Treat a union as a Struct variant like a tuple struct.  Add an
iterator and get_identifier functions to the AST Union class.  Same
for the HIR Union class, plus a get_generics_params method. Add a
get_is_union method tot the ADTType.
---
 gcc/rust/ast/rust-item.h                      | 11 ++++
 gcc/rust/hir/rust-ast-lower-item.h            | 51 +++++++++++++++++
 gcc/rust/hir/rust-ast-lower-stmt.h            | 53 ++++++++++++++++++
 gcc/rust/hir/tree/rust-hir-item.h             | 16 ++++++
 gcc/rust/resolve/rust-ast-resolve-item.h      | 22 ++++++++
 gcc/rust/resolve/rust-ast-resolve-stmt.h      | 32 +++++++++++
 gcc/rust/resolve/rust-ast-resolve-toplevel.h  | 14 +++++
 gcc/rust/typecheck/rust-hir-type-check-stmt.h | 55 ++++++++++++++++++-
 .../typecheck/rust-hir-type-check-toplevel.h  | 54 +++++++++++++++++-
 gcc/rust/typecheck/rust-hir-type-check.cc     | 12 +++-
 gcc/rust/typecheck/rust-tycheck-dump.h        |  6 ++
 gcc/rust/typecheck/rust-tyty.cc               |  4 +-
 gcc/rust/typecheck/rust-tyty.h                | 12 ++--
 13 files changed, 331 insertions(+), 11 deletions(-)

diff --git a/gcc/rust/ast/rust-item.h b/gcc/rust/ast/rust-item.h
index 30cab0ed900..1e928e8111a 100644
--- a/gcc/rust/ast/rust-item.h
+++ b/gcc/rust/ast/rust-item.h
@@ -2489,6 +2489,15 @@ public:
   std::vector<StructField> &get_variants () { return variants; }
   const std::vector<StructField> &get_variants () const { return variants; }
 
+  void iterate (std::function<bool (StructField &)> cb)
+  {
+    for (auto &variant : variants)
+      {
+       if (!cb (variant))
+         return;
+      }
+  }
+
   std::vector<std::unique_ptr<GenericParam> > &get_generic_params ()
   {
     return generic_params;
@@ -2505,6 +2514,8 @@ public:
     return where_clause;
   }
 
+  Identifier get_identifier () const { return union_name; }
+
 protected:
   /* Use covariance to implement clone function as returning this object
    * rather than base */
diff --git a/gcc/rust/hir/rust-ast-lower-item.h 
b/gcc/rust/hir/rust-ast-lower-item.h
index 5ba59183179..b6af00f6b54 100644
--- a/gcc/rust/hir/rust-ast-lower-item.h
+++ b/gcc/rust/hir/rust-ast-lower-item.h
@@ -192,6 +192,57 @@ public:
                               struct_decl.get_locus ());
   }
 
+  void visit (AST::Union &union_decl) override
+  {
+    std::vector<std::unique_ptr<HIR::GenericParam> > generic_params;
+    if (union_decl.has_generics ())
+      {
+       generic_params
+         = lower_generic_params (union_decl.get_generic_params ());
+      }
+
+    std::vector<std::unique_ptr<HIR::WhereClauseItem> > where_clause_items;
+    HIR::WhereClause where_clause (std::move (where_clause_items));
+    HIR::Visibility vis = HIR::Visibility::create_public ();
+
+    std::vector<HIR::StructField> variants;
+    union_decl.iterate ([&] (AST::StructField &variant) mutable -> bool {
+      HIR::Visibility vis = HIR::Visibility::create_public ();
+      HIR::Type *type
+       = ASTLoweringType::translate (variant.get_field_type ().get ());
+
+      auto crate_num = mappings->get_current_crate ();
+      Analysis::NodeMapping mapping (crate_num, variant.get_node_id (),
+                                    mappings->get_next_hir_id (crate_num),
+                                    mappings->get_next_localdef_id (
+                                      crate_num));
+
+      HIR::StructField translated_variant (mapping, variant.get_field_name (),
+                                          std::unique_ptr<HIR::Type> (type),
+                                          vis, variant.get_locus (),
+                                          variant.get_outer_attrs ());
+      variants.push_back (std::move (translated_variant));
+      return true;
+    });
+
+    auto crate_num = mappings->get_current_crate ();
+    Analysis::NodeMapping mapping (crate_num, union_decl.get_node_id (),
+                                  mappings->get_next_hir_id (crate_num),
+                                  mappings->get_next_localdef_id (crate_num));
+
+    translated
+      = new HIR::Union (mapping, union_decl.get_identifier (), vis,
+                       std::move (generic_params), std::move (where_clause),
+                       std::move (variants), union_decl.get_outer_attrs (),
+                       union_decl.get_locus ());
+
+    mappings->insert_defid_mapping (mapping.get_defid (), translated);
+    mappings->insert_hir_item (mapping.get_crate_num (), mapping.get_hirid (),
+                              translated);
+    mappings->insert_location (crate_num, mapping.get_hirid (),
+                              union_decl.get_locus ());
+  }
+
   void visit (AST::StaticItem &var) override
   {
     HIR::Visibility vis = HIR::Visibility::create_public ();
diff --git a/gcc/rust/hir/rust-ast-lower-stmt.h 
b/gcc/rust/hir/rust-ast-lower-stmt.h
index 9df6b746bb7..2e97ca63a13 100644
--- a/gcc/rust/hir/rust-ast-lower-stmt.h
+++ b/gcc/rust/hir/rust-ast-lower-stmt.h
@@ -215,6 +215,59 @@ public:
                               struct_decl.get_locus ());
   }
 
+  void visit (AST::Union &union_decl) override
+  {
+    std::vector<std::unique_ptr<HIR::GenericParam> > generic_params;
+    if (union_decl.has_generics ())
+      {
+       generic_params
+         = lower_generic_params (union_decl.get_generic_params ());
+      }
+
+    std::vector<std::unique_ptr<HIR::WhereClauseItem> > where_clause_items;
+    HIR::WhereClause where_clause (std::move (where_clause_items));
+    HIR::Visibility vis = HIR::Visibility::create_public ();
+
+    std::vector<HIR::StructField> variants;
+    union_decl.iterate ([&] (AST::StructField &variant) mutable -> bool {
+      HIR::Visibility vis = HIR::Visibility::create_public ();
+      HIR::Type *type
+       = ASTLoweringType::translate (variant.get_field_type ().get ());
+
+      auto crate_num = mappings->get_current_crate ();
+      Analysis::NodeMapping mapping (crate_num, variant.get_node_id (),
+                                    mappings->get_next_hir_id (crate_num),
+                                    mappings->get_next_localdef_id (
+                                      crate_num));
+
+      // FIXME
+      // AST::StructField is missing Location info
+      Location variant_locus;
+      HIR::StructField translated_variant (mapping, variant.get_field_name (),
+                                          std::unique_ptr<HIR::Type> (type),
+                                          vis, variant_locus,
+                                          variant.get_outer_attrs ());
+      variants.push_back (std::move (translated_variant));
+      return true;
+    });
+
+    auto crate_num = mappings->get_current_crate ();
+    Analysis::NodeMapping mapping (crate_num, union_decl.get_node_id (),
+                                  mappings->get_next_hir_id (crate_num),
+                                  mappings->get_next_localdef_id (crate_num));
+
+    translated
+      = new HIR::Union (mapping, union_decl.get_identifier (), vis,
+                       std::move (generic_params), std::move (where_clause),
+                       std::move (variants), union_decl.get_outer_attrs (),
+                       union_decl.get_locus ());
+
+    mappings->insert_hir_stmt (mapping.get_crate_num (), mapping.get_hirid (),
+                              translated);
+    mappings->insert_location (crate_num, mapping.get_hirid (),
+                              union_decl.get_locus ());
+  }
+
   void visit (AST::EmptyStmt &empty) override
   {
     auto crate_num = mappings->get_current_crate ();
diff --git a/gcc/rust/hir/tree/rust-hir-item.h 
b/gcc/rust/hir/tree/rust-hir-item.h
index e7e110fda92..cfe45d73d85 100644
--- a/gcc/rust/hir/tree/rust-hir-item.h
+++ b/gcc/rust/hir/tree/rust-hir-item.h
@@ -1989,10 +1989,26 @@ public:
   Union (Union &&other) = default;
   Union &operator= (Union &&other) = default;
 
+  std::vector<std::unique_ptr<GenericParam> > &get_generic_params ()
+  {
+    return generic_params;
+  }
+
+  Identifier get_identifier () const { return union_name; }
+
   Location get_locus () const { return locus; }
 
   void accept_vis (HIRVisitor &vis) override;
 
+  void iterate (std::function<bool (StructField &)> cb)
+  {
+    for (auto &variant : variants)
+      {
+       if (!cb (variant))
+         return;
+      }
+  }
+
 protected:
   /* Use covariance to implement clone function as returning this object
    * rather than base */
diff --git a/gcc/rust/resolve/rust-ast-resolve-item.h 
b/gcc/rust/resolve/rust-ast-resolve-item.h
index 0714f5d5706..54f1fe15533 100644
--- a/gcc/rust/resolve/rust-ast-resolve-item.h
+++ b/gcc/rust/resolve/rust-ast-resolve-item.h
@@ -260,6 +260,28 @@ public:
     resolver->get_type_scope ().pop ();
   }
 
+  void visit (AST::Union &union_decl) override
+  {
+    NodeId scope_node_id = union_decl.get_node_id ();
+    resolver->get_type_scope ().push (scope_node_id);
+
+    if (union_decl.has_generics ())
+      {
+       for (auto &generic : union_decl.get_generic_params ())
+         {
+           ResolveGenericParam::go (generic.get (), union_decl.get_node_id ());
+         }
+      }
+
+    union_decl.iterate ([&] (AST::StructField &field) mutable -> bool {
+      ResolveType::go (field.get_field_type ().get (),
+                      union_decl.get_node_id ());
+      return true;
+    });
+
+    resolver->get_type_scope ().pop ();
+  }
+
   void visit (AST::StaticItem &var) override
   {
     ResolveType::go (var.get_type ().get (), var.get_node_id ());
diff --git a/gcc/rust/resolve/rust-ast-resolve-stmt.h 
b/gcc/rust/resolve/rust-ast-resolve-stmt.h
index 210a9fc047d..b6044327b27 100644
--- a/gcc/rust/resolve/rust-ast-resolve-stmt.h
+++ b/gcc/rust/resolve/rust-ast-resolve-stmt.h
@@ -131,6 +131,38 @@ public:
     resolver->get_type_scope ().pop ();
   }
 
+  void visit (AST::Union &union_decl) override
+  {
+    auto path = CanonicalPath::new_seg (union_decl.get_node_id (),
+                                       union_decl.get_identifier ());
+    resolver->get_type_scope ().insert (
+      path, union_decl.get_node_id (), union_decl.get_locus (), false,
+      [&] (const CanonicalPath &, NodeId, Location locus) -> void {
+       RichLocation r (union_decl.get_locus ());
+       r.add_range (locus);
+       rust_error_at (r, "redefined multiple times");
+      });
+
+    NodeId scope_node_id = union_decl.get_node_id ();
+    resolver->get_type_scope ().push (scope_node_id);
+
+    if (union_decl.has_generics ())
+      {
+       for (auto &generic : union_decl.get_generic_params ())
+         {
+           ResolveGenericParam::go (generic.get (), union_decl.get_node_id ());
+         }
+      }
+
+    union_decl.iterate ([&] (AST::StructField &field) mutable -> bool {
+      ResolveType::go (field.get_field_type ().get (),
+                      union_decl.get_node_id ());
+      return true;
+    });
+
+    resolver->get_type_scope ().pop ();
+  }
+
   void visit (AST::Function &function) override
   {
     auto path = ResolveFunctionItemToCanonicalPath::resolve (function);
diff --git a/gcc/rust/resolve/rust-ast-resolve-toplevel.h 
b/gcc/rust/resolve/rust-ast-resolve-toplevel.h
index 9abbb18e080..4df0467b994 100644
--- a/gcc/rust/resolve/rust-ast-resolve-toplevel.h
+++ b/gcc/rust/resolve/rust-ast-resolve-toplevel.h
@@ -81,6 +81,20 @@ public:
       });
   }
 
+  void visit (AST::Union &union_decl) override
+  {
+    auto path
+      = prefix.append (CanonicalPath::new_seg (union_decl.get_node_id (),
+                                              union_decl.get_identifier ()));
+    resolver->get_type_scope ().insert (
+      path, union_decl.get_node_id (), union_decl.get_locus (), false,
+      [&] (const CanonicalPath &, NodeId, Location locus) -> void {
+       RichLocation r (union_decl.get_locus ());
+       r.add_range (locus);
+       rust_error_at (r, "redefined multiple times");
+      });
+  }
+
   void visit (AST::StaticItem &var) override
   {
     auto path = prefix.append (
diff --git a/gcc/rust/typecheck/rust-hir-type-check-stmt.h 
b/gcc/rust/typecheck/rust-hir-type-check-stmt.h
index 1b6f47c1595..fad2b7183df 100644
--- a/gcc/rust/typecheck/rust-hir-type-check-stmt.h
+++ b/gcc/rust/typecheck/rust-hir-type-check-stmt.h
@@ -159,7 +159,7 @@ public:
     TyTy::BaseType *type
       = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (),
                           mappings->get_next_hir_id (),
-                          struct_decl.get_identifier (), true,
+                          struct_decl.get_identifier (), true, false,
                           std::move (fields), std::move (substitutions));
 
     context->insert_type (struct_decl.get_mappings (), type);
@@ -209,13 +209,64 @@ public:
     TyTy::BaseType *type
       = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (),
                           mappings->get_next_hir_id (),
-                          struct_decl.get_identifier (), false,
+                          struct_decl.get_identifier (), false, false,
                           std::move (fields), std::move (substitutions));
 
     context->insert_type (struct_decl.get_mappings (), type);
     infered = type;
   }
 
+  void visit (HIR::Union &union_decl) override
+  {
+    std::vector<TyTy::SubstitutionParamMapping> substitutions;
+    if (union_decl.has_generics ())
+      {
+       for (auto &generic_param : union_decl.get_generic_params ())
+         {
+           switch (generic_param.get ()->get_kind ())
+             {
+             case HIR::GenericParam::GenericKind::LIFETIME:
+               // Skipping Lifetime completely until better handling.
+               break;
+
+               case HIR::GenericParam::GenericKind::TYPE: {
+                 auto param_type
+                   = TypeResolveGenericParam::Resolve (generic_param.get ());
+                 context->insert_type (generic_param->get_mappings (),
+                                       param_type);
+
+                 substitutions.push_back (TyTy::SubstitutionParamMapping (
+                   static_cast<HIR::TypeParam &> (*generic_param),
+                   param_type));
+               }
+               break;
+             }
+         }
+      }
+
+    std::vector<TyTy::StructFieldType *> variants;
+    union_decl.iterate ([&] (HIR::StructField &variant) mutable -> bool {
+      TyTy::BaseType *variant_type
+       = TypeCheckType::Resolve (variant.get_field_type ().get ());
+      TyTy::StructFieldType *ty_variant
+       = new TyTy::StructFieldType (variant.get_mappings ().get_hirid (),
+                                    variant.get_field_name (), variant_type);
+      variants.push_back (ty_variant);
+      context->insert_type (variant.get_mappings (),
+                           ty_variant->get_field_type ());
+      return true;
+    });
+
+    TyTy::BaseType *type
+      = new TyTy::ADTType (union_decl.get_mappings ().get_hirid (),
+                          mappings->get_next_hir_id (),
+                          union_decl.get_identifier (), false, true,
+                          std::move (variants), std::move (substitutions));
+
+    context->insert_type (union_decl.get_mappings (), type);
+    infered = type;
+  }
+
   void visit (HIR::Function &function) override
   {
     std::vector<TyTy::SubstitutionParamMapping> substitutions;
diff --git a/gcc/rust/typecheck/rust-hir-type-check-toplevel.h 
b/gcc/rust/typecheck/rust-hir-type-check-toplevel.h
index dd3dd751ad6..a723e7e679f 100644
--- a/gcc/rust/typecheck/rust-hir-type-check-toplevel.h
+++ b/gcc/rust/typecheck/rust-hir-type-check-toplevel.h
@@ -94,7 +94,7 @@ public:
     TyTy::BaseType *type
       = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (),
                           mappings->get_next_hir_id (),
-                          struct_decl.get_identifier (), true,
+                          struct_decl.get_identifier (), true, false,
                           std::move (fields), std::move (substitutions));
 
     context->insert_type (struct_decl.get_mappings (), type);
@@ -143,12 +143,62 @@ public:
     TyTy::BaseType *type
       = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (),
                           mappings->get_next_hir_id (),
-                          struct_decl.get_identifier (), false,
+                          struct_decl.get_identifier (), false, false,
                           std::move (fields), std::move (substitutions));
 
     context->insert_type (struct_decl.get_mappings (), type);
   }
 
+  void visit (HIR::Union &union_decl) override
+  {
+    std::vector<TyTy::SubstitutionParamMapping> substitutions;
+    if (union_decl.has_generics ())
+      {
+       for (auto &generic_param : union_decl.get_generic_params ())
+         {
+           switch (generic_param.get ()->get_kind ())
+             {
+             case HIR::GenericParam::GenericKind::LIFETIME:
+               // Skipping Lifetime completely until better handling.
+               break;
+
+               case HIR::GenericParam::GenericKind::TYPE: {
+                 auto param_type
+                   = TypeResolveGenericParam::Resolve (generic_param.get ());
+                 context->insert_type (generic_param->get_mappings (),
+                                       param_type);
+
+                 substitutions.push_back (TyTy::SubstitutionParamMapping (
+                   static_cast<HIR::TypeParam &> (*generic_param),
+                   param_type));
+               }
+               break;
+             }
+         }
+      }
+
+    std::vector<TyTy::StructFieldType *> variants;
+    union_decl.iterate ([&] (HIR::StructField &variant) mutable -> bool {
+      TyTy::BaseType *variant_type
+       = TypeCheckType::Resolve (variant.get_field_type ().get ());
+      TyTy::StructFieldType *ty_variant
+       = new TyTy::StructFieldType (variant.get_mappings ().get_hirid (),
+                                    variant.get_field_name (), variant_type);
+      variants.push_back (ty_variant);
+      context->insert_type (variant.get_mappings (),
+                           ty_variant->get_field_type ());
+      return true;
+    });
+
+    TyTy::BaseType *type
+      = new TyTy::ADTType (union_decl.get_mappings ().get_hirid (),
+                          mappings->get_next_hir_id (),
+                          union_decl.get_identifier (), false, true,
+                          std::move (variants), std::move (substitutions));
+
+    context->insert_type (union_decl.get_mappings (), type);
+  }
+
   void visit (HIR::StaticItem &var) override
   {
     TyTy::BaseType *type = TypeCheckType::Resolve (var.get_type ());
diff --git a/gcc/rust/typecheck/rust-hir-type-check.cc 
b/gcc/rust/typecheck/rust-hir-type-check.cc
index cb2896c0bb4..da528d7878a 100644
--- a/gcc/rust/typecheck/rust-hir-type-check.cc
+++ b/gcc/rust/typecheck/rust-hir-type-check.cc
@@ -180,7 +180,17 @@ TypeCheckStructExpr::visit (HIR::StructExprStructFields 
&struct_expr)
   // check the arguments are all assigned and fix up the ordering
   if (fields_assigned.size () != struct_path_resolved->num_fields ())
     {
-      if (!struct_expr.has_struct_base ())
+      if (struct_def->get_is_union ())
+       {
+         if (fields_assigned.size () != 1)
+           {
+             rust_error_at (
+               struct_expr.get_locus (),
+               "union must have exactly one field variant assigned");
+             return;
+           }
+       }
+      else if (!struct_expr.has_struct_base ())
        {
          rust_error_at (struct_expr.get_locus (),
                         "constructor is missing fields");
diff --git a/gcc/rust/typecheck/rust-tycheck-dump.h 
b/gcc/rust/typecheck/rust-tycheck-dump.h
index b80372b2a9c..cc2e3c01110 100644
--- a/gcc/rust/typecheck/rust-tycheck-dump.h
+++ b/gcc/rust/typecheck/rust-tycheck-dump.h
@@ -48,6 +48,12 @@ public:
            + "\n";
   }
 
+  void visit (HIR::Union &union_decl) override
+  {
+    dump
+      += indent () + "union " + type_string (union_decl.get_mappings ()) + 
"\n";
+  }
+
   void visit (HIR::ImplBlock &impl_block) override
   {
     dump += indent () + "impl "
diff --git a/gcc/rust/typecheck/rust-tyty.cc b/gcc/rust/typecheck/rust-tyty.cc
index f043c7eabda..d059134f8a0 100644
--- a/gcc/rust/typecheck/rust-tyty.cc
+++ b/gcc/rust/typecheck/rust-tyty.cc
@@ -517,8 +517,8 @@ ADTType::clone ()
     cloned_fields.push_back ((StructFieldType *) f->clone ());
 
   return new ADTType (get_ref (), get_ty_ref (), identifier, get_is_tuple (),
-                     cloned_fields, clone_substs (), used_arguments,
-                     get_combined_refs ());
+                     get_is_union (), cloned_fields, clone_substs (),
+                     used_arguments, get_combined_refs ());
 }
 
 ADTType *
diff --git a/gcc/rust/typecheck/rust-tyty.h b/gcc/rust/typecheck/rust-tyty.h
index 2152c1b6d76..b7cf46bb783 100644
--- a/gcc/rust/typecheck/rust-tyty.h
+++ b/gcc/rust/typecheck/rust-tyty.h
@@ -848,7 +848,7 @@ protected:
 class ADTType : public BaseType, public SubstitutionRef
 {
 public:
-  ADTType (HirId ref, std::string identifier, bool is_tuple,
+  ADTType (HirId ref, std::string identifier, bool is_tuple, bool is_union,
           std::vector<StructFieldType *> fields,
           std::vector<SubstitutionParamMapping> subst_refs,
           SubstitutionArgumentMappings generic_arguments
@@ -856,21 +856,24 @@ public:
           std::set<HirId> refs = std::set<HirId> ())
     : BaseType (ref, ref, TypeKind::ADT, refs),
       SubstitutionRef (std::move (subst_refs), std::move (generic_arguments)),
-      identifier (identifier), fields (fields), is_tuple (is_tuple)
+      identifier (identifier), fields (fields), is_tuple (is_tuple),
+      is_union (is_union)
   {}
 
   ADTType (HirId ref, HirId ty_ref, std::string identifier, bool is_tuple,
-          std::vector<StructFieldType *> fields,
+          bool is_union, std::vector<StructFieldType *> fields,
           std::vector<SubstitutionParamMapping> subst_refs,
           SubstitutionArgumentMappings generic_arguments
           = SubstitutionArgumentMappings::error (),
           std::set<HirId> refs = std::set<HirId> ())
     : BaseType (ref, ty_ref, TypeKind::ADT, refs),
       SubstitutionRef (std::move (subst_refs), std::move (generic_arguments)),
-      identifier (identifier), fields (fields), is_tuple (is_tuple)
+      identifier (identifier), fields (fields), is_tuple (is_tuple),
+      is_union (is_union)
   {}
 
   bool get_is_tuple () { return is_tuple; }
+  bool get_is_union () { return is_union; }
 
   bool is_unit () const override { return this->fields.empty (); }
 
@@ -957,6 +960,7 @@ private:
   std::string identifier;
   std::vector<StructFieldType *> fields;
   bool is_tuple;
+  bool is_union;
 };
 
 class FnType : public BaseType, public SubstitutionRef
-- 
2.32.0

-- 
Gcc-rust mailing list
Gcc-rust@gcc.gnu.org
https://gcc.gnu.org/mailman/listinfo/gcc-rust

Reply via email to