This is the mail archive of the
gcc@gcc.gnu.org
mailing list for the GCC project.
Re: [GSoC] commutative patterns
- From: Prathamesh Kulkarni <bilbotheelffriend at gmail dot com>
- To: Richard Biener <richard dot guenther at gmail dot com>
- Cc: Diego Novillo <dnovillo at google dot com>, gcc <gcc at gcc dot gnu dot org>, Maxim Kuvyrkov <maxim dot kuvyrkov at linaro dot org>
- Date: Mon, 23 Jun 2014 01:22:35 +0530
- Subject: Re: [GSoC] commutative patterns
- Authentication-results: sourceware.org; auth=none
- References: <CAJXstsCmdC9FQPqynTVmV2kBco7vTk4ifUp5TvauZ4FNfWpCUQ at mail dot gmail dot com> <CAJXstsBeaZOx64=CqQQtJvYcBw0LdVsdbZCR3kPp_JV6-V-sUw at mail dot gmail dot com> <CAJXstsBMu88CEP7Ny_8icV32=oB8_7WNohj2es0c1+T1LXFy0w at mail dot gmail dot com>
On Sun, Jun 22, 2014 at 3:09 AM, Prathamesh Kulkarni
<bilbotheelffriend@gmail.com> wrote:
> On Fri, Jun 20, 2014 at 3:02 AM, Prathamesh Kulkarni
> <bilbotheelffriend@gmail.com> wrote:
>>
>> On Fri, Jun 20, 2014 at 2:53 AM, Prathamesh Kulkarni
>> <bilbotheelffriend@gmail.com> wrote:
>> > Hi,
>> > The attached patch attempts to generate commutative variants for
>> > a given expression.
>> >
>> > Example:
>> > For the AST: (PLUS_EXPR (PLUS_EXPR @0 @1) @2),
>> >
>> > the commutative variants are:
>> > (PLUS_EXPR (PLUS_EXPR @0 @1 ) @2 )
>> > (PLUS_EXPR (PLUS_EXPR @1 @0 ) @2 )
>> > (PLUS_EXPR @2 (PLUS_EXPR @0 @1 ) )
>> > (PLUS_EXPR @2 (PLUS_EXPR @1 @0 ) )
>> >
>> >
>> > * Basic Idea:
>> > Consider expression e with two operands o0, and o1,
>> > and expr-code denoting expression's code (plus/mult, etc.)
>> >
>> > Commutative variants are stored in vector (vec<operand *>).
>> >
>> > vec<operand *>
>> > commutative (e)
>> > {
>> > if (e is not commutative)
>> > return [e]; // vector with only one expression
>> >
>> > v1 = commutative (o0);
>> > v2 = commutative (o1);
>> > ret = []
>> >
>> > for i = 0 ... v1.length ()
>> > for j = 0 ... v2.length ()
>> > {
>> > ne = new expr with <expr-code> and operands: v1[i], v2[j];
>> > append ne to ret;
>> > }
>> >
>> > for i = 0 ... v2.length ()
>> > for j = 0 ... v1.length ()
>> > {
>> > ne = new expr with <expr-code> and operand: v2[i], v1[j];
>> > append ne to ret
>> > }
>> >
>> > return ret;
>> > }
>> >
>> > Example:
>> > (plus (plus @0 @1) (plus @2 @3))
>> > generates following commutative variants:
>> oops.
>> the pattern given to genmatch was (bogus):
>> (plus (plus @0 @1) (plus @0 @3))
>> >
>> > (PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @0 @3 ) )
>> > (PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @3 @0 ) )
>> > (PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @0 @3 ) )
>> > (PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @3 @0 ) )
>> > (PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @0 @1 ) )
>> > (PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @1 @0 ) )
>> > (PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @0 @1 ) )
>> > (PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @1 @0 ) )
>> >
>> >
>> > * Decide which operators are commutative.
>> > Currently I assume all PLUS_EXPR and MULT_EXPR are true.
>> s/true/commutative
> There's a bug in the previous patch - if the operator is not
> commutative, it does not try
> for generating commutative variants of it's operands, and does not
> commutate captured
> expression (.what).
> example:
> (negate (plus @0 @1)) has two commutative variants (including the
> original pattern),
> but the patch does not generate them, since negate is not commutative.
>
> The attached patch fixes that. As a quick hack i handled each operator
> class (unary, binary, ternary)
> specially (commutate_unary, commutate_binary, commutate_ternary).
> Ideally it should be unified
> (I tried that way, but it was segfaulting). I will try and come up
> with a better way.
> Also the current patch won't work for built-in functions/operators
> having more than 3 operands.
> (max we have 3 so far in match.pd for cond, I hope this doesn't come
> "in the way").
>
> With the current patch,
> for the expression (negate (plus @0 @1))
> it generates following commutative variants:
> (negate (plus @0 @1))
> (negate (plus @1 @0))
>
> and for the following pattern (involving captured expression):
> (negate (plus@0 @1 @2))
> it generates following variants:
> (negate (plus@0 @1 @2))
> (negate (plus@0 @2 @1))
>
> * generates multiple matching patterns
> Since at AST-level we do not test for captures equality (true/match),
> it treats both of the captures
> as different, even though they are same.
> example: the following also expression has 2 variants generated
> (BUILT_IN_SQRT (mult @0 @0))
> commutative variants:
> (BUILT_IN_SQRT (mult @0 @0))
> (BUILT_IN_SQRT (mult @0 @0))
> I guess this won't really be a problem with decision tree. If we decide to emit
> warning, we should warn only for user defined patterns, and not generated ones.
>
> * syntax for commutative operators
> Currently, I assume any PLUS_EXPR / MULT_EXPR to be commutative.
> I guess we should have syntax for users marking an operator to be commutative.
>
> sth like:
> a) op:c
> b) op "c"
> c) op!
> d) op "commutative"
>
> Or any other, that you would like -:)
>
> * cloning AST nodes
> Currently I do not do a deep-copy of the AST for each distinct
> commutative variant, so the nodes
> are shared for different expressions, which are commutative variants
> of the original expression.
> Is this OK, or should we clone each AST node, so that each expression
> is represented by a distinct AST ?
> cloning shall eat up space, while sharing shall require more careful
> memory management (freeing one ast, may also
> free nodes of other expression).
This patch removes the hack of special handling according to operator classes.
For now, I added op:c syntax to denote operator op as commutative.
Example: (does not commutate outer plus since it's not marked commutative).
(plus (plus:c@0 @1 @2))
generates following variants:
(PLUS_EXPR (PLUS_EXPR@0 @1 @2 ) @3 )
(PLUS_EXPR (PLUS_EXPR@0 @2 @1 ) @3 )
How do we resize a vector to hold n elements at start ?
I tried:
vec<operand *> v = vNULL;
v.resize (n);
v.resize_exact (n);
however accessing v[i] led to internal abort in operator[] (vec.h line 735).
As a work-around I did (in cartesian_product):
for (unsigned i = 0; i < n_ops; ++i)
v.safe_push (0);
This works to make vector "big enough" to hold n_ops elements, but is
rather ugly.
Thanks and Regards,
Prathamesh
>
> Thanks and Regards,
> Prathamesh
>
>> > Maybe we should add syntax to mark a particular operator as commutative ?
>> >
>> > * Cloning AST nodes
>> > While creating another AST that represents one of
>> > the commutative variants, should we clone the AST nodes,
>> > so that all commutative variants have distinct AST nodes ?
>> > That's not done currently, and AST nodes are shared amongst
>> > different commutative expressions, and we end up with a DAG,
>> > for a set of commutative expressions.
>> >
>> > Thanks and Regards,
>> > Prathamesh
Index: genmatch.c
===================================================================
--- genmatch.c (revision 211732)
+++ genmatch.c (working copy)
@@ -119,7 +119,7 @@ struct id_base : typed_free_remove<id_ba
{
enum id_kind { CODE, FN } kind;
- id_base (id_kind, const char *);
+ id_base (id_kind, const char *);
hashval_t hashval;
const char *id;
@@ -146,7 +146,7 @@ id_base::equal (const value_type *op1,
static hash_table<id_base> operators;
-id_base::id_base (id_kind kind_, const char *id_)
+id_base::id_base (id_kind kind_, const char *id_)
{
kind = kind_;
id = id_;
@@ -218,8 +218,9 @@ struct predicate : public operand
};
struct e_operation {
- e_operation (const char *id);
+ e_operation (const char *id, bool is_commutative_ = false);
id_base *op;
+ bool is_commutative;
};
@@ -258,9 +259,11 @@ struct capture : public operand
};
-e_operation::e_operation (const char *id)
+e_operation::e_operation (const char *id, bool is_commutative_)
{
id_base tem (id_base::CODE, id);
+ is_commutative = is_commutative_;
+
op = operators.find_with_hash (&tem, tem.hashval);
if (op)
return;
@@ -293,14 +296,14 @@ e_operation::e_operation (const char *id
struct simplify {
simplify (const char *name_,
- struct operand *match_, source_location match_location_,
+ vec<operand *> matchers_, source_location match_location_,
struct operand *ifexpr_, source_location ifexpr_location_,
struct operand *result_, source_location result_location_)
- : name (name_), match (match_), match_location (match_location_),
+ : name (name_), matchers (matchers_), match_location (match_location_),
ifexpr (ifexpr_), ifexpr_location (ifexpr_location_),
result (result_), result_location (result_location_) {}
const char *name;
- struct operand *match;
+ vec<operand *> matchers; // vector to hold commutative expressions
source_location match_location;
struct operand *ifexpr;
source_location ifexpr_location;
@@ -308,7 +311,148 @@ struct simplify {
source_location result_location;
};
+void
+print_operand (operand *o, FILE *f = stderr)
+{
+ if (o->type == operand::OP_CAPTURE)
+ {
+ capture *c = static_cast<capture *> (o);
+ fprintf (f, "@%s", (static_cast<capture *> (o))->where);
+ if (c->what)
+ {
+ putc (':', f);
+ print_operand (c->what, f);
+ putc (' ', f);
+ }
+ }
+
+ else if (o->type == operand::OP_PREDICATE)
+ fprintf (f, "%s", (static_cast<predicate *> (o))->ident);
+
+ else if (o->type == operand::OP_C_EXPR)
+ fprintf (f, "c_expr");
+
+ else if (o->type == operand::OP_EXPR)
+ {
+ expr *e = static_cast<expr *> (o);
+ fprintf (f, "(%s ", e->operation->op->id);
+
+ for (unsigned i = 0; i < e->ops.length (); ++i)
+ {
+ print_operand (e->ops[i], f);
+ putc (' ', f);
+ }
+
+ putc (')', f);
+ }
+
+ else
+ gcc_unreachable ();
+}
+
+void
+print_matches (struct simplify *s, FILE *f = stderr)
+{
+ if (s->matchers.length () == 1)
+ return;
+
+ fprintf (f, "for expression: ");
+ print_operand (s->matchers[0], f); // s->matchers[0] is equivalent to original expression
+ putc ('\n', f);
+
+ fprintf (f, "commutative expressions:\n");
+ for (unsigned i = 0; i < s->matchers.length (); ++i)
+ {
+ print_operand (s->matchers[i], f);
+ putc ('\n', f);
+ }
+}
+
+void
+cartesian_product (const vec< vec<operand *> >& ops_vector, vec< vec<operand *> >& result, vec<operand *>& v, unsigned n)
+{
+ if (n == ops_vector.length ())
+ {
+ vec<operand *> xv = v.copy ();
+ result.safe_push (xv);
+ return;
+ }
+
+ for (unsigned i = 0; i < ops_vector[n].length (); ++i)
+ {
+ v[n] = ops_vector[n][i];
+ cartesian_product (ops_vector, result, v, n + 1);
+ }
+}
+
+void
+cartesian_product (const vec< vec<operand *> >& ops_vector, vec< vec<operand *> >& result, unsigned n_ops)
+{
+ vec<operand *> v = vNULL;
+// FIXME: this is done to resize v to length n_ops.
+ for (unsigned i = 0; i < n_ops; ++i)
+ v.safe_push (0);
+ cartesian_product (ops_vector, result, v, 0);
+}
+
+vec<operand *>
+commutate (operand *op)
+{
+ vec<operand *> ret = vNULL;
+
+ if (op->type == operand::OP_CAPTURE)
+ {
+ capture *c = static_cast<capture *> (op);
+ if (!c->what)
+ {
+ ret.safe_push (op);
+ return ret;
+ }
+ vec<operand *> v = commutate (c->what);
+ for (unsigned i = 0; i < v.length (); ++i)
+ {
+ capture *nc = new capture (c->where, v[i]);
+ ret.safe_push (nc);
+ }
+ return ret;
+ }
+
+ if (op->type != operand::OP_EXPR)
+ {
+ ret.safe_push (op);
+ return ret;
+ }
+
+ expr *e = static_cast<expr *> (op);
+
+ vec< vec<operand *> > ops_vector = vNULL;
+ for (unsigned i = 0; i < e->ops.length (); ++i)
+ ops_vector.safe_push (commutate (e->ops[i]));
+
+ vec< vec<operand *> > result = vNULL;
+ cartesian_product (ops_vector, result, e->ops.length ());
+
+ for (unsigned i = 0; i < result.length (); ++i)
+ {
+ expr *ne = new expr (e->operation);
+ for (unsigned j = 0; j < result[i].length (); ++j)
+ ne->append_op (result[i][j]);
+ ret.safe_push (ne);
+ }
+ if (!e->operation->is_commutative)
+ return ret;
+
+ for (unsigned i = 0; i < result.length (); ++i)
+ {
+ expr *ne = new expr (e->operation);
+ for (unsigned j = result[i].length (); j; --j) // result[i].length () is 2 since e->operation is binary
+ ne->append_op (result[i][j-1]);
+ ret.safe_push (ne);
+ }
+
+ return ret;
+}
/* Code gen off the AST. */
@@ -574,11 +718,15 @@ write_nary_simplifiers (FILE *f, vec<sim
{
simplify *s = simplifiers[i];
/* ??? This means we can't capture the outermost expression. */
- if (s->match->type != operand::OP_EXPR)
+ for (unsigned i = 0; i < s->matchers.length (); ++i)
+ {
+ operand *match = s->matchers[i];
+ if (match->type != operand::OP_EXPR)
continue;
- expr *e = static_cast <expr *> (s->match);
+ expr *e = static_cast <expr *> (match);
if (e->ops.length () != n)
continue;
+
char fail_label[16];
snprintf (fail_label, 16, "fail%d", label_cnt++);
output_line_directive (f, s->match_location);
@@ -627,6 +775,7 @@ write_nary_simplifiers (FILE *f, vec<sim
fprintf (f, " }\n");
fprintf (f, "%s:\n", fail_label);
}
+ }
fprintf (f, " return false;\n");
fprintf (f, "}\n");
}
@@ -827,6 +976,25 @@ parse_expr (cpp_reader *r)
expr *e = new expr (parse_operation (r));
const cpp_token *token = peek (r);
operand *op;
+ bool is_commutative = false;
+
+ if (token->type == CPP_COLON)
+ {
+ eat_token (r, CPP_COLON);
+ token = peek (r);
+ if (token->type == CPP_NAME
+ && !(token->flags & PREV_WHITE))
+ {
+ const char *s = (const char *)CPP_HASHNODE (token->val.node.node)->ident.str;
+ eat_token (r, CPP_NAME);
+ token = peek (r);
+ if (s[0] == 'c' && !s[1])
+ is_commutative = true;
+ else
+ fatal_at (token, "not implemented: predicates on expressions");
+ }
+ }
+
if (token->type == CPP_ATSIGN
&& !(token->flags & PREV_WHITE))
op = parse_capture (r, e);
@@ -847,6 +1015,13 @@ parse_expr (cpp_reader *r)
fatal_at (token, "got %d operands instead of the required %d",
e->ops.length (), opr->get_required_nargs ());
}
+ if (is_commutative)
+ {
+ if (e->ops.length () == 2)
+ e->operation->is_commutative = true;
+ else
+ fatal_at (token, "only binary operators or function with two arguments can be marked commutative");
+ }
return op;
}
e->append_op (parse_op (r));
@@ -971,7 +1146,7 @@ parse_match_and_simplify (cpp_reader *r,
ifexpr = parse_c_expr (r, CPP_OPEN_PAREN);
}
token = peek (r);
- return new simplify (id, match, match_location,
+ return new simplify (id, commutate (match), match_location,
ifexpr, ifexpr_location, parse_op (r), token->src_loc);
}
@@ -1043,6 +1218,9 @@ main(int argc, char **argv)
}
while (1);
+ for (unsigned i = 0; i < simplifiers.length (); ++i)
+ print_matches (simplifiers[i]);
+
write_gimple (stdout, simplifiers);
cpp_finish (r, NULL);