From: Ju-Zhe Zhong <juzhe.zh...@rivai.ai>

gcc/ChangeLog:

        * config/riscv/riscv-vsetvl.cc (can_backward_propagate_p): Fix for null 
iter_bb.
        (vector_insn_info::set_demand_info): New function.
        (pass_vsetvl::emit_local_forward_vsetvls): Adjust for refinement of 
Phase 3.
        (pass_vsetvl::merge_successors): Ditto.
        (pass_vsetvl::compute_global_backward_infos): Ditto.
        (pass_vsetvl::backward_demand_fusion): Ditto.
        (pass_vsetvl::forward_demand_fusion): Ditto.
        (pass_vsetvl::demand_fusion): New function.
        (pass_vsetvl::lazy_vsetvl): Adjust for refinement of phase 3.
        * config/riscv/riscv-vsetvl.h: New function declaration.

---
 gcc/config/riscv/riscv-vsetvl.cc | 138 ++++++++++++++++++++++++++++---
 gcc/config/riscv/riscv-vsetvl.h  |   1 +
 2 files changed, 128 insertions(+), 11 deletions(-)

diff --git a/gcc/config/riscv/riscv-vsetvl.cc b/gcc/config/riscv/riscv-vsetvl.cc
index 52f0195980a..d42cfa91d63 100644
--- a/gcc/config/riscv/riscv-vsetvl.cc
+++ b/gcc/config/riscv/riscv-vsetvl.cc
@@ -43,7 +43,7 @@ along with GCC; see the file COPYING3.  If not see
     -  Phase 2 - Emit vsetvl instructions within each basic block according to
        demand, compute and save ANTLOC && AVLOC of each block.
 
-    -  Phase 3 - Backward demanded info propagation and fusion across blocks.
+    -  Phase 3 - Backward && forward demanded info propagation and fusion 
across blocks.
 
     -  Phase 4 - Lazy code motion including: compute local properties,
        pre_edge_lcm and vsetvl insertion && delete edges for LCM results.
@@ -434,8 +434,12 @@ can_backward_propagate_p (const function_info *ssa, const 
basic_block cfg_bb,
        set_info *ultimate_def = look_through_degenerate_phi (set);
        const basic_block ultimate_bb = ultimate_def->bb ()->cfg_bb ();
        FOR_BB_BETWEEN (iter_bb, ultimate_bb, def->bb ()->cfg_bb (), next_bb)
-         if (iter_bb->index == cfg_bb->index)
-           return true;
+         {
+           if (!iter_bb)
+             break;
+           if (iter_bb->index == cfg_bb->index)
+             return true;
+         }
 
        return false;
       };
@@ -1172,6 +1176,19 @@ vector_insn_info::parse_insn (insn_info *insn)
     m_demands[DEMAND_MASK_POLICY] = true;
 }
 
+void
+vector_insn_info::set_demand_info (const vector_insn_info &other)
+{
+  set_sew (other.get_sew ());
+  set_vlmul (other.get_vlmul ());
+  set_ratio (other.get_ratio ());
+  set_ta (other.get_ta ());
+  set_ma (other.get_ma ());
+  set_avl_info (other.get_avl_info ());
+  for (size_t i = 0; i < NUM_DEMAND; i++)
+    m_demands[i] = other.demand_p ((enum demand_type) i);
+}
+
 void
 vector_insn_info::demand_vl_vtype ()
 {
@@ -1691,8 +1708,10 @@ private:
   void emit_local_forward_vsetvls (const bb_info *);
 
   /* Phase 3.  */
-  void merge_successors (const basic_block, const basic_block);
-  void compute_global_backward_infos (void);
+  bool merge_successors (const basic_block, const basic_block);
+  bool backward_demand_fusion (void);
+  bool forward_demand_fusion (void);
+  void demand_fusion (void);
 
   /* Phase 4.  */
   void prune_expressions (void);
@@ -1866,7 +1885,7 @@ pass_vsetvl::emit_local_forward_vsetvls (const bb_info 
*bb)
 }
 
 /* Merge all successors of Father except child node.  */
-void
+bool
 pass_vsetvl::merge_successors (const basic_block father,
                               const basic_block child)
 {
@@ -1877,7 +1896,8 @@ pass_vsetvl::merge_successors (const basic_block father,
              || father_info.local_dem.empty_p ());
   gcc_assert (father_info.reaching_out.dirty_p ()
              || father_info.reaching_out.empty_p ());
-
+  
+  bool changed_p = false;
   FOR_EACH_EDGE (e, ei, father->succs)
     {
       const basic_block succ = e->dest;
@@ -1907,12 +1927,15 @@ pass_vsetvl::merge_successors (const basic_block father,
 
       father_info.local_dem = new_info;
       father_info.reaching_out = new_info;
+      changed_p = true;
     }
+
+  return changed_p;
 }
 
 /* Compute global backward demanded info.  */
-void
-pass_vsetvl::compute_global_backward_infos (void)
+bool
+pass_vsetvl::backward_demand_fusion (void)
 {
   /* We compute global infos by backward propagation.
      We want to have better performance in these following cases:
@@ -1939,6 +1962,7 @@ pass_vsetvl::compute_global_backward_infos (void)
           We backward propagate the first VSETVL into e32,mf2 so that we
           could be able to eliminate the second VSETVL in LCM.  */
 
+  bool changed_p = false;
   for (const bb_info *bb : crtl->ssa->reverse_bbs ())
     {
       basic_block cfg_bb = bb->cfg_bb ();
@@ -1982,9 +2006,10 @@ pass_vsetvl::compute_global_backward_infos (void)
                  block_info.reaching_out.set_dirty ();
                  block_info.reaching_out.set_dirty_pat (new_pat);
                  block_info.local_dem = block_info.reaching_out;
+                 changed_p = true;
                }
 
-             merge_successors (e->src, cfg_bb);
+             changed_p |= merge_successors (e->src, cfg_bb);
            }
          else if (block_info.reaching_out.dirty_p ())
            {
@@ -2011,6 +2036,7 @@ pass_vsetvl::compute_global_backward_infos (void)
              new_info.set_dirty_pat (new_pat);
              block_info.local_dem = new_info;
              block_info.reaching_out = new_info;
+             changed_p = true;
            }
          else
            {
@@ -2031,9 +2057,99 @@ pass_vsetvl::compute_global_backward_infos (void)
              if (block_info.local_dem == block_info.reaching_out)
                block_info.local_dem = new_info;
              block_info.reaching_out = new_info;
+             changed_p = true;
+           }
+       }
+    }
+  return changed_p;
+}
+
+/* Compute global forward demanded info.  */
+bool
+pass_vsetvl::forward_demand_fusion (void)
+{
+  /* Enhance the global information propagation especially
+     backward propagation miss the propagation.
+     Consider such case:
+
+                       bb0
+                       (TU)
+                      /   \
+                    bb1   bb2
+                    (TU)  (ANY)
+  existing edge -----> \    / (TU) <----- LCM create this edge.
+                       bb3
+                       (TU)
+
+     Base on the situation, LCM fails to eliminate the VSETVL instruction and
+     insert an edge from bb2 to bb3 since we can't backward propagate bb3 into
+     bb2. To avoid this confusing LCM result and non-optimal codegen, we should
+     forward propagate information from bb0 to bb2 which is friendly to LCM.  
*/
+  bool changed_p = false;
+  for (const bb_info *bb : crtl->ssa->bbs ())
+    {
+      basic_block cfg_bb = bb->cfg_bb ();
+      const auto &prop
+       = m_vector_manager->vector_block_infos[cfg_bb->index].reaching_out;
+
+      /* If there is nothing to propagate, just skip it.  */
+      if (!prop.valid_or_dirty_p ())
+       continue;
+
+      edge e;
+      edge_iterator ei;
+      /* Forward propagate to each successor.  */
+      FOR_EACH_EDGE (e, ei, cfg_bb->succs)
+       {
+         auto &local_dem
+           = m_vector_manager->vector_block_infos[e->dest->index].local_dem;
+         auto &reaching_out
+           = m_vector_manager->vector_block_infos[e->dest->index].reaching_out;
+
+         /* It's quite obvious, we don't need to propagate itself.  */
+         if (e->dest->index == cfg_bb->index)
+           continue;
+
+         /* If there is nothing to propagate, just skip it.  */
+         if (!local_dem.valid_or_dirty_p ())
+           continue;
+
+         if (prop > local_dem)
+           {
+             if (local_dem.dirty_p ())
+               {
+                 gcc_assert (local_dem == reaching_out);
+                 rtx dirty_pat
+                   = gen_vsetvl_pat (prop.get_insn ()->rtl (), prop);
+                 local_dem = prop;
+                 local_dem.set_dirty ();
+                 local_dem.set_dirty_pat (dirty_pat);
+                 reaching_out = local_dem;
+               }
+             else
+               {
+                 if (reaching_out == local_dem)
+                   reaching_out.set_demand_info (prop);
+                 local_dem.set_demand_info (prop);
+                 change_vsetvl_insn (local_dem.get_insn (), prop);
+               }
+             changed_p = true;
            }
        }
     }
+  return changed_p;
+}
+
+void
+pass_vsetvl::demand_fusion (void)
+{
+  bool changed_p = true;
+  while (changed_p)
+    {
+      changed_p = false;
+      changed_p |= backward_demand_fusion ();
+      changed_p |= forward_demand_fusion ();
+    }
 
   if (dump_file)
     {
@@ -2519,7 +2635,7 @@ pass_vsetvl::lazy_vsetvl (void)
   /* Phase 3 - Propagate demanded info across blocks.  */
   if (dump_file)
     fprintf (dump_file, "\nPhase 3: Demands propagation across blocks\n");
-  compute_global_backward_infos ();
+  demand_fusion ();
   if (dump_file)
     m_vector_manager->dump (dump_file);
 
diff --git a/gcc/config/riscv/riscv-vsetvl.h b/gcc/config/riscv/riscv-vsetvl.h
index c8218a6ff00..33481a87163 100644
--- a/gcc/config/riscv/riscv-vsetvl.h
+++ b/gcc/config/riscv/riscv-vsetvl.h
@@ -273,6 +273,7 @@ public:
   void set_dirty () { m_state = DIRTY; }
   void set_dirty_pat (rtx pat) { m_dirty_pat = pat; }
   void set_insn (rtl_ssa::insn_info *insn) { m_insn = insn; }
+  void set_demand_info (const vector_insn_info &);
 
   bool demand_p (enum demand_type type) const { return m_demands[type]; }
   void demand (enum demand_type type) { m_demands[type] = true; }
-- 
2.36.3

Reply via email to