This is the mail archive of the gcc-patches@gcc.gnu.org mailing list for the GCC project.


Index Nav: [Date Index] [Subject Index] [Author Index] [Thread Index]
Message Nav: [Date Prev] [Date Next] [Thread Prev] [Thread Next]
Other format: [Raw text]

Re: Support fused multiply-adds in fully-masked reductions


On Thu, May 24, 2018 at 2:17 PM Richard Sandiford <
richard.sandiford@linaro.org> wrote:

> Richard Biener <richard.guenther@gmail.com> writes:
> > On Wed, May 16, 2018 at 11:26 AM Richard Sandiford <
> > richard.sandiford@linaro.org> wrote:
> >
> >> This patch adds support for fusing a conditional add or subtract
> >> with a multiplication, so that we can use fused multiply-add and
> >> multiply-subtract operations for fully-masked reductions.  E.g.
> >> for SVE we vectorise:
> >
> >>    double res = 0.0;
> >>    for (int i = 0; i < n; ++i)
> >>      res += x[i] * y[i];
> >
> >> using a fully-masked loop in which the loop body has the form:
> >
> >>    res_1 = PHI<0(preheader), res_2(latch)>;
> >>    avec = IFN_MASK_LOAD (loop_mask, a)
> >>    bvec = IFN_MASK_LOAD (loop_mask, b)
> >>    prod = avec * bvec;
> >>    res_2 = IFN_COND_ADD (loop_mask, res_1, prod);
> >
> >> where the last statement does the equivalent of:
> >
> >>    res_2 = loop_mask ? res_1 + prod : res_1;
> >
> >> (operating elementwise).  The point of the patch is to convert the last
> >> two statements into a single internal function that is the equivalent
of:
> >
> >>    res_2 = loop_mask ? fma (avec, bvec, res_1) : res_1;
> >
> >> (again operating elementwise).
> >
> >> All current conditional X operations have the form "do X or don't do X
> >> to the first operand" (add/don't add to first operand, etc.).  However,
> >> the FMA optabs and functions are ordered so that the accumulator comes
> >> last.  There were two obvious ways of resolving this: break the
> >> convention for conditional operators and have "add/don't add to the
> >> final operand" or break the convention for FMA and put the accumulator
> >> first.  The patch goes for the latter, but adds _REV to make it obvious
> >> that the operands are in a different order.
> >
> > Eh.  I guess you'll do the same to SAD/DOT_PROD/WIDEN_SUM?
> >
> > That said, I don't really see the "do or not do to the first operand",
it's
> > "do or not do the operation on operands 1 to 2 (or 3)".  None of the
> > current ops modify operand 1, they all produce a new value, no?

> Yeah, neither the current functions nor these ones actually changed
> operand 1.  It was all about deciding what the "else" value should be.
> The _REV thing was a "fix" for the fact that we wanted the else value
> to be the final operand of fma.

> Of course, the real fix was to make all the IFN_COND_* functions take an
> explicit else value, as you suggested in the review of the other patch
> in the series.  So all this _REV stuff is redundant now.

> Here's an updated version based on top of the IFN_COND_FMA patch
> that I just posted.  Tested in the same way.

OK.

Thanks,
Richard.

> Thanks,
> Richard

> 2018-05-24  Richard Sandiford  <richard.sandiford@linaro.org>
>              Alan Hayward  <alan.hayward@arm.com>
>              David Sherwood  <david.sherwood@arm.com>

> gcc/
>          * internal-fn.h (can_interpret_as_conditional_op_p): Declare.
>          * internal-fn.c (can_interpret_as_conditional_op_p): New function.
>          * tree-ssa-math-opts.c (convert_mult_to_fma_1): Handle conditional
>          plus and minus and convert them into IFN_COND_FMA-based sequences.
>          (convert_mult_to_fma): Handle conditional plus and minus.

> gcc/testsuite/
>          * gcc.dg/vect/vect-fma-2.c: New test.
>          * gcc.target/aarch64/sve/reduc_4.c: Likewise.
>          * gcc.target/aarch64/sve/reduc_6.c: Likewise.
>          * gcc.target/aarch64/sve/reduc_7.c: Likewise.

> Index: gcc/internal-fn.h
> ===================================================================
> --- gcc/internal-fn.h   2018-05-24 13:05:46.049605128 +0100
> +++ gcc/internal-fn.h   2018-05-24 13:08:24.643987582 +0100
> @@ -196,6 +196,9 @@ extern internal_fn get_conditional_inter
>   extern internal_fn get_conditional_internal_fn (internal_fn);
>   extern tree_code conditional_internal_fn_code (internal_fn);
>   extern internal_fn get_unconditional_internal_fn (internal_fn);
> +extern bool can_interpret_as_conditional_op_p (gimple *, tree *,
> +                                              tree_code *, tree (&)[3],
> +                                              tree *);

>   extern bool internal_load_fn_p (internal_fn);
>   extern bool internal_store_fn_p (internal_fn);
> Index: gcc/internal-fn.c
> ===================================================================
> --- gcc/internal-fn.c   2018-05-24 13:05:46.048606357 +0100
> +++ gcc/internal-fn.c   2018-05-24 13:08:24.643987582 +0100
> @@ -3333,6 +3333,62 @@ #define CASE(NAME) case IFN_COND_##NAME:
>       }
>   }

> +/* Return true if STMT can be interpreted as a conditional tree code
> +   operation of the form:
> +
> +     LHS = COND ? OP (RHS1, ...) : ELSE;
> +
> +   operating elementwise if the operands are vectors.  This includes
> +   the case of an all-true COND, so that the operation always happens.
> +
> +   When returning true, set:
> +
> +   - *COND_OUT to the condition COND, or to NULL_TREE if the condition
> +     is known to be all-true
> +   - *CODE_OUT to the tree code
> +   - OPS[I] to operand I of *CODE_OUT
> +   - *ELSE_OUT to the fallback value ELSE, or to NULL_TREE if the
> +     condition is known to be all true.  */
> +
> +bool
> +can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,
> +                                  tree_code *code_out,
> +                                  tree (&ops)[3], tree *else_out)
> +{
> +  if (gassign *assign = dyn_cast <gassign *> (stmt))
> +    {
> +      *cond_out = NULL_TREE;
> +      *code_out = gimple_assign_rhs_code (assign);
> +      ops[0] = gimple_assign_rhs1 (assign);
> +      ops[1] = gimple_assign_rhs2 (assign);
> +      ops[2] = gimple_assign_rhs3 (assign);
> +      *else_out = NULL_TREE;
> +      return true;
> +    }
> +  if (gcall *call = dyn_cast <gcall *> (stmt))
> +    if (gimple_call_internal_p (call))
> +      {
> +       internal_fn ifn = gimple_call_internal_fn (call);
> +       tree_code code = conditional_internal_fn_code (ifn);
> +       if (code != ERROR_MARK)
> +         {
> +           *cond_out = gimple_call_arg (call, 0);
> +           *code_out = code;
> +           unsigned int nops = gimple_call_num_args (call) - 2;
> +           for (unsigned int i = 0; i < 3; ++i)
> +             ops[i] = i < nops ? gimple_call_arg (call, i + 1) :
NULL_TREE;
> +           *else_out = gimple_call_arg (call, nops + 1);
> +           if (integer_truep (*cond_out))
> +             {
> +               *cond_out = NULL_TREE;
> +               *else_out = NULL_TREE;
> +             }
> +           return true;
> +         }
> +      }
> +  return false;
> +}
> +
>   /* Return true if IFN is some form of load from memory.  */

>   bool
> Index: gcc/tree-ssa-math-opts.c
> ===================================================================
> --- gcc/tree-ssa-math-opts.c    2018-05-18 09:26:37.749713749 +0100
> +++ gcc/tree-ssa-math-opts.c    2018-05-24 13:08:24.644961583 +0100
> @@ -2655,7 +2655,6 @@ convert_mult_to_fma_1 (tree mul_result,
>     FOR_EACH_IMM_USE_STMT (use_stmt, imm_iter, mul_result)
>       {
>         gimple_stmt_iterator gsi = gsi_for_stmt (use_stmt);
> -      enum tree_code use_code;
>         tree addop, mulop1 = op1, result = mul_result;
>         bool negate_p = false;
>         gimple_seq seq = NULL;
> @@ -2663,8 +2662,8 @@ convert_mult_to_fma_1 (tree mul_result,
>         if (is_gimple_debug (use_stmt))
>          continue;

> -      use_code = gimple_assign_rhs_code (use_stmt);
> -      if (use_code == NEGATE_EXPR)
> +      if (is_gimple_assign (use_stmt)
> +         && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)
>          {
>            result = gimple_assign_lhs (use_stmt);
>            use_operand_p use_p;
> @@ -2675,22 +2674,23 @@ convert_mult_to_fma_1 (tree mul_result,

>            use_stmt = neguse_stmt;
>            gsi = gsi_for_stmt (use_stmt);
> -         use_code = gimple_assign_rhs_code (use_stmt);
>            negate_p = true;
>          }

> -      if (gimple_assign_rhs1 (use_stmt) == result)
> +      tree cond, else_value, ops[3];
> +      tree_code code;
> +      if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code,
> +                                             ops, &else_value))
> +       gcc_unreachable ();
> +      addop = ops[0] == result ? ops[1] : ops[0];
> +
> +      if (code == MINUS_EXPR)
>          {
> -         addop = gimple_assign_rhs2 (use_stmt);
> -         /* a * b - c -> a * b + (-c)  */
> -         if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)
> +         if (ops[0] == result)
> +           /* a * b - c -> a * b + (-c)  */
>              addop = gimple_build (&seq, NEGATE_EXPR, type, addop);
> -       }
> -      else
> -       {
> -         addop = gimple_assign_rhs1 (use_stmt);
> -         /* a - b * c -> (-b) * c + a */
> -         if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)
> +         else
> +           /* a - b * c -> (-b) * c + a */
>              negate_p = !negate_p;
>          }

> @@ -2699,8 +2699,13 @@ convert_mult_to_fma_1 (tree mul_result,

>         if (seq)
>          gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT);
> -      fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2,
addop);
> -      gimple_call_set_lhs (fma_stmt, gimple_assign_lhs (use_stmt));
> +
> +      if (cond)
> +       fma_stmt = gimple_build_call_internal (IFN_COND_FMA, 5, cond,
mulop1,
> +                                              op2, addop, else_value);
> +      else
> +       fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2,
addop);
> +      gimple_set_lhs (fma_stmt, gimple_get_lhs (use_stmt));
>         gimple_call_set_nothrow (fma_stmt, !stmt_can_throw_internal
(use_stmt));
>         gsi_replace (&gsi, fma_stmt, true);
>         /* Follow all SSA edges so that we generate FMS, FNMA and FNMS
> @@ -2883,7 +2888,6 @@ convert_mult_to_fma (gimple *mul_stmt, t
>        as an addition.  */
>     FOR_EACH_IMM_USE_FAST (use_p, imm_iter, mul_result)
>       {
> -      enum tree_code use_code;
>         tree result = mul_result;
>         bool negate_p = false;

> @@ -2904,13 +2908,9 @@ convert_mult_to_fma (gimple *mul_stmt, t
>         if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))
>          return false;

> -      if (!is_gimple_assign (use_stmt))
> -       return false;
> -
> -      use_code = gimple_assign_rhs_code (use_stmt);
> -
>         /* A negate on the multiplication leads to FNMA.  */
> -      if (use_code == NEGATE_EXPR)
> +      if (is_gimple_assign (use_stmt)
> +         && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)
>          {
>            ssa_op_iter iter;
>            use_operand_p usep;
> @@ -2932,17 +2932,20 @@ convert_mult_to_fma (gimple *mul_stmt, t
>            use_stmt = neguse_stmt;
>            if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))
>              return false;
> -         if (!is_gimple_assign (use_stmt))
> -           return false;

> -         use_code = gimple_assign_rhs_code (use_stmt);
>            negate_p = true;
>          }

> -      switch (use_code)
> +      tree cond, else_value, ops[3];
> +      tree_code code;
> +      if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code,
ops,
> +                                             &else_value))
> +       return false;
> +
> +      switch (code)
>          {
>          case MINUS_EXPR:
> -         if (gimple_assign_rhs2 (use_stmt) == result)
> +         if (ops[1] == result)
>              negate_p = !negate_p;
>            break;
>          case PLUS_EXPR:
> @@ -2952,47 +2955,50 @@ convert_mult_to_fma (gimple *mul_stmt, t
>            return false;
>          }

> -      /* If the subtrahend (gimple_assign_rhs2 (use_stmt)) is computed
> -        by a MULT_EXPR that we'll visit later, we might be able to
> -        get a more profitable match with fnma.
> +      if (cond)
> +       {
> +         if (cond == result || else_value == result)
> +           return false;
> +         if (!direct_internal_fn_supported_p (IFN_COND_FMA, type,
opt_type))
> +           return false;
> +       }
> +
> +      /* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that
> +        we'll visit later, we might be able to get a more profitable
> +        match with fnma.
>           OTOH, if we don't, a negate / fma pair has likely lower latency
>           that a mult / subtract pair.  */
> -      if (use_code == MINUS_EXPR && !negate_p
> -         && gimple_assign_rhs1 (use_stmt) == result
> +      if (code == MINUS_EXPR
> +         && !negate_p
> +         && ops[0] == result
>            && !direct_internal_fn_supported_p (IFN_FMS, type, opt_type)
> -         && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type))
> +         && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type)
> +         && TREE_CODE (ops[1]) == SSA_NAME
> +         && has_single_use (ops[1]))
>          {
> -         tree rhs2 = gimple_assign_rhs2 (use_stmt);
> -
> -         if (TREE_CODE (rhs2) == SSA_NAME)
> -           {
> -             gimple *stmt2 = SSA_NAME_DEF_STMT (rhs2);
> -             if (has_single_use (rhs2)
> -                 && is_gimple_assign (stmt2)
> -                 && gimple_assign_rhs_code (stmt2) == MULT_EXPR)
> -             return false;
> -           }
> +         gimple *stmt2 = SSA_NAME_DEF_STMT (ops[1]);
> +         if (is_gimple_assign (stmt2)
> +             && gimple_assign_rhs_code (stmt2) == MULT_EXPR)
> +           return false;
>          }

> -      tree use_rhs1 = gimple_assign_rhs1 (use_stmt);
> -      tree use_rhs2 = gimple_assign_rhs2 (use_stmt);
>         /* We can't handle a * b + a * b.  */
> -      if (use_rhs1 == use_rhs2)
> +      if (ops[0] == ops[1])
>          return false;
>         /* If deferring, make sure we are not looking at an instruction
that
>           wouldn't have existed if we were not.  */
>         if (state->m_deferring_p
> -         && (state->m_mul_result_set.contains (use_rhs1)
> -             || state->m_mul_result_set.contains (use_rhs2)))
> +         && (state->m_mul_result_set.contains (ops[0])
> +             || state->m_mul_result_set.contains (ops[1])))
>          return false;

>         if (check_defer)
>          {
> -         tree use_lhs = gimple_assign_lhs (use_stmt);
> +         tree use_lhs = gimple_get_lhs (use_stmt);
>            if (state->m_last_result)
>              {
> -             if (use_rhs2 == state->m_last_result
> -                 || use_rhs1 == state->m_last_result)
> +             if (ops[1] == state->m_last_result
> +                 || ops[0] == state->m_last_result)
>                  defer = true;
>                else
>                  defer = false;
> @@ -3001,12 +3007,12 @@ convert_mult_to_fma (gimple *mul_stmt, t
>              {
>                gcc_checking_assert (!state->m_initial_phi);
>                gphi *phi;
> -             if (use_rhs1 == result)
> -               phi = result_of_phi (use_rhs2);
> +             if (ops[0] == result)
> +               phi = result_of_phi (ops[1]);
>                else
>                  {
> -                 gcc_assert (use_rhs2 == result);
> -                 phi = result_of_phi (use_rhs1);
> +                 gcc_assert (ops[1] == result);
> +                 phi = result_of_phi (ops[0]);
>                  }

>                if (phi)
> Index: gcc/testsuite/gcc.dg/vect/vect-fma-2.c
> ===================================================================
> --- /dev/null   2018-04-20 16:19:46.369131350 +0100
> +++ gcc/testsuite/gcc.dg/vect/vect-fma-2.c      2018-05-24
13:08:24.643987582 +0100
> @@ -0,0 +1,17 @@
> +/* { dg-do compile } */
> +/* { dg-additional-options "-fdump-tree-optimized -fassociative-math
-fno-trapping-math -fno-signed-zeros" } */
> +
> +#include "tree-vect.h"
> +
> +#define N (VECTOR_BITS * 11 / 64 + 3)
> +
> +double
> +dot_prod (double *x, double *y)
> +{
> +  double sum = 0;
> +  for (int i = 0; i < N; ++i)
> +    sum += x[i] * y[i];
> +  return sum;
> +}
> +
> +/* { dg-final { scan-tree-dump { = \.COND_FMA } "optimized" { target {
vect_double && { vect_fully_masked && scalar_all_fma } } } } } */
> Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c
> ===================================================================
> --- /dev/null   2018-04-20 16:19:46.369131350 +0100
> +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c      2018-05-24
13:08:24.643987582 +0100
> @@ -0,0 +1,18 @@
> +/* { dg-do compile } */
> +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
> +
> +double
> +f (double *restrict a, double *restrict b, int *lookup)
> +{
> +  double res = 0.0;
> +  for (int i = 0; i < 512; ++i)
> +    res += a[lookup[i]] * b[i];
> +  return res;
> +}
> +
> +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+.d, p[0-7]/m, } 2 }
} */
> +/* Check that the vector instructions are the only instructions.  */
> +/* { dg-final { scan-assembler-times {\tfmla\t} 2 } } */
> +/* { dg-final { scan-assembler-not {\tfadd\t} } } */
> +/* { dg-final { scan-assembler-times {\tfaddv\td0,} 1 } } */
> +/* { dg-final { scan-assembler-not {\tsel\t} } } */
> Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c
> ===================================================================
> --- /dev/null   2018-04-20 16:19:46.369131350 +0100
> +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c      2018-05-24
13:08:24.643987582 +0100
> @@ -0,0 +1,17 @@
> +/* { dg-do compile } */
> +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
> +
> +#define REDUC(TYPE)                                            \
> +  TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count)              \
> +  {                                                            \
> +    TYPE sum = 0;                                              \
> +    for (int i = 0; i < count; ++i)                            \
> +      sum += x[i] * y[i];                                      \
> +    return sum;                                                        \
> +  }
> +
> +REDUC (float)
> +REDUC (double)
> +
> +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.s, p[0-7]/m} 1 }
} */
> +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.d, p[0-7]/m} 1 }
} */
> Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c
> ===================================================================
> --- /dev/null   2018-04-20 16:19:46.369131350 +0100
> +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c      2018-05-24
13:08:24.643987582 +0100
> @@ -0,0 +1,17 @@
> +/* { dg-do compile } */
> +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
> +
> +#define REDUC(TYPE)                                            \
> +  TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count)              \
> +  {                                                            \
> +    TYPE sum = 0;                                              \
> +    for (int i = 0; i < count; ++i)                            \
> +      sum -= x[i] * y[i];                                      \
> +    return sum;                                                        \
> +  }
> +
> +REDUC (float)
> +REDUC (double)
> +
> +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.s, p[0-7]/m} 1 }
} */
> +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.d, p[0-7]/m} 1 }
} */


Index Nav: [Date Index] [Subject Index] [Author Index] [Thread Index]
Message Nav: [Date Prev] [Date Next] [Thread Prev] [Thread Next]