[PING][PATCH 2/2] arm: Add support for MVE Tail-Predicated Low Overhead Loops
Richard Sandiford
richard.sandiford@arm.com
Tue Oct 24 15:11:26 GMT 2023
Sorry for the slow review. I had a look at the arm bits too, to get
some context for the target-independent bits.
Stamatis Markianos-Wright via Gcc-patches <gcc-patches@gcc.gnu.org> writes:
> [...]
> diff --git a/gcc/config/arm/arm-protos.h b/gcc/config/arm/arm-protos.h
> index 77e76336e94..74186930f0b 100644
> --- a/gcc/config/arm/arm-protos.h
> +++ b/gcc/config/arm/arm-protos.h
> @@ -65,8 +65,8 @@ extern void arm_emit_speculation_barrier_function (void);
> extern void arm_decompose_di_binop (rtx, rtx, rtx *, rtx *, rtx *, rtx *);
> extern bool arm_q_bit_access (void);
> extern bool arm_ge_bits_access (void);
> -extern bool arm_target_insn_ok_for_lob (rtx);
> -
> +extern bool arm_target_bb_ok_for_lob (basic_block);
> +extern rtx arm_attempt_dlstp_transform (rtx);
> #ifdef RTX_CODE
> enum reg_class
> arm_mode_base_reg_class (machine_mode);
> diff --git a/gcc/config/arm/arm.cc b/gcc/config/arm/arm.cc
> index 6e933c80183..39d97ba5e4d 100644
> --- a/gcc/config/arm/arm.cc
> +++ b/gcc/config/arm/arm.cc
> @@ -659,6 +659,12 @@ static const struct attribute_spec arm_attribute_table[]
> [...]
> +/* Wrapper function of arm_get_required_vpr_reg with TYPE == 1, so return
> + something only if the VPR reg is an input operand to the insn. */
> +
> +static rtx
> +ALWAYS_INLINE
Probably best to leave out the ALWAYS_INLINE. That's generally only
appropriate for things that need to be inlined for correctness.
> +arm_get_required_vpr_reg_param (rtx_insn *insn)
> +{
> + return arm_get_required_vpr_reg (insn, 1);
> +}
> [...]
> +/* Recursively scan through the DF chain backwards within the basic block and
> + determine if any of the USEs of the original insn (or the USEs of the insns
> + where thy were DEF-ed, etc., recursively) were affected by implicit VPT
> + predication of an MVE_VPT_UNPREDICATED_INSN_P in a dlstp/letp loop.
> + This function returns true if the insn is affected implicit predication
> + and false otherwise.
> + Having such implicit predication on an unpredicated insn wouldn't in itself
> + block tail predication, because the output of that insn might then be used
> + in a correctly predicated store insn, where the disabled lanes will be
> + ignored. To verify this we later call:
> + `arm_mve_check_df_chain_fwd_for_implic_predic_impact`, which will check the
> + DF chains forward to see if any implicitly-predicated operand gets used in
> + an improper way. */
> +
> +static bool
> +arm_mve_check_df_chain_back_for_implic_predic
> + (hash_map<int_hash<int, -1, -2>, bool>* safe_insn_map, rtx_insn *insn,
> + rtx vctp_vpr_generated)
> +{
> + bool* temp = NULL;
> + if ((temp = safe_insn_map->get (INSN_UID (insn))))
> + return *temp;
> +
> + basic_block body = BLOCK_FOR_INSN (insn);
> + /* The circumstances under which an instruction is affected by "implicit
> + predication" are as follows:
> + * It is an UNPREDICATED_INSN_P:
> + * That loads/stores from/to memory.
> + * Where any one of its operands is an MVE vector from outside the
> + loop body bb.
> + Or:
> + * Any of it's operands, recursively backwards, are affected. */
> + if (MVE_VPT_UNPREDICATED_INSN_P (insn)
> + && (arm_is_mve_load_store_insn (insn)
> + || (arm_is_mve_across_vector_insn (insn)
> + && !arm_mve_is_allowed_unpredic_across_vector_insn (insn))))
> + {
> + safe_insn_map->put (INSN_UID (insn), true);
> + return true;
> + }
> +
> + df_ref insn_uses = NULL;
> + FOR_EACH_INSN_USE (insn_uses, insn)
> + {
> + /* If the operand is in the input reg set to the the basic block,
> + (i.e. it has come from outside the loop!), consider it unsafe if:
> + * It's being used in an unpredicated insn.
> + * It is a predicable MVE vector. */
> + if (MVE_VPT_UNPREDICATED_INSN_P (insn)
> + && VALID_MVE_MODE (GET_MODE (DF_REF_REG (insn_uses)))
> + && REGNO_REG_SET_P (DF_LR_IN (body), DF_REF_REGNO (insn_uses)))
> + {
> + safe_insn_map->put (INSN_UID (insn), true);
> + return true;
> + }
> + /* Scan backwards from the current INSN through the instruction chain
> + until the start of the basic block. */
> + for (rtx_insn *prev_insn = PREV_INSN (insn);
> + prev_insn && prev_insn != PREV_INSN (BB_HEAD (body));
> + prev_insn = PREV_INSN (prev_insn))
> + {
> + /* If a previous insn defines a register that INSN uses, then recurse
> + in order to check that insn's USEs.
> + If any of these insns return true as MVE_VPT_UNPREDICATED_INSN_Ps,
> + then the whole chain is affected by the change in behaviour from
> + being placed in dlstp/letp loop. */
> + df_ref prev_insn_defs = NULL;
> + FOR_EACH_INSN_DEF (prev_insn_defs, prev_insn)
> + {
> + if (DF_REF_REGNO (insn_uses) == DF_REF_REGNO (prev_insn_defs)
> + && !arm_mve_vec_insn_is_predicated_with_this_predicate
> + (insn, vctp_vpr_generated)
> + && arm_mve_check_df_chain_back_for_implic_predic
> + (safe_insn_map, prev_insn, vctp_vpr_generated))
> + {
> + safe_insn_map->put (INSN_UID (insn), true);
> + return true;
> + }
> + }
> + }
> + }
> + safe_insn_map->put (INSN_UID (insn), false);
> + return false;
> +}
It looks like the recursion depth here is proportional to the length
of the longest insn-to-insn DU chain. That could be a problem for
pathologically large loops. Would it be possible to restructure
this to use a worklist instead?
> [...]
> +/* If we have identified the loop to have an incrementing counter, we need to
> + make sure that it increments by 1 and that the loop is structured correctly:
> + * The counter starts from 0
> + * The counter terminates at (num_of_elem + num_of_lanes - 1) / num_of_lanes
> + * The vctp insn uses a reg that decrements appropriately in each iteration.
> +*/
> +
> +static rtx_insn*
> +arm_mve_dlstp_check_inc_counter (basic_block body, rtx_insn* vctp_insn,
> + rtx condconst, rtx condcount)
> +{
> + rtx vctp_reg = XVECEXP (XEXP (PATTERN (vctp_insn), 1), 0, 0);
> + /* The loop latch has to be empty. When compiling all the known MVE LoLs in
> + user applications, none of those with incrementing counters had any real
> + insns in the loop latch. As such, this function has only been tested with
> + an empty latch and may misbehave or ICE if we somehow get here with an
> + increment in the latch, so, for correctness, error out early. */
> + rtx_insn *dec_insn = BB_END (body->loop_father->latch);
> + if (NONDEBUG_INSN_P (dec_insn))
> + return NULL;
Could this use empty_block_p instead? It would avoid hard-coding the
assumption that BB_END is not a debug instruction.
> +
> + class rtx_iv vctp_reg_iv;
> + /* For loops of type B) the loop counter is independent of the decrement
> + of the reg used in the vctp_insn. So run iv analysis on that reg. This
> + has to succeed for such loops to be supported. */
> + if (!iv_analyze (vctp_insn, as_a<scalar_int_mode> (GET_MODE (vctp_reg)),
> + vctp_reg, &vctp_reg_iv))
> + return NULL;
> +
> + /* Find where both of those are modified in the loop body bb. */
> + rtx condcount_reg_set
> + = PATTERN (DF_REF_INSN (df_bb_regno_only_def_find
> + (body, REGNO (condcount))));
> + rtx vctp_reg_set = PATTERN (DF_REF_INSN (df_bb_regno_only_def_find
> + (body, REGNO (vctp_reg))));
> + if (!vctp_reg_set || !condcount_reg_set)
> + return NULL;
It looks like these should be testing whether df_bb_regno_only_def_find
return null instead. If they do, the DF_REF_INSN will segfault.
If they don't, the rest will succeed.
> +
> + if (REG_P (condcount) && REG_P (condconst))
> + {
> + /* First we need to prove that the loop is going 0..condconst with an
> + inc of 1 in each iteration. */
> + if (GET_CODE (XEXP (condcount_reg_set, 1)) == PLUS
> + && CONST_INT_P (XEXP (XEXP (condcount_reg_set, 1), 1))
> + && INTVAL (XEXP (XEXP (condcount_reg_set, 1), 1)) == 1)
> + {
> + rtx counter_reg = XEXP (condcount_reg_set, 0);
> + /* Check that the counter did indeed start from zero. */
> + df_ref this_set = DF_REG_DEF_CHAIN (REGNO (counter_reg));
> + if (!this_set)
> + return NULL;
> + df_ref last_set = DF_REF_NEXT_REG (this_set);
> + if (!last_set)
> + return NULL;
> + rtx_insn* last_set_insn = DF_REF_INSN (last_set);
> + if (!single_set (last_set_insn))
> + return NULL;
> + rtx counter_orig_set;
> + counter_orig_set = XEXP (PATTERN (last_set_insn), 1);
> + if (!CONST_INT_P (counter_orig_set)
> + || (INTVAL (counter_orig_set) != 0))
> + return NULL;
> + /* And finally check that the target value of the counter,
> + condconst is of the correct shape. */
> + if (!arm_mve_check_reg_origin_is_num_elems (body, condconst,
> + vctp_reg_iv.step))
> + return NULL;
> + }
> + else
> + return NULL;
> + }
> + else
> + return NULL;
> +
> + /* Extract the decrementnum of the vctp reg. */
> + int decrementnum = abs (INTVAL (vctp_reg_iv.step));
> + /* Ensure it matches the number of lanes of the vctp instruction. */
> + if (decrementnum != arm_mve_get_vctp_lanes (PATTERN (vctp_insn)))
> + return NULL;
> +
> + /* Everything looks valid. */
> + return vctp_insn;
> +}
One of the main reasons for reading the arm bits was to try to answer
the question: if we switch to a downcounting loop with a GE condition,
how do we make sure that the start value is not a large unsigned
number that is interpreted as negative by GE? E.g. if the loop
originally counted up in steps of N and used an LTU condition,
it could stop at a value in the range [INT_MAX + 1, UINT_MAX].
But the loop might never iterate if we start counting down from
most values in that range.
Does the patch handle that?
[I didn't look at the Arm parts much beyond this point]
> [...]
> diff --git a/gcc/df-core.cc b/gcc/df-core.cc
> index d4812b04a7c..4fcc14bf790 100644
> --- a/gcc/df-core.cc
> +++ b/gcc/df-core.cc
> @@ -1964,6 +1964,21 @@ df_bb_regno_last_def_find (basic_block bb, unsigned int regno)
> return NULL;
> }
>
> +/* Return the one and only def of REGNO within BB. If there is no def or
> + there are multiple defs, return NULL. */
> +
> +df_ref
> +df_bb_regno_only_def_find (basic_block bb, unsigned int regno)
> +{
> + df_ref temp = df_bb_regno_first_def_find (bb, regno);
> + if (!temp)
> + return NULL;
> + else if (temp == df_bb_regno_last_def_find (bb, regno))
> + return temp;
> + else
> + return NULL;
> +}
> +
> /* Finds the reference corresponding to the definition of REG in INSN.
> DF is the dataflow object. */
>
> diff --git a/gcc/df.h b/gcc/df.h
> index 402657a7076..98623637f9c 100644
> --- a/gcc/df.h
> +++ b/gcc/df.h
> @@ -987,6 +987,7 @@ extern void df_check_cfg_clean (void);
> #endif
> extern df_ref df_bb_regno_first_def_find (basic_block, unsigned int);
> extern df_ref df_bb_regno_last_def_find (basic_block, unsigned int);
> +extern df_ref df_bb_regno_only_def_find (basic_block, unsigned int);
> extern df_ref df_find_def (rtx_insn *, rtx);
> extern bool df_reg_defined (rtx_insn *, rtx);
> extern df_ref df_find_use (rtx_insn *, rtx);
> diff --git a/gcc/loop-doloop.cc b/gcc/loop-doloop.cc
> index 4feb0a25ab9..f6dbd0515de 100644
> --- a/gcc/loop-doloop.cc
> +++ b/gcc/loop-doloop.cc
> @@ -85,29 +85,29 @@ doloop_condition_get (rtx_insn *doloop_pat)
> forms:
>
> 1) (parallel [(set (pc) (if_then_else (condition)
> - (label_ref (label))
> - (pc)))
> - (set (reg) (plus (reg) (const_int -1)))
> - (additional clobbers and uses)])
> + (label_ref (label))
> + (pc)))
> + (set (reg) (plus (reg) (const_int -n)))
> + (additional clobbers and uses)])
>
> The branch must be the first entry of the parallel (also required
> by jump.cc), and the second entry of the parallel must be a set of
> the loop counter register. Some targets (IA-64) wrap the set of
> the loop counter in an if_then_else too.
>
> - 2) (set (reg) (plus (reg) (const_int -1))
> - (set (pc) (if_then_else (reg != 0)
> - (label_ref (label))
> - (pc))).
> + 2) (set (reg) (plus (reg) (const_int -n))
> + (set (pc) (if_then_else (reg != 0)
> + (label_ref (label))
> + (pc))).
>
> Some targets (ARM) do the comparison before the branch, as in the
> following form:
>
> - 3) (parallel [(set (cc) (compare ((plus (reg) (const_int -1), 0)))
> - (set (reg) (plus (reg) (const_int -1)))])
> - (set (pc) (if_then_else (cc == NE)
> - (label_ref (label))
> - (pc))) */
> + 3) (parallel [(set (cc) (compare ((plus (reg) (const_int -n), 0)))
Pre-existing, but I think this should be:
(set (cc) (compare (plus (reg) (const_int -n)) 0))
Same for the copy further down.
> + (set (reg) (plus (reg) (const_int -n)))])
> + (set (pc) (if_then_else (cc == NE)
> + (label_ref (label))
> + (pc))) */
>
> pattern = PATTERN (doloop_pat);
>
I agree with Andre that it would be good to include the GE possibility
in the comments, e.g. ==/>=.
> @@ -143,7 +143,7 @@ doloop_condition_get (rtx_insn *doloop_pat)
> || GET_CODE (cmp_arg1) != PLUS)
> return 0;
> reg_orig = XEXP (cmp_arg1, 0);
> - if (XEXP (cmp_arg1, 1) != GEN_INT (-1)
> + if (!CONST_INT_P (XEXP (cmp_arg1, 1))
> || !REG_P (reg_orig))
> return 0;
> cc_reg = SET_DEST (cmp_orig);
> @@ -156,7 +156,8 @@ doloop_condition_get (rtx_insn *doloop_pat)
> {
> /* We expect the condition to be of the form (reg != 0) */
> cond = XEXP (SET_SRC (cmp), 0);
> - if (GET_CODE (cond) != NE || XEXP (cond, 1) != const0_rtx)
> + if ((GET_CODE (cond) != NE && GET_CODE (cond) != GE)
> + || XEXP (cond, 1) != const0_rtx)
> return 0;
> }
> }
> @@ -173,14 +174,14 @@ doloop_condition_get (rtx_insn *doloop_pat)
> if (! REG_P (reg))
> return 0;
>
> - /* Check if something = (plus (reg) (const_int -1)).
> + /* Check if something = (plus (reg) (const_int -n)).
> On IA-64, this decrement is wrapped in an if_then_else. */
> inc_src = SET_SRC (inc);
> if (GET_CODE (inc_src) == IF_THEN_ELSE)
> inc_src = XEXP (inc_src, 1);
> if (GET_CODE (inc_src) != PLUS
> || XEXP (inc_src, 0) != reg
> - || XEXP (inc_src, 1) != constm1_rtx)
> + || !CONST_INT_P (XEXP (inc_src, 1)))
> return 0;
>
> /* Check for (set (pc) (if_then_else (condition)
> @@ -211,42 +212,49 @@ doloop_condition_get (rtx_insn *doloop_pat)
> || (GET_CODE (XEXP (condition, 0)) == PLUS
> && XEXP (XEXP (condition, 0), 0) == reg))
> {
> - if (GET_CODE (pattern) != PARALLEL)
> /* For the second form we expect:
>
> - (set (reg) (plus (reg) (const_int -1))
> - (set (pc) (if_then_else (reg != 0)
> - (label_ref (label))
> - (pc))).
> + (set (reg) (plus (reg) (const_int -n))
> + (set (pc) (if_then_else (reg != 0)
> + (label_ref (label))
> + (pc))).
>
> - is equivalent to the following:
> + If n == 1, that is equivalent to the following:
>
> - (parallel [(set (pc) (if_then_else (reg != 1)
> - (label_ref (label))
> - (pc)))
> - (set (reg) (plus (reg) (const_int -1)))
> - (additional clobbers and uses)])
> + (parallel [(set (pc) (if_then_else (reg != 1)
> + (label_ref (label))
> + (pc)))
> + (set (reg) (plus (reg) (const_int -1)))
> + (additional clobbers and uses)])
>
> - For the third form we expect:
> + For the third form we expect:
>
> - (parallel [(set (cc) (compare ((plus (reg) (const_int -1)), 0))
> - (set (reg) (plus (reg) (const_int -1)))])
> - (set (pc) (if_then_else (cc == NE)
> - (label_ref (label))
> - (pc)))
> + (parallel [(set (cc) (compare ((plus (reg) (const_int -n)), 0))
> + (set (reg) (plus (reg) (const_int -n)))])
> + (set (pc) (if_then_else (cc == NE)
> + (label_ref (label))
> + (pc)))
>
> - which is equivalent to the following:
> + Which also for n == 1 is equivalent to the following:
>
> - (parallel [(set (cc) (compare (reg, 1))
> - (set (reg) (plus (reg) (const_int -1)))
> - (set (pc) (if_then_else (NE == cc)
> - (label_ref (label))
> - (pc))))])
> + (parallel [(set (cc) (compare (reg, 1))
> + (set (reg) (plus (reg) (const_int -1)))
> + (set (pc) (if_then_else (NE == cc)
> + (label_ref (label))
> + (pc))))])
>
> - So we return the second form instead for the two cases.
> + So we return the second form instead for the two cases.
>
> + For the "elementwise" form where the decrement number isn't -1,
> + the final value may be exceeded, so use GE instead of NE.
> */
> - condition = gen_rtx_fmt_ee (NE, VOIDmode, inc_src, const1_rtx);
> + if (GET_CODE (pattern) != PARALLEL)
> + {
> + if (INTVAL (XEXP (inc_src, 1)) != -1)
> + condition = gen_rtx_fmt_ee (GE, VOIDmode, inc_src, const0_rtx);
> + else
> + condition = gen_rtx_fmt_ee (NE, VOIDmode, inc_src, const1_rtx);;
> + }
>
> return condition;
> }
> @@ -685,17 +693,6 @@ doloop_optimize (class loop *loop)
> return false;
> }
>
> - max_cost
> - = COSTS_N_INSNS (param_max_iterations_computation_cost);
> - if (set_src_cost (desc->niter_expr, mode, optimize_loop_for_speed_p (loop))
> - > max_cost)
> - {
> - if (dump_file)
> - fprintf (dump_file,
> - "Doloop: number of iterations too costly to compute.\n");
> - return false;
> - }
> -
> if (desc->const_iter)
> iterations = widest_int::from (rtx_mode_t (desc->niter_expr, mode),
> UNSIGNED);
> @@ -716,11 +713,24 @@ doloop_optimize (class loop *loop)
>
> /* Generate looping insn. If the pattern FAILs then give up trying
> to modify the loop since there is some aspect the back-end does
> - not like. */
> - count = copy_rtx (desc->niter_expr);
> + not like. If this succeeds, there is a chance that the loop
> + desc->niter_expr has been altered by the backend, so only extract
> + that data after the gen_doloop_end. */
> start_label = block_label (desc->in_edge->dest);
> doloop_reg = gen_reg_rtx (mode);
> rtx_insn *doloop_seq = targetm.gen_doloop_end (doloop_reg, start_label);
> + count = copy_rtx (desc->niter_expr);
Very minor, but I think the copy should still happen after the cost check.
OK for the df and doloop parts with those changes.
Thanks,
Richard
> +
> + max_cost
> + = COSTS_N_INSNS (param_max_iterations_computation_cost);
> + if (set_src_cost (count, mode, optimize_loop_for_speed_p (loop))
> + > max_cost)
> + {
> + if (dump_file)
> + fprintf (dump_file,
> + "Doloop: number of iterations too costly to compute.\n");
> + return false;
> + }
>
> word_mode_size = GET_MODE_PRECISION (word_mode);
> word_mode_max = (HOST_WIDE_INT_1U << (word_mode_size - 1) << 1) - 1;
More information about the Gcc-patches
mailing list