[PING][PATCH 2/2] arm: Add support for MVE Tail-Predicated Low Overhead Loops

Richard Sandiford richard.sandiford@arm.com
Tue Oct 24 15:11:26 GMT 2023


Sorry for the slow review.  I had a look at the arm bits too, to get
some context for the target-independent bits.

Stamatis Markianos-Wright via Gcc-patches <gcc-patches@gcc.gnu.org> writes:
> [...]
> diff --git a/gcc/config/arm/arm-protos.h b/gcc/config/arm/arm-protos.h
> index 77e76336e94..74186930f0b 100644
> --- a/gcc/config/arm/arm-protos.h
> +++ b/gcc/config/arm/arm-protos.h
> @@ -65,8 +65,8 @@ extern void arm_emit_speculation_barrier_function (void);
>  extern void arm_decompose_di_binop (rtx, rtx, rtx *, rtx *, rtx *, rtx *);
>  extern bool arm_q_bit_access (void);
>  extern bool arm_ge_bits_access (void);
> -extern bool arm_target_insn_ok_for_lob (rtx);
> -
> +extern bool arm_target_bb_ok_for_lob (basic_block);
> +extern rtx arm_attempt_dlstp_transform (rtx);
>  #ifdef RTX_CODE
>  enum reg_class
>  arm_mode_base_reg_class (machine_mode);
> diff --git a/gcc/config/arm/arm.cc b/gcc/config/arm/arm.cc
> index 6e933c80183..39d97ba5e4d 100644
> --- a/gcc/config/arm/arm.cc
> +++ b/gcc/config/arm/arm.cc
> @@ -659,6 +659,12 @@ static const struct attribute_spec arm_attribute_table[]
> [...]
> +/* Wrapper function of arm_get_required_vpr_reg with TYPE == 1, so return
> +   something only if the VPR reg is an input operand to the insn.  */
> +
> +static rtx
> +ALWAYS_INLINE

Probably best to leave out the ALWAYS_INLINE.  That's generally only
appropriate for things that need to be inlined for correctness.

> +arm_get_required_vpr_reg_param (rtx_insn *insn)
> +{
> +  return arm_get_required_vpr_reg (insn, 1);
> +}
> [...]
> +/* Recursively scan through the DF chain backwards within the basic block and
> +   determine if any of the USEs of the original insn (or the USEs of the insns
> +   where thy were DEF-ed, etc., recursively) were affected by implicit VPT
> +   predication of an MVE_VPT_UNPREDICATED_INSN_P in a dlstp/letp loop.
> +   This function returns true if the insn is affected implicit predication
> +   and false otherwise.
> +   Having such implicit predication on an unpredicated insn wouldn't in itself
> +   block tail predication, because the output of that insn might then be used
> +   in a correctly predicated store insn, where the disabled lanes will be
> +   ignored.  To verify this we later call:
> +   `arm_mve_check_df_chain_fwd_for_implic_predic_impact`, which will check the
> +   DF chains forward to see if any implicitly-predicated operand gets used in
> +   an improper way.  */
> +
> +static bool
> +arm_mve_check_df_chain_back_for_implic_predic
> +  (hash_map<int_hash<int, -1, -2>, bool>* safe_insn_map, rtx_insn *insn,
> +   rtx vctp_vpr_generated)
> +{
> +  bool* temp = NULL;
> +  if ((temp = safe_insn_map->get (INSN_UID (insn))))
> +    return *temp;
> +
> +  basic_block body = BLOCK_FOR_INSN (insn);
> +  /* The circumstances under which an instruction is affected by "implicit
> +     predication" are as follows:
> +      * It is an UNPREDICATED_INSN_P:
> +	* That loads/stores from/to memory.
> +	* Where any one of its operands is an MVE vector from outside the
> +	  loop body bb.
> +     Or:
> +      * Any of it's operands, recursively backwards, are affected.  */
> +  if (MVE_VPT_UNPREDICATED_INSN_P (insn)
> +      && (arm_is_mve_load_store_insn (insn)
> +	  || (arm_is_mve_across_vector_insn (insn)
> +	      && !arm_mve_is_allowed_unpredic_across_vector_insn (insn))))
> +    {
> +      safe_insn_map->put (INSN_UID (insn), true);
> +      return true;
> +    }
> +
> +  df_ref insn_uses = NULL;
> +  FOR_EACH_INSN_USE (insn_uses, insn)
> +  {
> +    /* If the operand is in the input reg set to the the basic block,
> +       (i.e. it has come from outside the loop!), consider it unsafe if:
> +	 * It's being used in an unpredicated insn.
> +	 * It is a predicable MVE vector.  */
> +    if (MVE_VPT_UNPREDICATED_INSN_P (insn)
> +	&& VALID_MVE_MODE (GET_MODE (DF_REF_REG (insn_uses)))
> +	&& REGNO_REG_SET_P (DF_LR_IN (body), DF_REF_REGNO (insn_uses)))
> +      {
> +	safe_insn_map->put (INSN_UID (insn), true);
> +	return true;
> +      }
> +    /* Scan backwards from the current INSN through the instruction chain
> +       until the start of the basic block.  */
> +    for (rtx_insn *prev_insn = PREV_INSN (insn);
> +	 prev_insn && prev_insn != PREV_INSN (BB_HEAD (body));
> +	 prev_insn = PREV_INSN (prev_insn))
> +      {
> +	/* If a previous insn defines a register that INSN uses, then recurse
> +	   in order to check that insn's USEs.
> +	   If any of these insns return true as MVE_VPT_UNPREDICATED_INSN_Ps,
> +	   then the whole chain is affected by the change in behaviour from
> +	   being placed in dlstp/letp loop.  */
> +	df_ref prev_insn_defs = NULL;
> +	FOR_EACH_INSN_DEF (prev_insn_defs, prev_insn)
> +	{
> +	  if (DF_REF_REGNO (insn_uses) == DF_REF_REGNO (prev_insn_defs)
> +	      && !arm_mve_vec_insn_is_predicated_with_this_predicate
> +		   (insn, vctp_vpr_generated)
> +	      && arm_mve_check_df_chain_back_for_implic_predic
> +		  (safe_insn_map, prev_insn, vctp_vpr_generated))
> +	    {
> +	      safe_insn_map->put (INSN_UID (insn), true);
> +	      return true;
> +	    }
> +	}
> +      }
> +  }
> +  safe_insn_map->put (INSN_UID (insn), false);
> +  return false;
> +}

It looks like the recursion depth here is proportional to the length
of the longest insn-to-insn DU chain.  That could be a problem for
pathologically large loops.  Would it be possible to restructure
this to use a worklist instead?

> [...]
> +/* If we have identified the loop to have an incrementing counter, we need to
> +   make sure that it increments by 1 and that the loop is structured correctly:
> +    * The counter starts from 0
> +    * The counter terminates at (num_of_elem + num_of_lanes - 1) / num_of_lanes
> +    * The vctp insn uses a reg that decrements appropriately in each iteration.
> +*/
> +
> +static rtx_insn*
> +arm_mve_dlstp_check_inc_counter (basic_block body, rtx_insn* vctp_insn,
> +				 rtx condconst, rtx condcount)
> +{
> +  rtx vctp_reg = XVECEXP (XEXP (PATTERN (vctp_insn), 1), 0, 0);
> +  /* The loop latch has to be empty.  When compiling all the known MVE LoLs in
> +     user applications, none of those with incrementing counters had any real
> +     insns in the loop latch.  As such, this function has only been tested with
> +     an empty latch and may misbehave or ICE if we somehow get here with an
> +     increment in the latch, so, for correctness, error out early.  */
> +  rtx_insn *dec_insn = BB_END (body->loop_father->latch);
> +  if (NONDEBUG_INSN_P (dec_insn))
> +    return NULL;

Could this use empty_block_p instead?  It would avoid hard-coding the
assumption that BB_END is not a debug instruction.

> +
> +  class rtx_iv vctp_reg_iv;
> +  /* For loops of type B) the loop counter is independent of the decrement
> +     of the reg used in the vctp_insn. So run iv analysis on that reg.  This
> +     has to succeed for such loops to be supported.  */
> +  if (!iv_analyze (vctp_insn, as_a<scalar_int_mode> (GET_MODE (vctp_reg)),
> +      vctp_reg, &vctp_reg_iv))
> +    return NULL;
> +
> +  /* Find where both of those are modified in the loop body bb.  */
> +  rtx condcount_reg_set
> +	= PATTERN (DF_REF_INSN (df_bb_regno_only_def_find
> +				 (body, REGNO (condcount))));
> +  rtx vctp_reg_set = PATTERN (DF_REF_INSN (df_bb_regno_only_def_find
> +					    (body, REGNO (vctp_reg))));
> +  if (!vctp_reg_set || !condcount_reg_set)
> +    return NULL;

It looks like these should be testing whether df_bb_regno_only_def_find
return null instead.  If they do, the DF_REF_INSN will segfault.
If they don't, the rest will succeed.

> +
> +  if (REG_P (condcount) && REG_P (condconst))
> +    {
> +      /* First we need to prove that the loop is going 0..condconst with an
> +	 inc of 1 in each iteration.  */
> +      if (GET_CODE (XEXP (condcount_reg_set, 1)) == PLUS
> +	  && CONST_INT_P (XEXP (XEXP (condcount_reg_set, 1), 1))
> +	  && INTVAL (XEXP (XEXP (condcount_reg_set, 1), 1)) == 1)
> +	{
> +	    rtx counter_reg = XEXP (condcount_reg_set, 0);
> +	    /* Check that the counter did indeed start from zero.  */
> +	    df_ref this_set = DF_REG_DEF_CHAIN (REGNO (counter_reg));
> +	    if (!this_set)
> +	      return NULL;
> +	    df_ref last_set = DF_REF_NEXT_REG (this_set);
> +	    if (!last_set)
> +	      return NULL;
> +	    rtx_insn* last_set_insn = DF_REF_INSN (last_set);
> +	    if (!single_set (last_set_insn))
> +	      return NULL;
> +	    rtx counter_orig_set;
> +	    counter_orig_set = XEXP (PATTERN (last_set_insn), 1);
> +	    if (!CONST_INT_P (counter_orig_set)
> +		|| (INTVAL (counter_orig_set) != 0))
> +	      return NULL;
> +	    /* And finally check that the target value of the counter,
> +	       condconst is of the correct shape.  */
> +	    if (!arm_mve_check_reg_origin_is_num_elems (body, condconst,
> +							vctp_reg_iv.step))
> +	      return NULL;
> +	}
> +      else
> +	return NULL;
> +    }
> +  else
> +    return NULL;
> +
> +  /* Extract the decrementnum of the vctp reg.  */
> +  int decrementnum = abs (INTVAL (vctp_reg_iv.step));
> +  /* Ensure it matches the number of lanes of the vctp instruction.  */
> +  if (decrementnum != arm_mve_get_vctp_lanes (PATTERN (vctp_insn)))
> +    return NULL;
> +
> +  /* Everything looks valid.  */
> +  return vctp_insn;
> +}

One of the main reasons for reading the arm bits was to try to answer
the question: if we switch to a downcounting loop with a GE condition,
how do we make sure that the start value is not a large unsigned
number that is interpreted as negative by GE?  E.g. if the loop
originally counted up in steps of N and used an LTU condition,
it could stop at a value in the range [INT_MAX + 1, UINT_MAX].
But the loop might never iterate if we start counting down from
most values in that range.

Does the patch handle that?

[I didn't look at the Arm parts much beyond this point]

> [...]
> diff --git a/gcc/df-core.cc b/gcc/df-core.cc
> index d4812b04a7c..4fcc14bf790 100644
> --- a/gcc/df-core.cc
> +++ b/gcc/df-core.cc
> @@ -1964,6 +1964,21 @@ df_bb_regno_last_def_find (basic_block bb, unsigned int regno)
>    return NULL;
>  }
>  
> +/* Return the one and only def of REGNO within BB.  If there is no def or
> +   there are multiple defs, return NULL.  */
> +
> +df_ref
> +df_bb_regno_only_def_find (basic_block bb, unsigned int regno)
> +{
> +  df_ref temp = df_bb_regno_first_def_find (bb, regno);
> +  if (!temp)
> +    return NULL;
> +  else if (temp == df_bb_regno_last_def_find (bb, regno))
> +    return temp;
> +  else
> +    return NULL;
> +}
> +
>  /* Finds the reference corresponding to the definition of REG in INSN.
>     DF is the dataflow object.  */
>  
> diff --git a/gcc/df.h b/gcc/df.h
> index 402657a7076..98623637f9c 100644
> --- a/gcc/df.h
> +++ b/gcc/df.h
> @@ -987,6 +987,7 @@ extern void df_check_cfg_clean (void);
>  #endif
>  extern df_ref df_bb_regno_first_def_find (basic_block, unsigned int);
>  extern df_ref df_bb_regno_last_def_find (basic_block, unsigned int);
> +extern df_ref df_bb_regno_only_def_find (basic_block, unsigned int);
>  extern df_ref df_find_def (rtx_insn *, rtx);
>  extern bool df_reg_defined (rtx_insn *, rtx);
>  extern df_ref df_find_use (rtx_insn *, rtx);
> diff --git a/gcc/loop-doloop.cc b/gcc/loop-doloop.cc
> index 4feb0a25ab9..f6dbd0515de 100644
> --- a/gcc/loop-doloop.cc
> +++ b/gcc/loop-doloop.cc
> @@ -85,29 +85,29 @@ doloop_condition_get (rtx_insn *doloop_pat)
>       forms:
>  
>       1)  (parallel [(set (pc) (if_then_else (condition)
> -	  			            (label_ref (label))
> -				            (pc)))
> -	             (set (reg) (plus (reg) (const_int -1)))
> -	             (additional clobbers and uses)])
> +					    (label_ref (label))
> +					    (pc)))
> +		     (set (reg) (plus (reg) (const_int -n)))
> +		     (additional clobbers and uses)])
>  
>       The branch must be the first entry of the parallel (also required
>       by jump.cc), and the second entry of the parallel must be a set of
>       the loop counter register.  Some targets (IA-64) wrap the set of
>       the loop counter in an if_then_else too.
>  
> -     2)  (set (reg) (plus (reg) (const_int -1))
> -         (set (pc) (if_then_else (reg != 0)
> -	                         (label_ref (label))
> -			         (pc))).  
> +     2)  (set (reg) (plus (reg) (const_int -n))
> +	 (set (pc) (if_then_else (reg != 0)
> +				 (label_ref (label))
> +				 (pc))).
>  
>       Some targets (ARM) do the comparison before the branch, as in the
>       following form:
>  
> -     3) (parallel [(set (cc) (compare ((plus (reg) (const_int -1), 0)))
> -                   (set (reg) (plus (reg) (const_int -1)))])
> -        (set (pc) (if_then_else (cc == NE)
> -                                (label_ref (label))
> -                                (pc))) */
> +     3) (parallel [(set (cc) (compare ((plus (reg) (const_int -n), 0)))

Pre-existing, but I think this should be:

  (set (cc) (compare (plus (reg) (const_int -n)) 0))

Same for the copy further down.

> +		   (set (reg) (plus (reg) (const_int -n)))])
> +	(set (pc) (if_then_else (cc == NE)
> +				(label_ref (label))
> +				(pc))) */
>  
>    pattern = PATTERN (doloop_pat);
>  

I agree with Andre that it would be good to include the GE possibility
in the comments, e.g. ==/>=.

> @@ -143,7 +143,7 @@ doloop_condition_get (rtx_insn *doloop_pat)
>  	      || GET_CODE (cmp_arg1) != PLUS)
>  	    return 0;
>  	  reg_orig = XEXP (cmp_arg1, 0);
> -	  if (XEXP (cmp_arg1, 1) != GEN_INT (-1) 
> +	  if (!CONST_INT_P (XEXP (cmp_arg1, 1))
>  	      || !REG_P (reg_orig))
>  	    return 0;
>  	  cc_reg = SET_DEST (cmp_orig);
> @@ -156,7 +156,8 @@ doloop_condition_get (rtx_insn *doloop_pat)
>  	{
>  	  /* We expect the condition to be of the form (reg != 0)  */
>  	  cond = XEXP (SET_SRC (cmp), 0);
> -	  if (GET_CODE (cond) != NE || XEXP (cond, 1) != const0_rtx)
> +	  if ((GET_CODE (cond) != NE && GET_CODE (cond) != GE)
> +	      || XEXP (cond, 1) != const0_rtx)
>  	    return 0;
>  	}
>      }
> @@ -173,14 +174,14 @@ doloop_condition_get (rtx_insn *doloop_pat)
>    if (! REG_P (reg))
>      return 0;
>  
> -  /* Check if something = (plus (reg) (const_int -1)).
> +  /* Check if something = (plus (reg) (const_int -n)).
>       On IA-64, this decrement is wrapped in an if_then_else.  */
>    inc_src = SET_SRC (inc);
>    if (GET_CODE (inc_src) == IF_THEN_ELSE)
>      inc_src = XEXP (inc_src, 1);
>    if (GET_CODE (inc_src) != PLUS
>        || XEXP (inc_src, 0) != reg
> -      || XEXP (inc_src, 1) != constm1_rtx)
> +      || !CONST_INT_P (XEXP (inc_src, 1)))
>      return 0;
>  
>    /* Check for (set (pc) (if_then_else (condition)
> @@ -211,42 +212,49 @@ doloop_condition_get (rtx_insn *doloop_pat)
>        || (GET_CODE (XEXP (condition, 0)) == PLUS
>  	  && XEXP (XEXP (condition, 0), 0) == reg))
>     {
> -     if (GET_CODE (pattern) != PARALLEL)
>       /*  For the second form we expect:
>  
> -         (set (reg) (plus (reg) (const_int -1))
> -         (set (pc) (if_then_else (reg != 0)
> -                                 (label_ref (label))
> -                                 (pc))).
> +	 (set (reg) (plus (reg) (const_int -n))
> +	 (set (pc) (if_then_else (reg != 0)
> +				 (label_ref (label))
> +				 (pc))).
>  
> -         is equivalent to the following:
> +	 If n == 1, that is equivalent to the following:
>  
> -         (parallel [(set (pc) (if_then_else (reg != 1)
> -                                            (label_ref (label))
> -                                            (pc)))
> -                     (set (reg) (plus (reg) (const_int -1)))
> -                     (additional clobbers and uses)])
> +	 (parallel [(set (pc) (if_then_else (reg != 1)
> +					    (label_ref (label))
> +					    (pc)))
> +		     (set (reg) (plus (reg) (const_int -1)))
> +		     (additional clobbers and uses)])
>  
> -        For the third form we expect:
> +	For the third form we expect:
>  
> -        (parallel [(set (cc) (compare ((plus (reg) (const_int -1)), 0))
> -                   (set (reg) (plus (reg) (const_int -1)))])
> -        (set (pc) (if_then_else (cc == NE)
> -                                (label_ref (label))
> -                                (pc))) 
> +	(parallel [(set (cc) (compare ((plus (reg) (const_int -n)), 0))
> +		   (set (reg) (plus (reg) (const_int -n)))])
> +	(set (pc) (if_then_else (cc == NE)
> +				(label_ref (label))
> +				(pc)))
>  
> -        which is equivalent to the following:
> +	Which also for n == 1 is equivalent to the following:
>  
> -        (parallel [(set (cc) (compare (reg,  1))
> -                   (set (reg) (plus (reg) (const_int -1)))
> -                   (set (pc) (if_then_else (NE == cc)
> -                                           (label_ref (label))
> -                                           (pc))))])
> +	(parallel [(set (cc) (compare (reg,  1))
> +		   (set (reg) (plus (reg) (const_int -1)))
> +		   (set (pc) (if_then_else (NE == cc)
> +					   (label_ref (label))
> +					   (pc))))])
>  
> -        So we return the second form instead for the two cases.
> +	So we return the second form instead for the two cases.
>  
> +	For the "elementwise" form where the decrement number isn't -1,
> +	the final value may be exceeded, so use GE instead of NE.
>       */
> -        condition = gen_rtx_fmt_ee (NE, VOIDmode, inc_src, const1_rtx);
> +     if (GET_CODE (pattern) != PARALLEL)
> +       {
> +	if (INTVAL (XEXP (inc_src, 1)) != -1)
> +	  condition = gen_rtx_fmt_ee (GE, VOIDmode, inc_src, const0_rtx);
> +	else
> +	  condition = gen_rtx_fmt_ee (NE, VOIDmode, inc_src, const1_rtx);;
> +       }
>  
>      return condition;
>     }
> @@ -685,17 +693,6 @@ doloop_optimize (class loop *loop)
>        return false;
>      }
>  
> -  max_cost
> -    = COSTS_N_INSNS (param_max_iterations_computation_cost);
> -  if (set_src_cost (desc->niter_expr, mode, optimize_loop_for_speed_p (loop))
> -      > max_cost)
> -    {
> -      if (dump_file)
> -	fprintf (dump_file,
> -		 "Doloop: number of iterations too costly to compute.\n");
> -      return false;
> -    }
> -
>    if (desc->const_iter)
>      iterations = widest_int::from (rtx_mode_t (desc->niter_expr, mode),
>  				   UNSIGNED);
> @@ -716,11 +713,24 @@ doloop_optimize (class loop *loop)
>  
>    /* Generate looping insn.  If the pattern FAILs then give up trying
>       to modify the loop since there is some aspect the back-end does
> -     not like.  */
> -  count = copy_rtx (desc->niter_expr);
> +     not like.  If this succeeds, there is a chance that the loop
> +     desc->niter_expr has been altered by the backend, so only extract
> +     that data after the gen_doloop_end.  */
>    start_label = block_label (desc->in_edge->dest);
>    doloop_reg = gen_reg_rtx (mode);
>    rtx_insn *doloop_seq = targetm.gen_doloop_end (doloop_reg, start_label);
> +  count = copy_rtx (desc->niter_expr);

Very minor, but I think the copy should still happen after the cost check.

OK for the df and doloop parts with those changes.

Thanks,
Richard

> +
> +  max_cost
> +    = COSTS_N_INSNS (param_max_iterations_computation_cost);
> +  if (set_src_cost (count, mode, optimize_loop_for_speed_p (loop))
> +      > max_cost)
> +    {
> +      if (dump_file)
> +	fprintf (dump_file,
> +		 "Doloop: number of iterations too costly to compute.\n");
> +      return false;
> +    }
>  
>    word_mode_size = GET_MODE_PRECISION (word_mode);
>    word_mode_max = (HOST_WIDE_INT_1U << (word_mode_size - 1) << 1) - 1;


More information about the Gcc-patches mailing list