On Fri, Jun 20, 2014 at 3:02 AM, Prathamesh Kulkarni
<bilbotheelffri...@gmail.com> wrote:
>
> On Fri, Jun 20, 2014 at 2:53 AM, Prathamesh Kulkarni
> <bilbotheelffri...@gmail.com> wrote:
> > Hi,
> >     The attached patch attempts to generate commutative variants for
> > a given expression.
> >
> > Example:
> > For the AST: (PLUS_EXPR (PLUS_EXPR @0 @1) @2),
> >
> > the commutative variants are:
> > (PLUS_EXPR (PLUS_EXPR @0 @1 ) @2 )
> > (PLUS_EXPR (PLUS_EXPR @1 @0 ) @2 )
> > (PLUS_EXPR @2 (PLUS_EXPR @0 @1 ) )
> > (PLUS_EXPR @2 (PLUS_EXPR @1 @0 ) )
> >
> >
> > * Basic Idea:
> > Consider expression e with two operands o0, and o1,
> > and expr-code denoting expression's code (plus/mult, etc.)
> >
> > Commutative variants are stored in vector (vec<operand *>).
> >
> > vec<operand *>
> > commutative (e)
> > {
> >   if (e is not commutative)
> >     return [e];  // vector with only one expression
> >
> >   v1 = commutative (o0);
> >   v2 = commutative (o1);
> >   ret = []
> >
> >   for i = 0 ... v1.length ()
> >     for j = 0 ... v2.length ()
> >       {
> >         ne = new expr with <expr-code> and operands: v1[i], v2[j];
> >         append ne to ret;
> >       }
> >
> >   for i = 0 ... v2.length ()
> >     for j = 0 ... v1.length ()
> >       {
> >         ne = new expr with <expr-code> and operand: v2[i], v1[j];
> >         append ne to ret
> >       }
> >
> >   return ret;
> > }
> >
> > Example:
> > (plus (plus @0 @1) (plus @2 @3))
> > generates following commutative variants:
> oops.
> the pattern given to genmatch was (bogus):
> (plus (plus @0 @1) (plus @0 @3))
> >
> > (PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @0 @3 ) )
> > (PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @3 @0 ) )
> > (PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @0 @3 ) )
> > (PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @3 @0 ) )
> > (PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @0 @1 ) )
> > (PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @1 @0 ) )
> > (PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @0 @1 ) )
> > (PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @1 @0 ) )
> >
> >
> > * Decide which operators are commutative.
> > Currently I assume all PLUS_EXPR and MULT_EXPR are true.
> s/true/commutative
There's a bug in the previous patch - if the operator is not
commutative, it does not try
for generating commutative variants of it's operands, and does not
commutate captured
expression (.what).
example:
(negate (plus @0 @1)) has two commutative variants (including the
original pattern),
but the patch does not generate them, since negate is not commutative.

The attached patch fixes that. As a quick hack i handled each operator
class (unary, binary, ternary)
specially (commutate_unary, commutate_binary, commutate_ternary).
Ideally it should be unified
(I tried that way, but it was segfaulting). I will try and come up
with a better way.
Also the current patch won't work for built-in functions/operators
having more than 3 operands.
(max we have 3 so far in match.pd for cond, I hope this doesn't come
"in the way").

With the current patch,
for the expression (negate (plus @0 @1))
it generates following commutative variants:
(negate (plus @0 @1))
(negate (plus @1 @0))

and for the following pattern (involving captured expression):
(negate (plus@0 @1 @2))
it generates following variants:
(negate (plus@0 @1 @2))
(negate (plus@0 @2 @1))

* generates multiple matching patterns
Since at AST-level we do not test for captures equality (true/match),
it treats both of the captures
as different, even though they are same.
example: the following also expression has 2 variants generated
(BUILT_IN_SQRT (mult @0 @0))
commutative variants:
(BUILT_IN_SQRT (mult @0 @0))
(BUILT_IN_SQRT (mult @0 @0))
I guess this won't really be a problem with decision tree. If we decide to emit
warning, we should warn only for user defined patterns, and not generated ones.

* syntax for commutative operators
Currently, I assume any PLUS_EXPR / MULT_EXPR to be commutative.
I guess we should have syntax for users marking an operator to be commutative.

sth like:
a) op:c
b) op "c"
c) op!
d) op "commutative"

Or any other, that you would like -:)

* cloning AST nodes
Currently I do not do a deep-copy of the AST for each distinct
commutative variant, so the nodes
are shared for different expressions, which are commutative variants
of the original expression.
Is this OK, or should we clone each AST node, so that each expression
is represented by a distinct AST ?
cloning shall eat up space, while sharing shall require more careful
memory management (freeing one ast, may also
free nodes of other expression).

Thanks and Regards,
Prathamesh

> > Maybe we should add syntax to mark a particular operator as commutative ?
> >
> > * Cloning AST nodes
> > While creating another AST that represents one of
> > the commutative variants, should we clone the AST nodes,
> > so that all commutative variants have distinct AST nodes ?
> > That's not done currently, and AST nodes are shared amongst
> > different commutative expressions, and we end up with a DAG,
> > for a set of commutative expressions.
> >
> > Thanks and Regards,
> > Prathamesh
Index: gcc/genmatch.c
===================================================================
--- gcc/genmatch.c	(revision 211732)
+++ gcc/genmatch.c	(working copy)
@@ -293,14 +293,14 @@ e_operation::e_operation (const char *id
 
 struct simplify {
   simplify (const char *name_,
-	    struct operand *match_, source_location match_location_,
+	    vec<operand *> matchers_, source_location match_location_,
 	    struct operand *ifexpr_, source_location ifexpr_location_,
 	    struct operand *result_, source_location result_location_)
-      : name (name_), match (match_), match_location (match_location_),
+      : name (name_), matchers (matchers_), match_location (match_location_),
       ifexpr (ifexpr_), ifexpr_location (ifexpr_location_),
       result (result_), result_location (result_location_) {}
   const char *name;
-  struct operand *match;
+  vec<operand *> matchers;  // vector to hold commutative expressions
   source_location match_location;
   struct operand *ifexpr;
   source_location ifexpr_location;
@@ -308,7 +308,193 @@ struct simplify {
   source_location result_location;
 };
 
+void
+print_operand (operand *o, FILE *f = stderr)
+{
+  if (o->type == operand::OP_CAPTURE)
+    {
+      capture *c = static_cast<capture *> (o);
+      fprintf (f, "@%s", (static_cast<capture *> (o))->where);
+      if (c->what)
+	{
+	  putc (':', f);
+	  print_operand (c->what, f);
+	  putc (' ', f);
+	}
+    }
+
+  else if (o->type == operand::OP_PREDICATE)
+    fprintf (f, "%s", (static_cast<predicate *> (o))->ident);
+
+  else if (o->type == operand::OP_C_EXPR)
+    fprintf (f, "c_expr");
+
+  else if (o->type == operand::OP_EXPR)
+    {
+      expr *e = static_cast<expr *> (o);
+      fprintf (f, "(%s ", e->operation->op->id);
+
+      for (unsigned i = 0; i < e->ops.length (); ++i)
+	{
+	  print_operand (e->ops[i], f);
+	  putc (' ', f);
+	}
+
+      putc (')', f);
+    }
+
+  else
+    gcc_unreachable ();
+}
+
+void
+print_matches (struct simplify *s, FILE *f = stderr)
+{
+  if (s->matchers.length () == 1)
+    return;
+
+  fprintf (f, "for expression: ");
+  print_operand (s->matchers[0], f);  // s->matchers[0] is equivalent to original expression
+  putc ('\n', f);
+
+  fprintf (f, "commutative expressions:\n");
+  for (unsigned i = 0; i < s->matchers.length (); ++i)
+    {
+      print_operand (s->matchers[i], f);
+      putc ('\n', f);
+    }
+}
+
+bool
+is_commutative (expr *e)
+{
+  if (e->operation->op->kind != id_base::CODE)
+    return false;
+
+  operator_id *op_id = static_cast <operator_id *> (e->operation->op);
+  enum tree_code code = op_id->code;
+
+  if (code == PLUS_EXPR || code == MULT_EXPR)
+    return true;
+
+  return false;
+}
+
+vec<operand *> commutate (operand *);
+
+vec<operand *>
+commutate_unary (expr *e)
+{
+  vec<operand *> ret = vNULL;
+
+  vec<operand *> v1 = commutate (e->ops[0]);
+  for (unsigned i = 0; i < v1.length (); ++i)
+    {
+      expr *ne = new expr (e->operation);
+      ne->append_op (v1[i]);
+      ret.safe_push (ne);
+    }
+
+  return ret;
+}
+
+vec<operand *>
+commutate_binary (expr *e)
+{
+  vec<operand *> ret = vNULL;
+  vec<operand *> v1 = commutate (e->ops[0]);
+  vec<operand *> v2 = commutate (e->ops[1]);
+
+  unsigned i, j;
+
+  for (i = 0; i < v1.length (); ++i)
+    for (j = 0; j < v2.length (); ++j)
+      {
+	expr *ne = new expr (e->operation);
+	ne->append_op (v1[i]);
+	ne->append_op (v2[j]);
+	ret.safe_push (ne);
+      }
+
+  if (is_commutative (e))
+    for (i = 0; i < v2.length (); ++i)
+      for (j = 0; j < v2.length (); ++j)
+	{
+	  expr *ne = new expr (e->operation);
+	  ne->append_op (v2[i]);
+	  ne->append_op (v1[j]);
+	  ret.safe_push (ne);
+	}
+
+  return ret;
+}
+
+vec<operand *>
+commutate_ternary (expr *e)
+{
+  vec<operand *> ret = vNULL;
+  vec<operand *> v1 = commutate (e->ops[0]);
+  vec<operand *> v2 = commutate (e->ops[1]);
+  vec<operand *> v3 = commutate (e->ops[2]);
+
+  for (unsigned i = 0; i < v1.length (); ++i)
+    for (unsigned j = 0; j < v2.length (); ++j)
+      for (unsigned k = 0; k < v3.length (); ++k)
+	{
+	  expr *ne = new expr (e->operation);
+	  ne->append_op (v1[i]);
+	  ne->append_op (v2[j]);
+	  ne->append_op (v3[k]);
+	  ret.safe_push (ne);
+	}
+
+  return ret;
+}
+  
+
+vec<operand *>
+commutate (operand *op)
+{
+  vec<operand *> ret = vNULL;
+
+  if (op->type == operand::OP_CAPTURE)
+    {
+      capture *c = static_cast<capture *> (op);
+      if (!c->what)
+	{
+	  ret.safe_push (op);
+	  return ret;
+	}
+      vec<operand *> v = commutate (c->what);
+      for (unsigned i = 0; i < v.length (); ++i)
+	{
+	  capture *nc = new capture (c->where, v[i]);
+	  ret.safe_push (nc);
+	}
+      return ret;	
+    }
+
+  if (op->type != operand::OP_EXPR)
+    {
+      ret.safe_push (op);
+      return ret;
+    }
+
+  expr *e = static_cast<expr *> (op);
+  unsigned n = e->ops.length ();
 
+  if (n == 1)
+    return commutate_unary (e);
+  else if (n == 2)
+    return commutate_binary (e);
+  else if (n == 3)
+    return commutate_ternary (e); 
+  else  // FIXME: does not commutate if node has >3 children
+    {
+      ret.safe_push (op);
+      return ret;
+    }
+}
 
 /* Code gen off the AST.  */
 
@@ -574,11 +760,15 @@ write_nary_simplifiers (FILE *f, vec<sim
     {
       simplify *s = simplifiers[i];
       /* ???  This means we can't capture the outermost expression.  */
-      if (s->match->type != operand::OP_EXPR)
+      for (unsigned i = 0; i < s->matchers.length (); ++i)
+	{
+	  operand *match = s->matchers[i];
+	  if (match->type != operand::OP_EXPR)
 	continue;
-      expr *e = static_cast <expr *> (s->match);
+	  expr *e = static_cast <expr *> (match);
       if (e->ops.length () != n)
 	continue;
+
       char fail_label[16];
       snprintf (fail_label, 16, "fail%d", label_cnt++);
       output_line_directive (f, s->match_location);
@@ -627,6 +817,7 @@ write_nary_simplifiers (FILE *f, vec<sim
       fprintf (f, "    }\n");
       fprintf (f, "%s:\n", fail_label);
     }
+    }
   fprintf (f, "  return false;\n");
   fprintf (f, "}\n");
 }
@@ -971,7 +1162,7 @@ parse_match_and_simplify (cpp_reader *r,
       ifexpr = parse_c_expr (r, CPP_OPEN_PAREN);
     }
   token = peek (r);
-  return new simplify (id, match, match_location,
+  return new simplify (id, commutate (match), match_location,
 		       ifexpr, ifexpr_location, parse_op (r), token->src_loc);
 }
 
@@ -1043,6 +1234,9 @@ main(int argc, char **argv)
     }
   while (1);
 
+  for (unsigned i = 0; i < simplifiers.length (); ++i)
+    print_matches (simplifiers[i]);
+
   write_gimple (stdout, simplifiers);
 
   cpp_finish (r, NULL);

Reply via email to