https://gcc.gnu.org/g:2b473162d33e0f3fe31a7f098c745cb388a01aa4

commit 2b473162d33e0f3fe31a7f098c745cb388a01aa4
Author: Pierre-Emmanuel Patry <pierre-emmanuel.pa...@embecosm.com>
Date:   Mon Jan 20 13:49:25 2025 +0100

    Add environment capture to NR2
    
    The compiler was still relying on NR1 for closure captures when using nr2
    even though the resolver was not used and thus it's state empty.
    
    gcc/rust/ChangeLog:
    
            * resolve/rust-late-name-resolver-2.0.cc (Late::visit): Add 
environment
            collection.
            * resolve/rust-late-name-resolver-2.0.h: Add function prototype.
            * resolve/rust-name-resolver.cc (Resolver::get_captures): Add 
assertion
            to prevent NR2 usage with nr1 capture functions.
            * typecheck/rust-hir-type-check-expr.cc (TypeCheckExpr::visit): Use
            nr2 captures.
            * util/rust-hir-map.cc (Mappings::add_capture): Add function to
            register capture for a given closure.
            (Mappings::lookup_captures):  Add a function to lookup all captures
            available for a given closure.
            * util/rust-hir-map.h: Add function prototypes.
    
    Signed-off-by: Pierre-Emmanuel Patry <pierre-emmanuel.pa...@embecosm.com>

Diff:
---
 gcc/rust/resolve/rust-late-name-resolver-2.0.cc | 13 +++++++++++++
 gcc/rust/resolve/rust-late-name-resolver-2.0.h  |  2 ++
 gcc/rust/resolve/rust-name-resolver.cc          |  2 ++
 gcc/rust/typecheck/rust-hir-type-check-expr.cc  | 20 +++++++++++++++++++-
 gcc/rust/util/rust-hir-map.cc                   | 20 ++++++++++++++++++++
 gcc/rust/util/rust-hir-map.h                    |  5 +++++
 6 files changed, 61 insertions(+), 1 deletion(-)

diff --git a/gcc/rust/resolve/rust-late-name-resolver-2.0.cc 
b/gcc/rust/resolve/rust-late-name-resolver-2.0.cc
index 7c6948565202..1e7f9f1546cf 100644
--- a/gcc/rust/resolve/rust-late-name-resolver-2.0.cc
+++ b/gcc/rust/resolve/rust-late-name-resolver-2.0.cc
@@ -390,5 +390,18 @@ Late::visit (AST::GenericArg &arg)
   DefaultResolver::visit (arg);
 }
 
+void
+Late::visit (AST::ClosureExprInner &closure)
+{
+  auto vals = ctx.values.peek ().get_values ();
+  for (auto &val : vals)
+    {
+      ctx.mappings.add_capture (closure.get_node_id (),
+                               val.second.get_node_id ());
+    }
+
+  DefaultResolver::visit (closure);
+}
+
 } // namespace Resolver2_0
 } // namespace Rust
diff --git a/gcc/rust/resolve/rust-late-name-resolver-2.0.h 
b/gcc/rust/resolve/rust-late-name-resolver-2.0.h
index bdaae143e73d..3030261f10bf 100644
--- a/gcc/rust/resolve/rust-late-name-resolver-2.0.h
+++ b/gcc/rust/resolve/rust-late-name-resolver-2.0.h
@@ -21,6 +21,7 @@
 
 #include "rust-ast-full.h"
 #include "rust-default-resolver.h"
+#include "rust-expr.h"
 
 namespace Rust {
 namespace Resolver2_0 {
@@ -55,6 +56,7 @@ public:
   void visit (AST::StructStruct &) override;
   void visit (AST::GenericArgs &) override;
   void visit (AST::GenericArg &);
+  void visit (AST::ClosureExprInner &) override;
 
 private:
   /* Setup Rust's builtin types (u8, i32, !...) in the resolver */
diff --git a/gcc/rust/resolve/rust-name-resolver.cc 
b/gcc/rust/resolve/rust-name-resolver.cc
index 6b131ad374d5..31da593b86c8 100644
--- a/gcc/rust/resolve/rust-name-resolver.cc
+++ b/gcc/rust/resolve/rust-name-resolver.cc
@@ -674,6 +674,8 @@ Resolver::decl_needs_capture (NodeId decl_rib_node_id,
 const std::set<NodeId> &
 Resolver::get_captures (NodeId id) const
 {
+  rust_assert (!flag_name_resolution_2_0);
+
   auto it = closures_capture_mappings.find (id);
   rust_assert (it != closures_capture_mappings.end ());
   return it->second;
diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.cc 
b/gcc/rust/typecheck/rust-hir-type-check-expr.cc
index 2554a72dc2a7..356a960f3174 100644
--- a/gcc/rust/typecheck/rust-hir-type-check-expr.cc
+++ b/gcc/rust/typecheck/rust-hir-type-check-expr.cc
@@ -16,6 +16,7 @@
 // along with GCC; see the file COPYING3.  If not see
 // <http://www.gnu.org/licenses/>.
 
+#include "rust-system.h"
 #include "rust-tyty-call.h"
 #include "rust-hir-type-check-struct-field.h"
 #include "rust-hir-path-probe.h"
@@ -1599,7 +1600,24 @@ TypeCheckExpr::visit (HIR::ClosureExpr &expr)
 
   // generate the closure type
   NodeId closure_node_id = expr.get_mappings ().get_nodeid ();
-  const std::set<NodeId> &captures = resolver->get_captures (closure_node_id);
+
+  // Resolve closure captures
+
+  std::set<NodeId> captures;
+  if (flag_name_resolution_2_0)
+    {
+      auto &nr_ctx = const_cast<Resolver2_0::NameResolutionContext &> (
+       Resolver2_0::ImmutableNameResolutionContext::get ().resolver ());
+
+      if (auto opt_cap = nr_ctx.mappings.lookup_captures (closure_node_id))
+       for (auto cap : opt_cap.value ())
+         captures.insert (cap);
+    }
+  else
+    {
+      captures = resolver->get_captures (closure_node_id);
+    }
+
   infered = new TyTy::ClosureType (ref, id, ident, closure_args, result_type,
                                   subst_refs, captures);
 
diff --git a/gcc/rust/util/rust-hir-map.cc b/gcc/rust/util/rust-hir-map.cc
index 26e3ee134c14..4d2927281a21 100644
--- a/gcc/rust/util/rust-hir-map.cc
+++ b/gcc/rust/util/rust-hir-map.cc
@@ -1321,5 +1321,25 @@ Mappings::get_auto_traits ()
   return auto_traits;
 }
 
+void
+Mappings::add_capture (NodeId closure, NodeId definition)
+{
+  auto cap = captures.find (closure);
+  if (cap == captures.end ())
+    captures[closure] = {definition};
+  else
+    cap->second.push_back (definition);
+}
+
+tl::optional<std::vector<NodeId>>
+Mappings::lookup_captures (NodeId closure)
+{
+  auto cap = captures.find (closure);
+  if (cap == captures.end ())
+    return tl::nullopt;
+  else
+    return cap->second;
+}
+
 } // namespace Analysis
 } // namespace Rust
diff --git a/gcc/rust/util/rust-hir-map.h b/gcc/rust/util/rust-hir-map.h
index 177894de9f85..6f21f38b4491 100644
--- a/gcc/rust/util/rust-hir-map.h
+++ b/gcc/rust/util/rust-hir-map.h
@@ -344,6 +344,8 @@ public:
 
   void insert_auto_trait (HIR::Trait *trait);
   std::vector<HIR::Trait *> &get_auto_traits ();
+  void add_capture (NodeId closure, NodeId definition);
+  tl::optional<std::vector<NodeId>> lookup_captures (NodeId closure);
 
 private:
   Mappings ();
@@ -434,6 +436,9 @@ private:
 
   // AST mappings
   std::map<NodeId, AST::Item *> ast_item_mappings;
+
+  // Closure AST NodeId -> vector of Definition node ids
+  std::unordered_map<NodeId, std::vector<NodeId>> captures;
 };
 
 } // namespace Analysis

Reply via email to