Extend fold_vec_perm to fold VEC_PERM_EXPR in VLA manner

Richard Sandiford richard.sandiford@arm.com
Mon Sep 5 10:21:16 GMT 2022


Sorry for the slow reply.  I wrote a response a couple of weeks ago
but I think it get lost in a machine outage.

Prathamesh Kulkarni <prathamesh.kulkarni@linaro.org> writes:
> Hi,
> The attached prototype patch extends fold_vec_perm to fold VEC_PERM_EXPR
> in VLA manner, and currently handles the following cases:
> (a) fixed len arg0, arg1 and fixed len sel.
> (b) fixed len arg0, arg1 and vla sel
> (c) vla arg0, arg1 and vla sel with arg0, arg1 being VECTOR_CST.
>
> It seems to work for the VLA tests written in
> test_vec_perm_vla_folding (), and am working thru the fallout observed in
> regression testing.
>
> Does the approach taken in the patch look in the right direction ?
> I am not sure if I have got the conversion from "sel_index"
> to index of either arg0, or arg1 entirely correct.
> I would be grateful for suggestions on the patch.
>
> Thanks,
> Prathamesh
>
> diff --git a/gcc/fold-const.cc b/gcc/fold-const.cc
> index 4f4ec81c8d4..5e12260211e 100644
> --- a/gcc/fold-const.cc
> +++ b/gcc/fold-const.cc
> @@ -85,6 +85,9 @@ along with GCC; see the file COPYING3.  If not see
>  #include "vec-perm-indices.h"
>  #include "asan.h"
>  #include "gimple-range.h"
> +#include "tree-pretty-print.h"
> +#include "gimple-pretty-print.h"
> +#include "print-tree.h"
>  
>  /* Nonzero if we are folding constants inside an initializer or a C++
>     manifestly-constant-evaluated context; zero otherwise.
> @@ -10496,40 +10499,6 @@ fold_mult_zconjz (location_t loc, tree type, tree expr)
>  			  build_zero_cst (itype));
>  }
>  
> -
> -/* Helper function for fold_vec_perm.  Store elements of VECTOR_CST or
> -   CONSTRUCTOR ARG into array ELTS, which has NELTS elements, and return
> -   true if successful.  */
> -
> -static bool
> -vec_cst_ctor_to_array (tree arg, unsigned int nelts, tree *elts)
> -{
> -  unsigned HOST_WIDE_INT i, nunits;
> -
> -  if (TREE_CODE (arg) == VECTOR_CST
> -      && VECTOR_CST_NELTS (arg).is_constant (&nunits))
> -    {
> -      for (i = 0; i < nunits; ++i)
> -	elts[i] = VECTOR_CST_ELT (arg, i);
> -    }
> -  else if (TREE_CODE (arg) == CONSTRUCTOR)
> -    {
> -      constructor_elt *elt;
> -
> -      FOR_EACH_VEC_SAFE_ELT (CONSTRUCTOR_ELTS (arg), i, elt)
> -	if (i >= nelts || TREE_CODE (TREE_TYPE (elt->value)) == VECTOR_TYPE)
> -	  return false;
> -	else
> -	  elts[i] = elt->value;
> -    }
> -  else
> -    return false;
> -  for (; i < nelts; i++)
> -    elts[i]
> -      = fold_convert (TREE_TYPE (TREE_TYPE (arg)), integer_zero_node);
> -  return true;
> -}
> -
>  /* Attempt to fold vector permutation of ARG0 and ARG1 vectors using SEL
>     selector.  Return the folded VECTOR_CST or CONSTRUCTOR if successful,
>     NULL_TREE otherwise.  */
> @@ -10537,45 +10506,149 @@ vec_cst_ctor_to_array (tree arg, unsigned int nelts, tree *elts)
>  tree
>  fold_vec_perm (tree type, tree arg0, tree arg1, const vec_perm_indices &sel)
>  {
> -  unsigned int i;
> -  unsigned HOST_WIDE_INT nelts;
> -  bool need_ctor = false;
> +  poly_uint64 arg0_len = TYPE_VECTOR_SUBPARTS (TREE_TYPE (arg0));
> +  poly_uint64 arg1_len = TYPE_VECTOR_SUBPARTS (TREE_TYPE (arg1));
> +
> +  gcc_assert (known_eq (TYPE_VECTOR_SUBPARTS (type),
> +			sel.length ()));
> +  gcc_assert (known_eq (arg0_len, arg1_len));
>  
> -  if (!sel.length ().is_constant (&nelts))
> -    return NULL_TREE;
> -  gcc_assert (known_eq (TYPE_VECTOR_SUBPARTS (type), nelts)
> -	      && known_eq (TYPE_VECTOR_SUBPARTS (TREE_TYPE (arg0)), nelts)
> -	      && known_eq (TYPE_VECTOR_SUBPARTS (TREE_TYPE (arg1)), nelts));
>    if (TREE_TYPE (TREE_TYPE (arg0)) != TREE_TYPE (type)
>        || TREE_TYPE (TREE_TYPE (arg1)) != TREE_TYPE (type))
>      return NULL_TREE;
>  
> -  tree *in_elts = XALLOCAVEC (tree, nelts * 2);
> -  if (!vec_cst_ctor_to_array (arg0, nelts, in_elts)
> -      || !vec_cst_ctor_to_array (arg1, nelts, in_elts + nelts))
> +  unsigned input_npatterns = 0;
> +  unsigned out_npatterns = sel.encoding ().npatterns ();
> +  unsigned out_nelts_per_pattern = sel.encoding ().nelts_per_pattern ();
> +
> +  /* FIXME: How to reshape fixed length vector_cst, so that
> +     npatterns == vector.length () and nelts_per_pattern == 1 ?
> +     It seems the vector is canonicalized to minimize npatterns.  */
> +
> +  if (arg0_len.is_constant ())
> +    {
> +      /* If arg0, arg1 are fixed width vectors, and sel is VLA,
> +         ensure that it is a dup sequence and has same period
> +	 as input vector.  */
> +
> +      if (!sel.length ().is_constant ()
> +	  && (sel.encoding ().nelts_per_pattern () > 2
> +	      || !known_eq (arg0_len, sel.encoding ().npatterns ())))
> +	return NULL_TREE;
> +
> +      input_npatterns = arg0_len.to_constant ();
> +
> +      if (sel.length ().is_constant ())
> +	{
> +	  out_npatterns = sel.length ().to_constant ();
> +	  out_nelts_per_pattern = 1;
> +	}
> +    }
> +  else if (TREE_CODE (arg0) == VECTOR_CST
> +	   && TREE_CODE (arg1) == VECTOR_CST)
> +    {
> +      unsigned npatterns = VECTOR_CST_NPATTERNS (arg0);
> +      unsigned input_nelts_per_pattern = VECTOR_CST_NELTS_PER_PATTERN (arg0);
> +
> +      /* If arg0, arg1 are VLA, then ensure that,
> +	 (a) sel also has same length as input vectors.
> +	 (b) arg0 and arg1 have same encoding.
> +	 (c) sel has same number of patterns as input vectors.
> +	 (d) if sel is a stepped sequence, then it has same
> +	     encoding as input vectors.  */
> +
> +      if (!known_eq (arg0_len, sel.length ())
> +	  || npatterns != VECTOR_CST_NPATTERNS (arg1)
> +	  || input_nelts_per_pattern != VECTOR_CST_NELTS_PER_PATTERN (arg1)
> +	  || npatterns != sel.encoding ().npatterns ()
> +	  || (sel.encoding ().nelts_per_pattern () > 2
> +	      && sel.encoding ().nelts_per_pattern () != input_nelts_per_pattern))
> +	return NULL_TREE;

This seems too restrictive.  More below.

> +
> +      input_npatterns = npatterns;
> +    }
> +  else
>      return NULL_TREE;
>  
> -  tree_vector_builder out_elts (type, nelts, 1);
> -  for (i = 0; i < nelts; i++)
> +  tree_vector_builder out_elts_builder (type, out_npatterns,
> +					out_nelts_per_pattern);
> +  bool need_ctor = false;
> +  unsigned out_encoded_nelts = out_npatterns * out_nelts_per_pattern;
> +
> +  for (unsigned i = 0; i < out_encoded_nelts; i++)
>      {
> -      HOST_WIDE_INT index;
> -      if (!sel[i].is_constant (&index))
> +      HOST_WIDE_INT sel_index;
> +      if (!sel[i].is_constant (&sel_index))
>  	return NULL_TREE;
> -      if (!CONSTANT_CLASS_P (in_elts[index]))
> -	need_ctor = true;
> -      out_elts.quick_push (unshare_expr (in_elts[index]));
> +
> +      /* Convert sel_index to index of either arg0 or arg1.
> +	 For eg:
> +	 arg0: {a0, b0, a1, b1, a1 + S, b1 + S, ...}
> +	 arg1: {c0, d0, c1, d1, c1 + S, d1 + S, ...}
> +	 Both have npatterns == 2, nelts_per_pattern == 3.
> +	 Then the combined vector would be:
> +	 {a0, b0, c0, d0, a1, b1, c1, d1, a1 + S, b1 + S, c1 + S, d1 + S, ... }
> +	 This combined vector will have,
> +	 npatterns = 2 * input_npatterns == 4.
> +	 sel_index is used to index this above combined vector.

There's no interleaving of the arguments though.  The selector selects from:

{a0, b0, a1, b1, a1 + S, b1 + S, ..., c0, d0, c1, d1, c1 + S, d1 + S, ...}

The VLA encoding encodes the first N patterns explicitly.  The
npatterns/nelts_per_pattern values then describe how to extend that
initial sequence to an arbitrary number of elements.  So when performing
an operation on (potentially) variable-length vectors, the questions is:

* Can we work out an initial sequence and npatterns/nelts_per_pattern
  pair that will be correct for all elements of the result?

This depends on the operation that we're performing.  E.g. it's
different for unary operations (vector_builder::new_unary_operation)
and binary operations (vector_builder::new_binary_operations).  It also
varies between unary operations and between binary operations, hence
the allow_stepped_p parameters.

For VEC_PERM_EXPR, I think the key requirement is that:

(R) Each individual selector pattern must always select from the same vector.

Whether this condition is met depends both on the pattern itself and on
the number of patterns that it's combined with.

E.g. suppose we had the selector pattern:

  { 0, 1, 4, ... }   i.e. 3x - 2 for x > 0

If the arguments and selector are n elements then this pattern on its
own would select from more than one argument if 3(n-1) - 2 >= n.
This is clearly true for large enough n.  So if n is variable then
we cannot represent this.

If the pattern above is one of two patterns, so interleaved as:

     { 0, _, 1, _, 4, _, ... }  o=0
  or { _, 0, _, 1, _, 4, ... }  o=1

then the pattern would select from more than one argument if
3(n/2-1) - 2 + o >= n.  This too would be a problem for variable n.

But if the pattern above is one of four patterns then it selects
from more than one argument if 3(n/4-1) - 2 + o >= n.  This is not
true for any valid n or o, so the pattern is OK.

So let's define some ad hoc terminology:

* Px is the number of patterns in x
* Ex is the number of elements per pattern in x

where x can be:

* 1: first argument
* 2: second argument
* s: selector
* r: result

Then:

(1) The number of elements encoded explicitly for x is Ex*Px

(2) The explicit encoding can be used to produce a sequence of N*Ex*Px
    elements for any integer N.  This extended sequence can be reencoded
    as having N*Px patterns, with Ex staying the same.

(3) If Ex < 3, Ex can be increased by 1 by repeating the final Px elements
    of the explicit encoding.

So let's assume (optimistically) that we can produce the result
by calculating the first Pr*Er elements and using the Pr,Er encoding
to imply the rest.  Then:

* (2) means that, when combining multiple input operands with potentially
  different encodings, we can set the number of patterns in the result
  to the least common multiple of the number of patterns in the inputs.
  In this case:

  Pr = least_common_multiple(P1, P2, Ps)

  is a valid number of patterns.

* (3) means that the number of elements per pattern of the result can
  be the maximum of the number of elements per pattern in the inputs.
  (Alternatively, we could always use 3.)  In this case:

  Er = max(E1, E2, Es)

  is a valid number of elements per pattern.

So if (R) holds we can compute the result -- for both VLA and VLS -- by
calculating the first Pr*Er elements of the result and using the
encoding to derive the rest.  If (R) doesn't hold then we need the
selector to be constant-length.  We should then fill in the result
based on:

- Pr == number of elements in the result
- Er == 1

But this should be the fallback option, even for VLS.

As far as the arguments go: we should reject CONSTRUCTORs for
variable-length types.  After doing that, we can treat a CONSTRUCTOR
for an N-element vector type by setting the number of patterns to N
and the number of elements per pattern to 1.

Thanks,
Richard

> +	 Since we don't explicitly build the combined vector, we convert
> +	 sel_index to corresponding index for either arg0 or arg1.
> +	 For eg, if sel_index == 7,
> +	 pattern = 7 % 4 == 3.
> +	 Since pattern > input_npatterns, the elem will come from:
> +	 pattern = 3 - input_npatterns ie, pattern 1 from arg1.
> +	 elem_index_in_pattern = 7 / 4 == 1.
> +	 So the actual index of the element in arg1 would be: 1 + (1 * 2) == 3.
> +	 So, sel_index == 7 corresponds to arg1[3], ie, d1.  */
> +
> +      unsigned pattern = sel_index % (2 * input_npatterns);
> +      unsigned elem_index_in_pattern = sel_index / (2 * input_npatterns);
> +      tree arg;
> +      if (pattern < input_npatterns)
> +	arg = arg0;
> +      else
> +	{
> +	  arg = arg1;
> +	  pattern -= input_npatterns;
> +	}
> +
> +      unsigned elem_index = (elem_index_in_pattern * input_npatterns) + pattern;
> +      tree elem;
> +      if (TREE_CODE (arg) == VECTOR_CST)
> +	{
> +	  /* If arg is fixed width vector, and elem_index goes out of range,
> +	     then return NULL_TREE.  */
> +	  if (TYPE_VECTOR_SUBPARTS (TREE_TYPE (arg)).is_constant ()
> +	      && elem_index > vector_cst_encoded_nelts (arg))
> +	    return NULL_TREE;
> +	  elem = vector_cst_elt (arg, elem_index);
> +	}
> +      else
> +	{
> +	  gcc_assert (TREE_CODE (arg) == CONSTRUCTOR);
> +	  if (elem_index >= CONSTRUCTOR_NELTS (arg))
> +	    return NULL_TREE;
> +	  elem = CONSTRUCTOR_ELT (arg, elem_index)->value;
> +	  if (VECTOR_TYPE_P (TREE_TYPE (elem)))
> +	    return NULL_TREE;
> +	  need_ctor = true;
> +	}
> +
> +      out_elts_builder.quick_push (unshare_expr (elem));
>      }
>  
>    if (need_ctor)
>      {
>        vec<constructor_elt, va_gc> *v;
> -      vec_alloc (v, nelts);
> -      for (i = 0; i < nelts; i++)
> -	CONSTRUCTOR_APPEND_ELT (v, NULL_TREE, out_elts[i]);
> +      vec_alloc (v, out_encoded_nelts);
> +
> +      for (unsigned i = 0; i < out_encoded_nelts; i++)
> +	CONSTRUCTOR_APPEND_ELT (v, NULL_TREE, out_elts_builder[i]);
>        return build_constructor (type, v);
>      }
> -  else
> -    return out_elts.build ();
> +
> +  return out_elts_builder.build ();
>  }
>  
>  /* Try to fold a pointer difference of type TYPE two address expressions of
> @@ -16912,6 +16985,91 @@ test_vec_duplicate_folding ()
>    ASSERT_TRUE (operand_equal_p (dup5_expr, dup5_cst, 0));
>  }
>  
> +static tree
> +build_vec_int_cst (unsigned npatterns, unsigned nelts_per_pattern,
> +		   int *encoded_elems)
> +{
> +  scalar_int_mode int_mode = SCALAR_INT_TYPE_MODE (integer_type_node);
> +  machine_mode vmode = targetm.vectorize.preferred_simd_mode (int_mode);
> +  poly_uint64 nunits = GET_MODE_NUNITS (vmode);
> +  tree vectype = build_vector_type (integer_type_node, nunits);
> +
> +  tree_vector_builder builder (vectype, npatterns, nelts_per_pattern);
> +  for (unsigned i = 0; i < npatterns * nelts_per_pattern; i++)
> +    builder.quick_push (build_int_cst (integer_type_node, encoded_elems[i]));
> +  return builder.build ();
> +}
> +
> +static void
> +vpe_verify_res (tree res, unsigned npatterns, unsigned nelts_per_pattern,
> +		int *encoded_elems)
> +{
> +  ASSERT_TRUE (res != NULL_TREE);
> +  ASSERT_TRUE (VECTOR_CST_NPATTERNS (res) == npatterns);
> +  ASSERT_TRUE (VECTOR_CST_NELTS_PER_PATTERN (res) == nelts_per_pattern);
> +
> +  for (unsigned i = 0; i < npatterns * nelts_per_pattern; i++)
> +    ASSERT_TRUE (wi::to_wide (VECTOR_CST_ELT (res, i))
> +			      == encoded_elems[i]);
> +}
> +
> +static void
> +test_vec_perm_vla_folding ()
> +{
> +  /* For all cases
> +     arg0: {1, 11, 21, 31, 2, 12, 22, 32, 3, 13, 23, 33, ...}, npatterns == 4, nelts_per_pattern == 3.
> +     arg1: {41, 51, 61, 71, 42, 52, 62, 72, 43, 53, 63, 73 ...}, npatterns == 4, nelts_per_pattern == 3.  */
> +
> +  int arg0_elems[] = { 1, 11, 21, 31, 2, 12, 22, 32, 3, 13, 23, 33 };
> +  tree arg0 = build_vec_int_cst (4, 3, arg0_elems);
> +
> +  int arg1_elems[] = { 41, 51, 61, 71, 42, 52, 62, 72, 43, 53, 63, 73 };
> +  tree arg1 = build_vec_int_cst (4, 3, arg1_elems);
> +
> +  if (TYPE_VECTOR_SUBPARTS (TREE_TYPE (arg0)).is_constant ()
> +      || TYPE_VECTOR_SUBPARTS (TREE_TYPE (arg1)).is_constant ())
> +    return;
> +
> +  /* Case 1: Dup mask sequence.
> +     mask = {0, 9, 3, 11, ...}, npatterns == 4, nelts_per_pattern == 1.
> +     expected result: {1, 21, 31, 32, ...}, npatterns == 4, nelts_per_pattern == 1.  */
> +  {
> +    int mask_elems[] = {0, 9, 3, 12};
> +    tree mask = build_vec_int_cst (4, 1, mask_elems);
> +    if (TYPE_VECTOR_SUBPARTS (TREE_TYPE (mask)).is_constant ())
> +      return;
> +    tree res = fold_ternary (VEC_PERM_EXPR, TREE_TYPE (arg0), arg0, arg1, mask);
> +    int res_encoded_elems[] = {1, 12, 31, 42};
> +    vpe_verify_res (res, 4, 1, res_encoded_elems);
> +  }
> +
> +  /* Case 2:
> +     mask = {0, 4, 1, 5, 8, 12, 9, 13 ...}, npatterns == 4, nelts_per_pattern == 2.
> +     expected result: {1, 41, 11, 51, 2, 12, 42, 52, ...}, npatterns == 4, nelts_per_pattern == 2.  */
> +  {
> +    int mask_elems[] = {0, 4, 1, 5, 8, 12, 9, 13};
> +    tree mask = build_vec_int_cst (4, 2, mask_elems);
> +    if (TYPE_VECTOR_SUBPARTS (TREE_TYPE (mask)).is_constant ())
> +      return;
> +    tree res = fold_ternary (VEC_PERM_EXPR, TREE_TYPE (arg0), arg0, arg1, mask);
> +    int res_encoded_elems[] = {1, 41, 11, 51, 2, 42, 12, 52};
> +    vpe_verify_res (res, 4, 2, res_encoded_elems);
> +  }
> +
> +  /* Case 3: Stepped mask sequence.
> +     mask = {0, 4, 1, 5, 8, 12, 9, 13, 16, 20, 17, 21}, npatterns == 4, nelts_per_pattern == 3.
> +     expected result = {1, 41, 11, 51, 2, 42, 12, 52, 3, 43, 13, 53 ...}, npatterns == 4, nelts_per_pattern == 3.  */
> +  {
> +    int mask_elems[] = {0, 4, 1, 5, 8, 12, 9, 13, 16, 20, 17, 21};
> +    tree mask = build_vec_int_cst (4, 3, mask_elems);
> +    if (TYPE_VECTOR_SUBPARTS (TREE_TYPE (mask)).is_constant ())
> +      return;
> +    tree res = fold_ternary (VEC_PERM_EXPR, TREE_TYPE (arg0), arg0, arg1, mask);
> +    int res_encoded_elems[] = {1, 41, 11, 51, 2, 42, 12, 52, 3, 43, 13, 53};
> +    vpe_verify_res (res, 4, 3, res_encoded_elems);
> +  }
> +}
> +
>  /* Run all of the selftests within this file.  */
>  
>  void
> @@ -16920,6 +17078,7 @@ fold_const_cc_tests ()
>    test_arithmetic_folding ();
>    test_vector_folding ();
>    test_vec_duplicate_folding ();
> +  test_vec_perm_vla_folding ();
>  }
>  
>  } // namespace selftest


More information about the Gcc-patches mailing list