diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi index ddaf1abaccbd44dae11ea902ec38b474aacfb8e1..d8142f745050d963e8d15c7793fae06d9ad02020 100644 --- a/gcc/doc/md.texi +++ b/gcc/doc/md.texi @@ -6143,6 +6143,50 @@ rotations @var{m} of 90 or 270. This pattern is not allowed to @code{FAIL}. +@cindex @code{cmla@var{m}4} instruction pattern +@item @samp{cmla@var{m}4} +Perform a vector floating point multiply and accumulate of complex numbers +in operand 0, operand 1 and operand 2. + +The instruction must perform the operation on data loaded contiguously into the +vectors. +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + +@cindex @code{cmla_conj@var{m}4} instruction pattern +@item @samp{cmla_conj@var{m}4} +Perform a vector floating point multiply and accumulate of complex numbers +in operand 0, operand 1 and the conjucate of operand 2. + +The instruction must perform the operation on data loaded contiguously into the +vectors. +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + +@cindex @code{cmls@var{m}4} instruction pattern +@item @samp{cmls@var{m}4} +Perform a vector floating point multiply and subtract of complex numbers +in operand 0, operand 1 and operand 2. + +The instruction must perform the operation on data loaded contiguously into the +vectors. +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + +@cindex @code{cmls_conj@var{m}4} instruction pattern +@item @samp{cmls_conj@var{m}4} +Perform a vector floating point multiply and subtract of complex numbers +in operand 0, operand 1 and the conjucate of operand 2. + +The instruction must perform the operation on data loaded contiguously into the +vectors. +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + @cindex @code{cmul@var{m}4} instruction pattern @item @samp{cmul@var{m}4} Perform a vector floating point multiplication of complex numbers in operand 0 diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def index cb41643f5e332518a0271bb8e1af4883c8bd6880..acb7d9f3bdc757437d5492a652144ba31c2ef702 100644 --- a/gcc/internal-fn.def +++ b/gcc/internal-fn.def @@ -288,6 +288,10 @@ DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary) /* Ternary math functions. */ DEF_INTERNAL_FLT_FLOATN_FN (FMA, ECF_CONST, fma, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA, ECF_CONST, cmla, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA_CONJ, ECF_CONST, cmla_conj, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMS, ECF_CONST, cmls, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMS_CONJ, ECF_CONST, cmls_conj, ternary) /* Unary integer ops. */ DEF_INTERNAL_INT_FN (CLRSB, ECF_CONST | ECF_NOTHROW, clrsb, unary) diff --git a/gcc/optabs.def b/gcc/optabs.def index 9c267d422478d0011f288b1f5f62daabe3989ba7..19db9c00896cd08adfd20a01669990bbbebd79f1 100644 --- a/gcc/optabs.def +++ b/gcc/optabs.def @@ -294,6 +294,10 @@ OPTAB_D (cadd90_optab, "cadd90$a3") OPTAB_D (cadd270_optab, "cadd270$a3") OPTAB_D (cmul_optab, "cmul$a3") OPTAB_D (cmul_conj_optab, "cmul_conj$a3") +OPTAB_D (cmla_optab, "cmla$a4") +OPTAB_D (cmla_conj_optab, "cmla_conj$a4") +OPTAB_D (cmls_optab, "cmls$a4") +OPTAB_D (cmls_conj_optab, "cmls_conj$a4") OPTAB_D (cos_optab, "cos$a2") OPTAB_D (cosh_optab, "cosh$a2") OPTAB_D (exp10_optab, "exp10$a2") diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c index 2edb0117f9cbbfc40e9ed3a96120a3c88f84a68e..c2987c2afac2fbd55e2acd6b56fc13c7d3ad13c1 100644 --- a/gcc/tree-vect-slp-patterns.c +++ b/gcc/tree-vect-slp-patterns.c @@ -1172,6 +1172,176 @@ complex_mul_pattern::validate_p () return true; } + +/******************************************************************************* + * complex_fma_pattern class + ******************************************************************************/ + +class complex_fma_pattern : public complex_mul_pattern +{ + protected: + complex_fma_pattern (slp_tree *node, vec_info *vinfo) + : complex_mul_pattern (node, vinfo) + { + this->m_arity = 2; + this->m_num_args = 3; + } + + public: + static vect_pattern* create (slp_tree *node, vec_info *vinfo) + { + return new complex_fma_pattern (node, vinfo); + } + + const char* get_name () + { + return "Complex FM(A|S)"; + } + + bool matches (); + bool matches (complex_operation_t op, vec ops); +}; + +/* Pattern matcher for trying to match complex multiply and accumulate + and multiply and subtract patterns in SLP tree. + If the operation matches then IFN is set to the operation it matched and + the arguments to the two replacement statements are put in M_OPS. + + If no match is found then IFN is set to IFN_LAST and M_OPTS is unchanged. + + This function matches the patterns shaped as: + + double ax = (b[i+1] * a[i]) + (b[i] * a[i]); + double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]); + + c[i] = c[i] - ax; + c[i+1] = c[i+1] + bx; + + If a match occurred then TRUE is returned, else FALSE. */ +bool +complex_fma_pattern::matches (complex_operation_t op1, vec args0) +{ + this->m_ifn = IFN_LAST; + + /* Find the two components. We match Complex MUL first which reduces the + amount of work this pattern has to do. After that we just match the + head node and we're done.: + + * FMA: + + + * FMS: - +. */ + slp_tree child = NULL; + + /* We need to ignore the two_operands nodes that may also match, + for that we can check if they have any scalar statements and also + check that it's not a permute node as we're looking for a normal + PLUS_EXPR operation. */ + if (op1 == PLUS_MINUS) + { + child = SLP_TREE_CHILDREN (args0[1])[1]; + } + else if (SLP_TREE_SCALAR_STMTS (*this->m_node).length () > 0 + && SLP_TREE_CODE (*this->m_node) != VEC_PERM_EXPR + && vect_match_expression_p (*this->m_node, PLUS_EXPR)) + { + if (SLP_TREE_CHILDREN (*this->m_node).length () != 2) + return false; + + op1 = PLUS_PLUS; + args0.safe_splice (SLP_TREE_CHILDREN (*this->m_node)); + child = args0[1]; + } + else + return false; + + auto_vec ops; + internal_fn mulfn = IFN_LAST; + /* The accumulation step produces an inverse tree from normal + multiply so match the nodes in reverse. */ + if (!vect_slp_matches_complex_mul (child, &mulfn, &ops, false, + op1 == PLUS_MINUS)) + return false; + + this->m_ops.create (6); + if (op1 == PLUS_MINUS) + { + if (mulfn == IFN_COMPLEX_MUL) + this->m_ifn = IFN_COMPLEX_FMS; + else if (mulfn == IFN_COMPLEX_MUL_CONJ) + this->m_ifn = IFN_COMPLEX_FMS_CONJ; + + child = SLP_TREE_CHILDREN (args0[0])[0]; + this->workset.safe_splice (SLP_TREE_CHILDREN (*this->m_node)); + save_match (); + } + else if (op1 == PLUS_PLUS) + { + if (mulfn == IFN_COMPLEX_MUL) + this->m_ifn = IFN_COMPLEX_FMA; + else if (mulfn == IFN_COMPLEX_MUL_CONJ) + this->m_ifn = IFN_COMPLEX_FMA_CONJ; + + /* Add doesn't generate a two_operators node, so for it we replace it + inline by turning the add node itself into a pattern. */ + this->m_inplace = true; + this->workset.safe_push (*this->m_node); + child = args0[0]; + this->m_match + = new vect_simple_pattern_match (this->m_arity, this->m_ifn, + this->m_vinfo, &this->workset, + this->m_num_args); + } + + if (this->m_ifn == IFN_LAST) + return false; + + /* The conjucate nodes have a different orderings, oddly enough the SUB node + has the same order regardless of the conjucate. This needs to be made more + consistent in the mid-end. */ + if (op1 == PLUS_MINUS || mulfn == IFN_COMPLEX_MUL) + { + this->m_ops.quick_push (child); + this->m_ops.quick_push (ops[1]); + this->m_ops.quick_push (ops[0]); + this->m_ops.quick_push (child); + this->m_ops.quick_push (ops[3]); + this->m_ops.quick_push (ops[2]); + } + else + { + this->m_ops.quick_push (child); + this->m_ops.quick_push (ops[0]); + this->m_ops.quick_push (ops[1]); + this->m_ops.quick_push (child); + this->m_ops.quick_push (ops[2]); + this->m_ops.quick_push (ops[3]); + } + + vect_build_perm_groups (&this->m_blocks[0], this->m_ops); + + /* Unfortunately the sequence for a conjucate and rotation by 180 and 270 are + remarkably similar. So we need to do some extra checks to make sure we + don't match those. */ + if (mulfn == IFN_COMPLEX_MUL_CONJ) + for (unsigned i = 0; i < this->m_ops.length (); i++) + { + map_t m = this->m_blocks[i]; + if (m.a > m.b) + return false; + } + + return true; +} + +bool +complex_fma_pattern::matches () +{ + auto_vec args0; + complex_operation_t op + = vect_detect_pair_op (*this->m_node, true, &args0); + return matches (op, args0); +} + + /******************************************************************************* * complex_operations_pattern class ******************************************************************************/ @@ -1303,6 +1473,10 @@ vect_pattern_decl_t slp_patterns[] order patterns from the largest to the smallest. Especially if they overlap in what they can detect. */ + /* FMA overlaps with MUL but is the longer sequence. Because we're in post + order traversal we can't match FMA if included in + complex_operations_pattern so must be checked on it's own. */ + SLP_PATTERN (complex_fma_pattern), SLP_PATTERN (complex_operations_pattern), }; #undef SLP_PATTERN