[PATCH 5/8 v9]middle-end slp: support complex multiply and complex multiply conjugate
Richard Biener
rguenther@suse.de
Fri Jan 8 09:37:17 GMT 2021
On Mon, 28 Dec 2020, Tamar Christina wrote:
> Hi All,
>
> This adds support for complex multiply and complex multiply and accumulate to
> the vect pattern detector.
>
> Bootstrapped Regtested on aarch64-none-linux-gnu, x86_64-pc-linux-gnu
> and no issues.
>
> Ok for master?
>
> Thanks,
> Tamar
>
> gcc/ChangeLog:
>
> * internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New.
> * optabs.def (cmul_optab, cmul_conj_optab): New.
> * doc/md.texi: Document them.
> * tree-vect-slp-patterns.c (vect_match_call_complex_mla,
> vect_normalize_conj_loc, is_eq_or_top, vect_validate_multiplication,
> vect_build_combine_node, class complex_mul_pattern,
> complex_mul_pattern::matches, complex_mul_pattern::recognize,
> complex_mul_pattern::build): New.
>
> --- inline copy of patch --
> diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi
> index ec6ec180b91fcf9f481b6754c044483787fd923c..b8cc90e1a75e402abbf8a8cf2efefc1a333f8b3a 100644
> --- a/gcc/doc/md.texi
> +++ b/gcc/doc/md.texi
> @@ -6202,6 +6202,50 @@ The operation is only supported for vector modes @var{m}.
>
> This pattern is not allowed to @code{FAIL}.
>
> +@cindex @code{cmul@var{m}4} instruction pattern
> +@item @samp{cmul@var{m}4}
> +Perform a vector multiply that is semantically the same as multiply of
> +complex numbers.
> +
> +@smallexample
> + complex TYPE c[N];
> + complex TYPE a[N];
> + complex TYPE b[N];
> + for (int i = 0; i < N; i += 1)
> + @{
> + c[i] = a[i] * b[i];
> + @}
> +@end smallexample
> +
> +In GCC lane ordering the real part of the number must be in the even lanes with
> +the imaginary part in the odd lanes.
> +
> +The operation is only supported for vector modes @var{m}.
> +
> +This pattern is not allowed to @code{FAIL}.
> +
> +@cindex @code{cmul_conj@var{m}4} instruction pattern
> +@item @samp{cmul_conj@var{m}4}
> +Perform a vector multiply by conjugate that is semantically the same as a
> +multiply of complex numbers where the second multiply arguments is conjugated.
> +
> +@smallexample
> + complex TYPE c[N];
> + complex TYPE a[N];
> + complex TYPE b[N];
> + for (int i = 0; i < N; i += 1)
> + @{
> + c[i] = a[i] * conj (b[i]);
> + @}
> +@end smallexample
> +
> +In GCC lane ordering the real part of the number must be in the even lanes with
> +the imaginary part in the odd lanes.
> +
> +The operation is only supported for vector modes @var{m}.
> +
> +This pattern is not allowed to @code{FAIL}.
> +
> @cindex @code{ffs@var{m}2} instruction pattern
> @item @samp{ffs@var{m}2}
> Store into operand 0 one plus the index of the least significant 1-bit
> diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def
> index 511fe70162b5d9db3a61a5285d31c008f6835487..5a0bbe3fe5dee591d54130e60f6996b28164ae38 100644
> --- a/gcc/internal-fn.def
> +++ b/gcc/internal-fn.def
> @@ -279,6 +279,8 @@ DEF_INTERNAL_FLT_FLOATN_FN (FMAX, ECF_CONST, fmax, binary)
> DEF_INTERNAL_OPTAB_FN (XORSIGN, ECF_CONST, xorsign, binary)
> DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT90, ECF_CONST, cadd90, binary)
> DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT270, ECF_CONST, cadd270, binary)
> +DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL, ECF_CONST, cmul, binary)
> +DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL_CONJ, ECF_CONST, cmul_conj, binary)
>
>
> /* FP scales. */
> diff --git a/gcc/optabs.def b/gcc/optabs.def
> index e9727def4dbf941bb9ac8b56f83f8ea0f52b262c..e82396bae1117c6de91304761a560b7fbcb69ce1 100644
> --- a/gcc/optabs.def
> +++ b/gcc/optabs.def
> @@ -292,6 +292,8 @@ OPTAB_D (copysign_optab, "copysign$F$a3")
> OPTAB_D (xorsign_optab, "xorsign$F$a3")
> OPTAB_D (cadd90_optab, "cadd90$a3")
> OPTAB_D (cadd270_optab, "cadd270$a3")
> +OPTAB_D (cmul_optab, "cmul$a3")
> +OPTAB_D (cmul_conj_optab, "cmul_conj$a3")
> OPTAB_D (cos_optab, "cos$a2")
> OPTAB_D (cosh_optab, "cosh$a2")
> OPTAB_D (exp10_optab, "exp10$a2")
> diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
> index dbc58f7c53868ed431fc67de1f0162eb0d3b2c24..82721acbab8cf81c4d6f9954c98fb913a7bb6282 100644
> --- a/gcc/tree-vect-slp-patterns.c
> +++ b/gcc/tree-vect-slp-patterns.c
> @@ -719,6 +719,368 @@ complex_add_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
> return new complex_add_pattern (node, &ops, ifn);
> }
>
> +/*******************************************************************************
> + * complex_mul_pattern
> + ******************************************************************************/
> +
> +/* Helper function of that looks for a match in the CHILDth child of NODE. The
> + child used is stored in RES.
> +
> + If the match is successful then ARGS will contain the operands matched
> + and the complex_operation_t type is returned. If match is not successful
> + then CMPLX_NONE is returned and ARGS is left unmodified. */
> +
> +static inline complex_operation_t
> +vect_match_call_complex_mla (slp_tree node, unsigned child,
> + vec<slp_tree> *args = NULL, slp_tree *res = NULL)
> +{
> + gcc_assert (child < SLP_TREE_CHILDREN (node).length ());
> +
> + slp_tree data = SLP_TREE_CHILDREN (node)[child];
> +
> + if (res)
> + *res = data;
> +
> + return vect_detect_pair_op (data, false, args);
> +}
> +
> +/* Check to see if either of the trees in ARGS are a NEGATE_EXPR. If the first
> + child (args[0]) is a NEGATE_EXPR then NEG_FIRST_P is set to TRUE.
> +
> + If a negate is found then the values in ARGS are reordered such that the
> + negate node is always the second one and the entry is replaced by the child
> + of the negate node. */
> +
> +static inline bool
> +vect_normalize_conj_loc (vec<slp_tree> args, bool *neg_first_p = NULL)
> +{
> + gcc_assert (args.length () == 2);
> + bool neg_found = false;
> +
> + if (vect_match_expression_p (args[0], NEGATE_EXPR))
> + {
> + std::swap (args[0], args[1]);
> + neg_found = true;
> + if (neg_first_p)
> + *neg_first_p = true;
> + }
> + else if (vect_match_expression_p (args[1], NEGATE_EXPR))
> + {
> + neg_found = true;
> + if (neg_first_p)
> + *neg_first_p = false;
> + }
> +
> + if (neg_found)
> + args[1] = SLP_TREE_CHILDREN (args[1])[0];
> +
> + return neg_found;
> +}
> +
> +/* Helper function to check if PERM is KIND or PERM_TOP. */
> +
> +static inline bool
> +is_eq_or_top (complex_load_perm_t perm, complex_perm_kinds_t kind)
> +{
> + return perm.first == kind || perm.first == PERM_TOP;
> +}
> +
> +/* Helper function that checks to see if LEFT_OP and RIGHT_OP are both MULT_EXPR
> + nodes but also that they represent an operation that is either a complex
> + multiplication or a complex multiplication by conjugated value.
> +
> + Of the negation is expected to be in the first half of the tree (As required
> + by an FMS pattern) then NEG_FIRST is true. If the operation is a conjugate
> + operation then CONJ_FIRST_OPERAND is set to indicate whether the first or
> + second operand contains the conjugate operation. */
> +
> +static inline bool
> +vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
> + vec<slp_tree> left_op, vec<slp_tree> right_op,
> + bool neg_first, bool *conj_first_operand,
> + bool fms)
> +{
> + /* The presence of a negation indicates that we have either a conjugate or a
> + rotation. We need to distinguish which one. */
> + *conj_first_operand = false;
> + complex_perm_kinds_t kind;
> +
> + /* Complex conjugates have the negation on the imaginary part of the
> + number where rotations affect the real component. So check if the
> + negation is on a dup of lane 1. */
> + if (fms)
> + {
> + /* Canonicalization for fms is not consistent. So have to test both
> + variants to be sure. This needs to be fixed in the mid-end so
> + this part can be simpler. */
> + kind = linear_loads_p (perm_cache, right_op[0]).first;
> + if (!((kind == PERM_ODDODD
> + && is_eq_or_top (linear_loads_p (perm_cache, right_op[1]),
> + PERM_ODDEVEN))
> + || (kind == PERM_ODDEVEN
> + && is_eq_or_top (linear_loads_p (perm_cache, right_op[1]),
> + PERM_ODDODD))))
> + return false;
> + }
> + else
> + {
> + if (linear_loads_p (perm_cache, right_op[1]).first != PERM_ODDODD
> + && !is_eq_or_top (linear_loads_p (perm_cache, right_op[0]),
> + PERM_ODDEVEN))
> + return false;
> + }
> +
> + /* Deal with differences in indexes. */
> + int index1 = fms ? 1 : 0;
> + int index2 = fms ? 0 : 1;
> +
> + /* Check if the conjugate is on the second first or second operand. The
> + order of the node with the conjugate value determines this, and the dup
> + node must be one of lane 0 of the same DR as the neg node. */
> + kind = linear_loads_p (perm_cache, left_op[index1]).first;
> + if (kind == PERM_TOP)
> + {
> + if (linear_loads_p (perm_cache, left_op[index2]).first == PERM_EVENODD)
> + return true;
> + }
> + else if (kind == PERM_EVENODD)
> + {
> + if ((kind = linear_loads_p (perm_cache, left_op[index2]).first) == PERM_EVENODD)
> + return false;
> + }
> + else if (!neg_first)
> + *conj_first_operand = true;
> + else
> + return false;
> +
> + if (kind != PERM_EVENEVEN)
> + return false;
> +
> + return true;
> +}
> +
> +/* Helper function to help distinguish between a conjugate and a rotation in a
> + complex multiplication. The operations have similar shapes but the order of
> + the load permutes are different. This function returns TRUE when the order
> + is consistent with a multiplication or multiplication by conjugated
> + operand but returns FALSE if it's a multiplication by rotated operand. */
> +
> +static inline bool
> +vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
> + vec<slp_tree> op, complex_perm_kinds_t permKind)
> +{
> + /* The left node is the more common case, test it first. */
> + if (!is_eq_or_top (linear_loads_p (perm_cache, op[0]), permKind))
> + {
> + if (!is_eq_or_top (linear_loads_p (perm_cache, op[1]), permKind))
> + return false;
> + }
> + return true;
> +}
> +
> +/* This function combines two nodes containing only even and only odd lanes
> + together into a single node which contains the nodes in even/odd order
> + by using a lane permute. */
> +
> +static slp_tree
> +vect_build_combine_node (slp_tree even, slp_tree odd, slp_tree rep)
> +{
> + auto_vec<slp_tree> nodes;
> + nodes.create (2);
> + vec<std::pair<unsigned, unsigned> > perm;
> + perm.create (SLP_TREE_LANES (rep));
> +
> + for (unsigned x = 0; x < SLP_TREE_LANES (rep); x+=2)
> + {
> + perm.quick_push (std::make_pair (0, x));
> + perm.quick_push (std::make_pair (1, x));
> + }
That looks wrong, it creates {0,0}, {1, 0}, {0, 2}, {1, 2}
but you want {0, 0}, {1, 0}, {0, 1}, {1, 1} AFAICS. At least
I assume SLP_TREE_LANES (odd/even) == SLP_TREE_LANES (rep) / 2?
'rep' isn't documented, I assume it's supoosed to be a "representative"
for the result?
> +
> + nodes.quick_push (even);
> + nodes.quick_push (odd);
No need for this intermediate nodes array, just push to ...
> + SLP_TREE_REF_COUNT (even)++;
> + SLP_TREE_REF_COUNT (odd)++;
> +
> + slp_tree vnode = vect_create_new_slp_node (2, SLP_TREE_CODE (even));
> + SLP_TREE_CODE (vnode) = VEC_PERM_EXPR;
> + SLP_TREE_LANE_PERMUTATION (vnode) = perm;
> + SLP_TREE_CHILDREN (vnode).safe_splice (nodes);
... the children array directly (even with quick_push, we've
already allocated 2 elements for the children).
> + SLP_TREE_REF_COUNT (vnode) = 1;
> + SLP_TREE_LANES (vnode) = SLP_TREE_LANES (rep);
> + gcc_assert (perm.length () == SLP_TREE_LANES (vnode));
> + /* Representation is set to that of the current node as the vectorizer
> + can't deal with VEC_PERMs with no representation, as would be the
> + case with invariants. */
Yeah, I need to fix this ...
> + SLP_TREE_REPRESENTATIVE (vnode) = SLP_TREE_REPRESENTATIVE (rep);
> + SLP_TREE_VECTYPE (vnode) = SLP_TREE_VECTYPE (rep);
> + return vnode;
> +}
> +
> +class complex_mul_pattern : public complex_pattern
> +{
> + protected:
> + complex_mul_pattern (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
> + : complex_pattern (node, m_ops, ifn)
> + {
> + this->m_num_args = 2;
> + }
> +
> + public:
> + void build (vec_info *);
> + static internal_fn
> + matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree *,
> + vec<slp_tree> *);
> +
> + static vect_pattern*
> + recognize (slp_tree_to_load_perm_map_t *, slp_tree *);
> +
> + static vect_pattern*
> + mkInstance (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
> + {
> + return new complex_mul_pattern (node, m_ops, ifn);
> + }
> +
> +};
> +
> +/* Pattern matcher for trying to match complex multiply pattern in SLP tree
> + If the operation matches then IFN is set to the operation it matched
> + and the arguments to the two replacement statements are put in m_ops.
> +
> + If no match is found then IFN is set to IFN_LAST and m_ops is unchanged.
> +
> + This function matches the patterns shaped as:
> +
> + double ax = (b[i+1] * a[i]);
> + double bx = (a[i+1] * b[i]);
> +
> + c[i] = c[i] - ax;
> + c[i+1] = c[i+1] + bx;
> +
> + If a match occurred then TRUE is returned, else FALSE. The initial match is
> + expected to be in OP1 and the initial match operands in args0. */
> +
> +internal_fn
> +complex_mul_pattern::matches (complex_operation_t op,
> + slp_tree_to_load_perm_map_t *perm_cache,
> + slp_tree *node, vec<slp_tree> *ops)
> +{
> + internal_fn ifn = IFN_LAST;
> +
> + if (op != MINUS_PLUS)
> + return IFN_LAST;
> +
> + slp_tree root = *node;
> + /* First two nodes must be a multiply. */
> + auto_vec<slp_tree> muls;
> + if (vect_match_call_complex_mla (root, 0) != MULT_MULT
> + || vect_match_call_complex_mla (root, 1, &muls) != MULT_MULT)
> + return IFN_LAST;
> +
> + /* Now operand2+4 may lead to another expression. */
> + auto_vec<slp_tree> left_op, right_op;
> + left_op.safe_splice (SLP_TREE_CHILDREN (muls[0]));
> + right_op.safe_splice (SLP_TREE_CHILDREN (muls[1]));
> +
> + if (linear_loads_p (perm_cache, left_op[1]).first == PERM_ODDEVEN)
> + return IFN_LAST;
> +
> + bool neg_first;
> + bool is_neg = vect_normalize_conj_loc (right_op, &neg_first);
> +
> + if (!is_neg)
> + {
> + /* A multiplication needs to multiply agains the real pair, otherwise
> + the pattern matches that of FMS. */
> + if (!vect_validate_multiplication (perm_cache, left_op, PERM_EVENEVEN)
> + || vect_normalize_conj_loc (left_op))
> + return IFN_LAST;
> + ifn = IFN_COMPLEX_MUL;
> + }
> + else if (is_neg)
> + {
> + bool conj_first_operand;
> + if (!vect_validate_multiplication (perm_cache, left_op, right_op,
> + neg_first, &conj_first_operand,
> + false))
> + return IFN_LAST;
> +
> + ifn = IFN_COMPLEX_MUL_CONJ;
> + }
> +
> + if (!vect_pattern_validate_optab (ifn, *node))
> + return IFN_LAST;
> +
> + ops->truncate (0);
> + ops->create (3);
> +
> + complex_perm_kinds_t kind = linear_loads_p (perm_cache, left_op[0]).first;
> + if (kind == PERM_EVENODD)
> + {
> + ops->quick_push (left_op[1]);
> + ops->quick_push (right_op[1]);
> + ops->quick_push (left_op[0]);
> + }
> + else if (kind == PERM_TOP)
> + {
> + ops->quick_push (left_op[1]);
> + ops->quick_push (right_op[1]);
> + ops->quick_push (left_op[0]);
> + }
> + else
> + {
> + ops->quick_push (left_op[0]);
> + ops->quick_push (right_op[0]);
> + ops->quick_push (left_op[1]);
> + }
> +
> + return ifn;
> +}
> +
> +/* Attempt to recognize a complex mul pattern. */
> +
> +vect_pattern*
> +complex_mul_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
> + slp_tree *node)
> +{
> + auto_vec<slp_tree> ops;
> + complex_operation_t op
> + = vect_detect_pair_op (*node, true, &ops);
> + internal_fn ifn
> + = complex_mul_pattern::matches (op, perm_cache, node, &ops);
> + if (ifn == IFN_LAST)
> + return NULL;
> +
> + return new complex_mul_pattern (node, &ops, ifn);
> +}
> +
> +/* Perform a replacement of the detected complex mul pattern with the new
> + instruction sequences. */
> +
> +void
> +complex_mul_pattern::build (vec_info *vinfo)
> +{
> + auto_vec<slp_tree> nodes;
> +
> + /* First re-arrange the children. */
> + nodes.create (2);
> +
> + nodes.quick_push (this->m_ops[2]);
> + nodes.quick_push (
> + vect_build_combine_node (this->m_ops[0], this->m_ops[1], *this->m_node));
> + SLP_TREE_REF_COUNT (this->m_ops[2])++;
> +
> + slp_tree node;
> + unsigned i;
> + FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (*this->m_node), i, node)
> + vect_free_slp_tree (node);
> +
> + SLP_TREE_CHILDREN (*this->m_node).truncate (0);
> + SLP_TREE_CHILDREN (*this->m_node).safe_splice (nodes);
please elide the nodes array. *this->m_node now has a "wrong"
representative but I guess
> + complex_pattern::build (vinfo);
will fix that up? I still find the structure of the pattern matching
& transform hard to follow. But well - I've settled with the idea
of refactoring it for next stage1 after the fact ;)
Thanks,
Richard.
More information about the Gcc-patches
mailing list