This is the mail archive of the gcc@gcc.gnu.org mailing list for the GCC project.


Index Nav: [Date Index] [Subject Index] [Author Index] [Thread Index]
Message Nav: [Date Prev] [Date Next] [Thread Prev] [Thread Next]
Other format: [Raw text]

Re: [GSoC] commutative patterns


On Sun, Jun 22, 2014 at 3:09 AM, Prathamesh Kulkarni
<bilbotheelffriend@gmail.com> wrote:
> On Fri, Jun 20, 2014 at 3:02 AM, Prathamesh Kulkarni
> <bilbotheelffriend@gmail.com> wrote:
>>
>> On Fri, Jun 20, 2014 at 2:53 AM, Prathamesh Kulkarni
>> <bilbotheelffriend@gmail.com> wrote:
>> > Hi,
>> >     The attached patch attempts to generate commutative variants for
>> > a given expression.
>> >
>> > Example:
>> > For the AST: (PLUS_EXPR (PLUS_EXPR @0 @1) @2),
>> >
>> > the commutative variants are:
>> > (PLUS_EXPR (PLUS_EXPR @0 @1 ) @2 )
>> > (PLUS_EXPR (PLUS_EXPR @1 @0 ) @2 )
>> > (PLUS_EXPR @2 (PLUS_EXPR @0 @1 ) )
>> > (PLUS_EXPR @2 (PLUS_EXPR @1 @0 ) )
>> >
>> >
>> > * Basic Idea:
>> > Consider expression e with two operands o0, and o1,
>> > and expr-code denoting expression's code (plus/mult, etc.)
>> >
>> > Commutative variants are stored in vector (vec<operand *>).
>> >
>> > vec<operand *>
>> > commutative (e)
>> > {
>> >   if (e is not commutative)
>> >     return [e];  // vector with only one expression
>> >
>> >   v1 = commutative (o0);
>> >   v2 = commutative (o1);
>> >   ret = []
>> >
>> >   for i = 0 ... v1.length ()
>> >     for j = 0 ... v2.length ()
>> >       {
>> >         ne = new expr with <expr-code> and operands: v1[i], v2[j];
>> >         append ne to ret;
>> >       }
>> >
>> >   for i = 0 ... v2.length ()
>> >     for j = 0 ... v1.length ()
>> >       {
>> >         ne = new expr with <expr-code> and operand: v2[i], v1[j];
>> >         append ne to ret
>> >       }
>> >
>> >   return ret;
>> > }
>> >
>> > Example:
>> > (plus (plus @0 @1) (plus @2 @3))
>> > generates following commutative variants:
>> oops.
>> the pattern given to genmatch was (bogus):
>> (plus (plus @0 @1) (plus @0 @3))
>> >
>> > (PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @0 @3 ) )
>> > (PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @3 @0 ) )
>> > (PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @0 @3 ) )
>> > (PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @3 @0 ) )
>> > (PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @0 @1 ) )
>> > (PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @1 @0 ) )
>> > (PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @0 @1 ) )
>> > (PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @1 @0 ) )
>> >
>> >
>> > * Decide which operators are commutative.
>> > Currently I assume all PLUS_EXPR and MULT_EXPR are true.
>> s/true/commutative
> There's a bug in the previous patch - if the operator is not
> commutative, it does not try
> for generating commutative variants of it's operands, and does not
> commutate captured
> expression (.what).
> example:
> (negate (plus @0 @1)) has two commutative variants (including the
> original pattern),
> but the patch does not generate them, since negate is not commutative.
>
> The attached patch fixes that. As a quick hack i handled each operator
> class (unary, binary, ternary)
> specially (commutate_unary, commutate_binary, commutate_ternary).
> Ideally it should be unified
> (I tried that way, but it was segfaulting). I will try and come up
> with a better way.
> Also the current patch won't work for built-in functions/operators
> having more than 3 operands.
> (max we have 3 so far in match.pd for cond, I hope this doesn't come
> "in the way").
>
> With the current patch,
> for the expression (negate (plus @0 @1))
> it generates following commutative variants:
> (negate (plus @0 @1))
> (negate (plus @1 @0))
>
> and for the following pattern (involving captured expression):
> (negate (plus@0 @1 @2))
> it generates following variants:
> (negate (plus@0 @1 @2))
> (negate (plus@0 @2 @1))
>
> * generates multiple matching patterns
> Since at AST-level we do not test for captures equality (true/match),
> it treats both of the captures
> as different, even though they are same.
> example: the following also expression has 2 variants generated
> (BUILT_IN_SQRT (mult @0 @0))
> commutative variants:
> (BUILT_IN_SQRT (mult @0 @0))
> (BUILT_IN_SQRT (mult @0 @0))
> I guess this won't really be a problem with decision tree. If we decide to emit
> warning, we should warn only for user defined patterns, and not generated ones.
>
> * syntax for commutative operators
> Currently, I assume any PLUS_EXPR / MULT_EXPR to be commutative.
> I guess we should have syntax for users marking an operator to be commutative.
>
> sth like:
> a) op:c
> b) op "c"
> c) op!
> d) op "commutative"
>
> Or any other, that you would like -:)
>
> * cloning AST nodes
> Currently I do not do a deep-copy of the AST for each distinct
> commutative variant, so the nodes
> are shared for different expressions, which are commutative variants
> of the original expression.
> Is this OK, or should we clone each AST node, so that each expression
> is represented by a distinct AST ?
> cloning shall eat up space, while sharing shall require more careful
> memory management (freeing one ast, may also
> free nodes of other expression).
This patch removes the hack of special handling according to operator classes.
For now, I added op:c syntax to denote operator op as commutative.

Example: (does not commutate outer plus since it's not marked commutative).
(plus (plus:c@0 @1 @2))
generates following variants:
(PLUS_EXPR (PLUS_EXPR@0 @1 @2 )  @3 )
(PLUS_EXPR (PLUS_EXPR@0 @2 @1 )  @3 )

How do we resize a vector to hold n elements at start ?
I tried:
vec<operand *> v = vNULL;
v.resize (n);
v.resize_exact (n);
however accessing v[i] led to internal abort in operator[] (vec.h line 735).

As a work-around I did (in cartesian_product):
for (unsigned i = 0; i < n_ops; ++i)
  v.safe_push (0);
This works to make vector "big enough" to hold n_ops elements, but is
rather ugly.

Thanks and Regards,
Prathamesh
>
> Thanks and Regards,
> Prathamesh
>
>> > Maybe we should add syntax to mark a particular operator as commutative ?
>> >
>> > * Cloning AST nodes
>> > While creating another AST that represents one of
>> > the commutative variants, should we clone the AST nodes,
>> > so that all commutative variants have distinct AST nodes ?
>> > That's not done currently, and AST nodes are shared amongst
>> > different commutative expressions, and we end up with a DAG,
>> > for a set of commutative expressions.
>> >
>> > Thanks and Regards,
>> > Prathamesh
Index: genmatch.c
===================================================================
--- genmatch.c	(revision 211732)
+++ genmatch.c	(working copy)
@@ -119,7 +119,7 @@ struct id_base : typed_free_remove<id_ba
 {
   enum id_kind { CODE, FN } kind;
 
-  id_base (id_kind, const char *);
+  id_base (id_kind, const char *); 
 
   hashval_t hashval;
   const char *id;
@@ -146,7 +146,7 @@ id_base::equal (const value_type *op1,
 
 static hash_table<id_base> operators;
 
-id_base::id_base (id_kind kind_, const char *id_)
+id_base::id_base (id_kind kind_, const char *id_) 
 {
   kind = kind_;
   id = id_;
@@ -218,8 +218,9 @@ struct predicate : public operand
 };
 
 struct e_operation {
-  e_operation (const char *id);
+  e_operation (const char *id, bool is_commutative_ = false);
   id_base *op;
+  bool is_commutative;
 };
 
 
@@ -258,9 +259,11 @@ struct capture : public operand
 };
 
 
-e_operation::e_operation (const char *id)
+e_operation::e_operation (const char *id, bool is_commutative_)
 {
   id_base tem (id_base::CODE, id);
+  is_commutative = is_commutative_;
+
   op = operators.find_with_hash (&tem, tem.hashval);
   if (op)
     return;
@@ -293,14 +296,14 @@ e_operation::e_operation (const char *id
 
 struct simplify {
   simplify (const char *name_,
-	    struct operand *match_, source_location match_location_,
+	    vec<operand *> matchers_, source_location match_location_,
 	    struct operand *ifexpr_, source_location ifexpr_location_,
 	    struct operand *result_, source_location result_location_)
-      : name (name_), match (match_), match_location (match_location_),
+      : name (name_), matchers (matchers_), match_location (match_location_),
       ifexpr (ifexpr_), ifexpr_location (ifexpr_location_),
       result (result_), result_location (result_location_) {}
   const char *name;
-  struct operand *match;
+  vec<operand *> matchers;  // vector to hold commutative expressions
   source_location match_location;
   struct operand *ifexpr;
   source_location ifexpr_location;
@@ -308,7 +311,148 @@ struct simplify {
   source_location result_location;
 };
 
+void
+print_operand (operand *o, FILE *f = stderr)
+{
+  if (o->type == operand::OP_CAPTURE)
+    {
+      capture *c = static_cast<capture *> (o);
+      fprintf (f, "@%s", (static_cast<capture *> (o))->where);
+      if (c->what)
+	{
+	  putc (':', f);
+	  print_operand (c->what, f);
+	  putc (' ', f);
+	}
+    }
+
+  else if (o->type == operand::OP_PREDICATE)
+    fprintf (f, "%s", (static_cast<predicate *> (o))->ident);
+
+  else if (o->type == operand::OP_C_EXPR)
+    fprintf (f, "c_expr");
+
+  else if (o->type == operand::OP_EXPR)
+    {
+      expr *e = static_cast<expr *> (o);
+      fprintf (f, "(%s ", e->operation->op->id);
+
+      for (unsigned i = 0; i < e->ops.length (); ++i)
+	{
+	  print_operand (e->ops[i], f);
+	  putc (' ', f);
+	}
+
+      putc (')', f);
+    }
+
+  else
+    gcc_unreachable ();
+}
+
+void
+print_matches (struct simplify *s, FILE *f = stderr)
+{
+  if (s->matchers.length () == 1)
+    return;
+
+  fprintf (f, "for expression: ");
+  print_operand (s->matchers[0], f);  // s->matchers[0] is equivalent to original expression
+  putc ('\n', f);
+
+  fprintf (f, "commutative expressions:\n");
+  for (unsigned i = 0; i < s->matchers.length (); ++i)
+    {
+      print_operand (s->matchers[i], f);
+      putc ('\n', f);
+    }
+}
+
+void
+cartesian_product (const vec< vec<operand *> >& ops_vector, vec< vec<operand *> >& result, vec<operand *>& v, unsigned n)
+{
+  if (n == ops_vector.length ())
+    {
+      vec<operand *> xv = v.copy (); 
+      result.safe_push (xv);
+      return;
+    }
+
+  for (unsigned i = 0; i < ops_vector[n].length (); ++i)
+    {
+      v[n] = ops_vector[n][i];
+      cartesian_product (ops_vector, result, v, n + 1);
+    }
+}
+ 
+void
+cartesian_product (const vec< vec<operand *> >& ops_vector, vec< vec<operand *> >& result, unsigned n_ops)
+{
+  vec<operand *> v = vNULL;
+// FIXME: this is done to resize v to length n_ops.
+  for (unsigned i = 0; i < n_ops; ++i)
+    v.safe_push (0);
+  cartesian_product (ops_vector, result, v, 0);
+}
+
+vec<operand *>
+commutate (operand *op)
+{
+  vec<operand *> ret = vNULL;
+
+  if (op->type == operand::OP_CAPTURE)
+    {
+      capture *c = static_cast<capture *> (op);
+      if (!c->what)
+	{
+	  ret.safe_push (op);
+	  return ret;
+	}
+      vec<operand *> v = commutate (c->what);
+      for (unsigned i = 0; i < v.length (); ++i)
+	{
+	  capture *nc = new capture (c->where, v[i]);
+	  ret.safe_push (nc);
+	}
+      return ret;	
+    }
+
+  if (op->type != operand::OP_EXPR)
+    {
+      ret.safe_push (op);
+      return ret;
+    }
+
+  expr *e = static_cast<expr *> (op);
+
+  vec< vec<operand *> > ops_vector = vNULL;
+  for (unsigned i = 0; i < e->ops.length (); ++i)
+    ops_vector.safe_push (commutate (e->ops[i]));
+
+  vec< vec<operand *> > result = vNULL;
+  cartesian_product (ops_vector, result, e->ops.length ());
+
+  for (unsigned i = 0; i < result.length (); ++i)
+    {
+      expr *ne = new expr (e->operation);
+      for (unsigned j = 0; j < result[i].length (); ++j)
+	ne->append_op (result[i][j]);
+      ret.safe_push (ne);
+    }
 
+  if (!e->operation->is_commutative)
+    return ret;
+
+  for (unsigned i = 0; i < result.length (); ++i)
+    {
+      expr *ne = new expr (e->operation);
+      for (unsigned j = result[i].length (); j; --j)  // result[i].length () is 2 since e->operation is binary
+	ne->append_op (result[i][j-1]);
+      ret.safe_push (ne);
+    }
+
+  return ret;
+}
 
 /* Code gen off the AST.  */
 
@@ -574,11 +718,15 @@ write_nary_simplifiers (FILE *f, vec<sim
     {
       simplify *s = simplifiers[i];
       /* ???  This means we can't capture the outermost expression.  */
-      if (s->match->type != operand::OP_EXPR)
+      for (unsigned i = 0; i < s->matchers.length (); ++i)
+	{
+	  operand *match = s->matchers[i];
+	  if (match->type != operand::OP_EXPR)
 	continue;
-      expr *e = static_cast <expr *> (s->match);
+	  expr *e = static_cast <expr *> (match);
       if (e->ops.length () != n)
 	continue;
+
       char fail_label[16];
       snprintf (fail_label, 16, "fail%d", label_cnt++);
       output_line_directive (f, s->match_location);
@@ -627,6 +775,7 @@ write_nary_simplifiers (FILE *f, vec<sim
       fprintf (f, "    }\n");
       fprintf (f, "%s:\n", fail_label);
     }
+    }
   fprintf (f, "  return false;\n");
   fprintf (f, "}\n");
 }
@@ -827,6 +976,25 @@ parse_expr (cpp_reader *r)
   expr *e = new expr (parse_operation (r));
   const cpp_token *token = peek (r);
   operand *op;
+  bool is_commutative = false;
+
+  if (token->type == CPP_COLON)
+    {
+      eat_token (r, CPP_COLON);
+      token = peek (r);
+      if (token->type == CPP_NAME
+	  && !(token->flags & PREV_WHITE))
+	{
+	  const char *s = (const char *)CPP_HASHNODE (token->val.node.node)->ident.str;
+	  eat_token (r, CPP_NAME);
+	  token = peek (r);
+	  if (s[0] == 'c' && !s[1])
+	    is_commutative = true;
+	  else
+	    fatal_at (token, "not implemented: predicates on expressions");
+	}
+    }
+
   if (token->type == CPP_ATSIGN
       && !(token->flags & PREV_WHITE))
     op = parse_capture (r, e);
@@ -847,6 +1015,13 @@ parse_expr (cpp_reader *r)
 		fatal_at (token, "got %d operands instead of the required %d",
 			  e->ops.length (), opr->get_required_nargs ());
 	    }
+	  if (is_commutative)
+	    {
+	      if (e->ops.length () == 2)
+		e->operation->is_commutative = true;
+	      else
+		fatal_at (token, "only binary operators or function with two arguments can be marked commutative");
+	    }
 	  return op;
 	}
       e->append_op (parse_op (r));
@@ -971,7 +1146,7 @@ parse_match_and_simplify (cpp_reader *r,
       ifexpr = parse_c_expr (r, CPP_OPEN_PAREN);
     }
   token = peek (r);
-  return new simplify (id, match, match_location,
+  return new simplify (id, commutate (match), match_location,
 		       ifexpr, ifexpr_location, parse_op (r), token->src_loc);
 }
 
@@ -1043,6 +1218,9 @@ main(int argc, char **argv)
     }
   while (1);
 
+  for (unsigned i = 0; i < simplifiers.length (); ++i)
+    print_matches (simplifiers[i]);
+
   write_gimple (stdout, simplifiers);
 
   cpp_finish (r, NULL);

Index Nav: [Date Index] [Subject Index] [Author Index] [Thread Index]
Message Nav: [Date Prev] [Date Next] [Thread Prev] [Thread Next]