tbaeder updated this revision to Diff 516701.

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D149133/new/

https://reviews.llvm.org/D149133

Files:
  clang/lib/AST/Interp/ByteCodeExprGen.cpp
  clang/lib/AST/Interp/ByteCodeExprGen.h
  clang/lib/AST/Interp/Interp.cpp
  clang/lib/AST/Interp/Interp.h
  clang/lib/AST/Interp/Opcodes.td
  clang/lib/AST/Interp/Pointer.h
  clang/test/AST/Interp/records.cpp

Index: clang/test/AST/Interp/records.cpp
===================================================================
--- clang/test/AST/Interp/records.cpp
+++ clang/test/AST/Interp/records.cpp
@@ -586,6 +586,58 @@
   static_assert(test() == 1);
 }
 
+namespace BaseToDerived {
+namespace A {
+  struct A {};
+  struct B : A { int n; };
+  struct C : B {};
+  C c = {};
+  constexpr C *pb = (C*)((A*)&c + 1); // expected-error {{must be initialized by a constant expression}} \
+                                      // expected-note {{cannot access derived class of pointer past the end of object}} \
+                                      // ref-error {{must be initialized by a constant expression}} \
+                                      // ref-note {{cannot access derived class of pointer past the end of object}}
+}
+namespace B {
+  struct A {};
+  struct Z {};
+  struct B : Z, A {
+    int n;
+   constexpr B() : n(10) {}
+  };
+  struct C : B {
+   constexpr C() : B() {}
+  };
+
+  constexpr C c = {};
+  constexpr const A *pa = &c;
+  constexpr const C *cp = (C*)pa;
+  constexpr const B *cb = (B*)cp;
+
+  static_assert(cb->n == 10);
+  static_assert(cp->n == 10);
+}
+
+namespace C {
+  struct Base { int *a; };
+  struct Base2 : Base { int f[12]; };
+
+  struct Middle1 { int b[3]; };
+  struct Middle2 : Base2 { char c; };
+  struct Middle3 : Middle2 { char g[3]; };
+  struct Middle4 { int f[3]; };
+  struct Middle5 : Middle4, Middle3 { char g2[3]; };
+
+  struct NotQuiteDerived : Middle1, Middle5 { bool d; };
+  struct Derived : NotQuiteDerived { int e; };
+
+  constexpr NotQuiteDerived NQD1 = {};
+
+  constexpr Middle5 *M4 = (Middle5*)((Base2*)&NQD1);
+  static_assert(M4->a == nullptr);
+  static_assert(M4->g2[0] == 0);
+}
+}
+
 
 namespace VirtualDtors {
   class A {
Index: clang/lib/AST/Interp/Pointer.h
===================================================================
--- clang/lib/AST/Interp/Pointer.h
+++ clang/lib/AST/Interp/Pointer.h
@@ -100,6 +100,14 @@
     return Pointer(Pointee, Field, Field);
   }
 
+  /// Subtract the given offset from the current Base and Offset
+  /// of the pointer.
+  Pointer atFieldSub(unsigned Off) const {
+    assert(Offset >= Off);
+    unsigned O = Offset - Off;
+    return Pointer(Pointee, O, O);
+  }
+
   /// Restricts the scope of an array element pointer.
   Pointer narrow() const {
     // Null pointers cannot be narrowed.
Index: clang/lib/AST/Interp/Opcodes.td
===================================================================
--- clang/lib/AST/Interp/Opcodes.td
+++ clang/lib/AST/Interp/Opcodes.td
@@ -291,6 +291,10 @@
   let Args = [ArgUint32];
 }
 
+def GetPtrDerivedPop : Opcode {
+  let Args = [ArgUint32];
+}
+
 // [Pointer] -> [Pointer]
 def GetPtrVirtBase : Opcode {
   // RecordDecl of base class.
Index: clang/lib/AST/Interp/Interp.h
===================================================================
--- clang/lib/AST/Interp/Interp.h
+++ clang/lib/AST/Interp/Interp.h
@@ -67,8 +67,9 @@
 bool CheckRange(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
                 CheckSubobjectKind CSK);
 
-/// Checks if accessing a base of the given pointer is valid.
-bool CheckBase(InterpState &S, CodePtr OpPC, const Pointer &Ptr);
+/// Checks if accessing a base or derived record of the given pointer is valid.
+bool CheckBaseDerived(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
+                      CheckSubobjectKind CSK);
 
 /// Checks if a pointer points to const storage.
 bool CheckConst(InterpState &S, CodePtr OpPC, const Pointer &Ptr);
@@ -1078,11 +1079,21 @@
   return true;
 }
 
+inline bool GetPtrDerivedPop(InterpState &S, CodePtr OpPC, uint32_t Off) {
+  const Pointer &Ptr = S.Stk.pop<Pointer>();
+  if (!CheckNull(S, OpPC, Ptr, CSK_Derived))
+    return false;
+  if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Derived))
+    return false;
+  S.Stk.push<Pointer>(Ptr.atFieldSub(Off));
+  return true;
+}
+
 inline bool GetPtrBase(InterpState &S, CodePtr OpPC, uint32_t Off) {
   const Pointer &Ptr = S.Stk.peek<Pointer>();
   if (!CheckNull(S, OpPC, Ptr, CSK_Base))
     return false;
-  if (!CheckBase(S, OpPC, Ptr))
+  if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Base))
     return false;
   S.Stk.push<Pointer>(Ptr.atField(Off));
   return true;
@@ -1092,7 +1103,7 @@
   const Pointer &Ptr = S.Stk.pop<Pointer>();
   if (!CheckNull(S, OpPC, Ptr, CSK_Base))
     return false;
-  if (!CheckBase(S, OpPC, Ptr))
+  if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Base))
     return false;
   S.Stk.push<Pointer>(Ptr.atField(Off));
   return true;
Index: clang/lib/AST/Interp/Interp.cpp
===================================================================
--- clang/lib/AST/Interp/Interp.cpp
+++ clang/lib/AST/Interp/Interp.cpp
@@ -211,12 +211,13 @@
   return false;
 }
 
-bool CheckBase(InterpState &S, CodePtr OpPC, const Pointer &Ptr) {
+bool CheckBaseDerived(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
+                      CheckSubobjectKind CSK) {
   if (!Ptr.isOnePastEnd())
     return true;
 
   const SourceInfo &Loc = S.Current->getSource(OpPC);
-  S.FFDiag(Loc, diag::note_constexpr_past_end_subobject) << CSK_Base;
+  S.FFDiag(Loc, diag::note_constexpr_past_end_subobject) << CSK;
   return false;
 }
 
Index: clang/lib/AST/Interp/ByteCodeExprGen.h
===================================================================
--- clang/lib/AST/Interp/ByteCodeExprGen.h
+++ clang/lib/AST/Interp/ByteCodeExprGen.h
@@ -269,8 +269,8 @@
   }
 
   bool emitRecordDestruction(const Descriptor *Desc);
-  bool emitDerivedToBaseCasts(const RecordType *DerivedType,
-                              const RecordType *BaseType, const Expr *E);
+  unsigned collectBaseOffset(const RecordType *BaseType,
+                             const RecordType *DerivedType);
   bool emitBuiltinBitCast(const CastExpr *E);
 
 protected:
Index: clang/lib/AST/Interp/ByteCodeExprGen.cpp
===================================================================
--- clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -138,8 +138,20 @@
     if (!this->visit(SubExpr))
       return false;
 
-    return this->emitDerivedToBaseCasts(getRecordTy(SubExpr->getType()),
-                                        getRecordTy(CE->getType()), CE);
+    unsigned DerivedOffset = collectBaseOffset(getRecordTy(CE->getType()),
+                                               getRecordTy(SubExpr->getType()));
+
+    return this->emitGetPtrBasePop(DerivedOffset, CE);
+  }
+
+  case CK_BaseToDerived: {
+    if (!this->visit(SubExpr))
+      return false;
+
+    unsigned DerivedOffset = collectBaseOffset(getRecordTy(SubExpr->getType()),
+                                               getRecordTy(CE->getType()));
+
+    return this->emitGetPtrDerivedPop(DerivedOffset, CE);
   }
 
   case CK_FloatingCast: {
@@ -2124,13 +2136,15 @@
 }
 
 template <class Emitter>
-bool ByteCodeExprGen<Emitter>::emitDerivedToBaseCasts(
-    const RecordType *DerivedType, const RecordType *BaseType, const Expr *E) {
-  // Pointer of derived type is already on the stack.
+unsigned
+ByteCodeExprGen<Emitter>::collectBaseOffset(const RecordType *BaseType,
+                                            const RecordType *DerivedType) {
   const auto *FinalDecl = cast<CXXRecordDecl>(BaseType->getDecl());
   const RecordDecl *CurDecl = DerivedType->getDecl();
   const Record *CurRecord = getRecord(CurDecl);
   assert(CurDecl && FinalDecl);
+
+  unsigned OffsetSum = 0;
   for (;;) {
     assert(CurRecord->getNumBases() > 0);
     // One level up
@@ -2138,21 +2152,18 @@
       const auto *BaseDecl = cast<CXXRecordDecl>(B.Decl);
 
       if (BaseDecl == FinalDecl || BaseDecl->isDerivedFrom(FinalDecl)) {
-        // This decl will lead us to the final decl, so emit a base cast.
-        if (!this->emitGetPtrBasePop(B.Offset, E))
-          return false;
-
+        OffsetSum += B.Offset;
         CurRecord = B.R;
         CurDecl = BaseDecl;
         break;
       }
     }
     if (CurDecl == FinalDecl)
-      return true;
+      break;
   }
 
-  llvm_unreachable("Couldn't find the base class?");
-  return false;
+  assert(OffsetSum > 0);
+  return OffsetSum;
 }
 
 /// When calling this, we have a pointer of the local-to-destroy
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to