nand created this revision. nand added reviewers: Bigcheese, jfb, rsmith, dexonsmith. Herald added a project: clang. Herald added a subscriber: cfe-commits.
Added support for: - for - switch - do-while - while Also implemented assignment to enable useful tests. The patch also includes range-based for loops which are not yet tested, but having them in ByteCodeStmtGen should simplify the process of slicing off further patches. Repository: rG LLVM Github Monorepo https://reviews.llvm.org/D70086 Files: clang/lib/AST/Interp/ByteCodeEmitter.cpp clang/lib/AST/Interp/ByteCodeExprGen.cpp clang/lib/AST/Interp/ByteCodeExprGen.h clang/lib/AST/Interp/ByteCodeStmtGen.cpp clang/lib/AST/Interp/ByteCodeStmtGen.h clang/lib/AST/Interp/Context.cpp clang/lib/AST/Interp/Context.h clang/lib/AST/Interp/EvalEmitter.cpp clang/lib/AST/Interp/InterpLoop.cpp clang/lib/AST/Interp/Opcodes.td clang/lib/AST/Interp/Opcodes/Comparison.h clang/test/AST/Interp/flow.cpp
Index: clang/test/AST/Interp/flow.cpp =================================================================== --- /dev/null +++ clang/test/AST/Interp/flow.cpp @@ -0,0 +1,116 @@ +// RUN: %clang_cc1 -std=c++17 -fsyntax-only -fexperimental-new-constant-interpreter %s -verify +// RUN: %clang_cc1 -std=c++17 -fsyntax-only %s -verify +// expected-no-diagnostics + +constexpr int fn_for_break(int n) { + int x = 0; + for (int i = 0; i < 20; i = i + 1) { + if (i == n) + break; + x = x + 1; + } + return x; +} +using A = int[5]; +using A = int[fn_for_break(5)]; + +constexpr int fn_for_cont(int n) { + int x = 0; + for (int i = 0; i < 20; i = i + 1) { + if (i < n) + continue; + x = x + 1; + } + return x; +} +using B = int[15]; +using B = int[fn_for_cont(5)]; + +constexpr int fn_while_break(unsigned n) { + int x = 0; + int i = 0; + while (int next = i = i + 1) { + if (next == n) { + break; + } + x = x + 1; + } + return x; +} +using C = int[4]; +using C = int[fn_while_break(5)]; + +constexpr int fn_while_cont(unsigned n) { + int x = 0; + int i = 0; + while (i < 10) { + i = i + 1; + if (i < n) { + continue; + } + x = x + 1; + } + return x; +} +using D = int[6]; +using D = int[fn_while_cont(5)]; + +constexpr int fn_do_break(unsigned n) { + int x = 0; + int i = 0; + do { + if (i == n) { + break; + } + x = x + 1; + i = i + 1; + } while (i < 20); + return x; +} +using E = int[5]; +using E = int[fn_do_break(5)]; + +constexpr int fn_do_continue(unsigned n) { + int x = 0; + int i = 0; + do { + i = i + 1; + if (i < n) { + continue; + } + x = x + 1; + } while (i < 20); + return x; +} +using F = int[16]; +using F = int[fn_do_continue(5)]; + +constexpr int fn_for_cond_var(int n) { + int a = 0; + for (int i = n; int b = a + 1; i = i + 1) { + if (i == 50) { + a = 0 - b; + } + } + return a - 1; +} + +using G = int[-fn_for_cond_var(10)]; +using G = int[2]; + +constexpr int fn_switch(int a) { + switch (a) { + case 0: return 2; + case 1: return 5; + default: return 6; + } +} + +using H0 = int[2]; +using H0 = int[fn_switch(0)]; + +using H1 = int[5]; +using H1 = int[fn_switch(1)]; + +using H2 = int[6]; +using H2 = int[fn_switch(2)]; Index: clang/lib/AST/Interp/Opcodes/Comparison.h =================================================================== --- clang/lib/AST/Interp/Opcodes/Comparison.h +++ clang/lib/AST/Interp/Opcodes/Comparison.h @@ -118,4 +118,15 @@ return false; } +template <PrimType Name> +bool InRange(InterpState &S, CodePtr OpPC) { + using T = typename PrimConv<Name>::T; + const T &RHS = S.Stk.pop<T>(); + const T &LHS = S.Stk.pop<T>(); + const T &Value = S.Stk.pop<T>(); + + S.Stk.push<bool>(LHS <= Value && Value <= RHS); + return true; +} + #endif Index: clang/lib/AST/Interp/Opcodes.td =================================================================== --- clang/lib/AST/Interp/Opcodes.td +++ clang/lib/AST/Interp/Opcodes.td @@ -339,6 +339,16 @@ def GT : ComparisonOpcode; def GE : ComparisonOpcode; +//===----------------------------------------------------------------------===// +// Range test. +//===----------------------------------------------------------------------===// + +// [Real, Real, Real] -> [Bool] +def InRange : Opcode { + let Types = [AluFPTypeClass]; + let HasGroup = 1; +} + //===----------------------------------------------------------------------===// // Stack management. //===----------------------------------------------------------------------===// Index: clang/lib/AST/Interp/InterpLoop.cpp =================================================================== --- clang/lib/AST/Interp/InterpLoop.cpp +++ clang/lib/AST/Interp/InterpLoop.cpp @@ -108,6 +108,8 @@ return false; if (S.checkingPotentialConstantExpression()) return false; + if (!F->isConstexpr()) + return false; // Adjust the state. S.CallStackDepth++; Index: clang/lib/AST/Interp/EvalEmitter.cpp =================================================================== --- clang/lib/AST/Interp/EvalEmitter.cpp +++ clang/lib/AST/Interp/EvalEmitter.cpp @@ -177,6 +177,8 @@ return false; if (S.checkingPotentialConstantExpression()) return false; + if (!F->isConstexpr()) + return false; S.Current = new InterpFrame(S, F, S.Current, OpPC, std::move(This)); return Interpret(S, Result); } Index: clang/lib/AST/Interp/Context.h =================================================================== --- clang/lib/AST/Interp/Context.h +++ clang/lib/AST/Interp/Context.h @@ -70,13 +70,6 @@ /// Classifies an expression. llvm::Optional<PrimType> classify(QualType T); -private: - /// Runs a function. - bool Run(State &Parent, Function *Func, APValue &Result); - - /// Checks a result fromt the interpreter. - bool Check(State &Parent, llvm::Expected<bool> &&R); - private: /// Current compilation context. ASTContext &Ctx; Index: clang/lib/AST/Interp/Context.cpp =================================================================== --- clang/lib/AST/Interp/Context.cpp +++ clang/lib/AST/Interp/Context.cpp @@ -26,6 +26,8 @@ Context::~Context() {} bool Context::isPotentialConstantExpr(State &Parent, const FunctionDecl *FD) { + // Try to compile the function. This either produces an error message (if this + // is the first attempt to compile) or returns a dummy function with no body. Function *Func = P->getFunction(FD); if (!Func) { ByteCodeStmtGen<ByteCodeEmitter> C(*this, *P, Parent); @@ -39,22 +41,43 @@ } } + // If function has no body, it is definitely not constexpr. if (!Func->isConstexpr()) return false; - APValue Dummy; - return Run(Parent, Func, Dummy); + // Run the function in a dummy context. + APValue DummyResult; + InterpState State(Parent, *P, Stk, *this); + State.Current = new InterpFrame(State, Func, nullptr, {}, {}); + if (Interpret(State, DummyResult)) + return true; + Stk.clear(); + return false; } bool Context::evaluateAsRValue(State &Parent, const Expr *E, APValue &Result) { ByteCodeExprGen<EvalEmitter> C(*this, *P, Parent, Stk, Result); - return Check(Parent, C.interpretExpr(E)); + if (auto Flag = C.interpretExpr(E)) { + return *Flag; + } else { + handleAllErrors(Flag.takeError(), [&Parent](ByteCodeGenError &Err) { + Parent.FFDiag(Err.getLoc(), diag::err_experimental_clang_interp_failed); + }); + return false; + } } bool Context::evaluateAsInitializer(State &Parent, const VarDecl *VD, APValue &Result) { ByteCodeExprGen<EvalEmitter> C(*this, *P, Parent, Stk, Result); - return Check(Parent, C.interpretDecl(VD)); + if (auto Flag = C.interpretDecl(VD)) { + return *Flag; + } else { + handleAllErrors(Flag.takeError(), [&Parent](ByteCodeGenError &Err) { + Parent.FFDiag(Err.getLoc(), diag::err_experimental_clang_interp_failed); + }); + return false; + } } const LangOptions &Context::getLangOpts() const { return Ctx.getLangOpts(); } @@ -114,21 +137,3 @@ unsigned Context::getCharBit() const { return Ctx.getTargetInfo().getCharWidth(); } - -bool Context::Run(State &Parent, Function *Func, APValue &Result) { - InterpState State(Parent, *P, Stk, *this); - State.Current = new InterpFrame(State, Func, nullptr, {}, {}); - if (Interpret(State, Result)) - return true; - Stk.clear(); - return false; -} - -bool Context::Check(State &Parent, llvm::Expected<bool> &&Flag) { - if (Flag) - return *Flag; - handleAllErrors(Flag.takeError(), [&Parent](ByteCodeGenError &Err) { - Parent.FFDiag(Err.getLoc(), diag::err_experimental_clang_interp_failed); - }); - return false; -} Index: clang/lib/AST/Interp/ByteCodeStmtGen.h =================================================================== --- clang/lib/AST/Interp/ByteCodeStmtGen.h +++ clang/lib/AST/Interp/ByteCodeStmtGen.h @@ -60,8 +60,16 @@ bool visitStmt(const Stmt *S); bool visitCompoundStmt(const CompoundStmt *S); bool visitDeclStmt(const DeclStmt *DS); + bool visitForStmt(const ForStmt *FS); + bool visitWhileStmt(const WhileStmt *DS); + bool visitDoStmt(const DoStmt *DS); bool visitReturnStmt(const ReturnStmt *RS); bool visitIfStmt(const IfStmt *IS); + bool visitBreakStmt(const BreakStmt *BS); + bool visitContinueStmt(const ContinueStmt *CS); + bool visitSwitchStmt(const SwitchStmt *SS); + bool visitCaseStmt(const SwitchCase *CS); + bool visitCXXForRangeStmt(const CXXForRangeStmt *FS); /// Compiles a variable declaration. bool visitVarDecl(const VarDecl *VD); Index: clang/lib/AST/Interp/ByteCodeStmtGen.cpp =================================================================== --- clang/lib/AST/Interp/ByteCodeStmtGen.cpp +++ clang/lib/AST/Interp/ByteCodeStmtGen.cpp @@ -118,10 +118,27 @@ return visitCompoundStmt(cast<CompoundStmt>(S)); case Stmt::DeclStmtClass: return visitDeclStmt(cast<DeclStmt>(S)); + case Stmt::ForStmtClass: + return visitForStmt(cast<ForStmt>(S)); + case Stmt::WhileStmtClass: + return visitWhileStmt(cast<WhileStmt>(S)); + case Stmt::DoStmtClass: + return visitDoStmt(cast<DoStmt>(S)); case Stmt::ReturnStmtClass: return visitReturnStmt(cast<ReturnStmt>(S)); case Stmt::IfStmtClass: return visitIfStmt(cast<IfStmt>(S)); + case Stmt::BreakStmtClass: + return visitBreakStmt(cast<BreakStmt>(S)); + case Stmt::ContinueStmtClass: + return visitContinueStmt(cast<ContinueStmt>(S)); + case Stmt::SwitchStmtClass: + return visitSwitchStmt(cast<SwitchStmt>(S)); + case Stmt::CaseStmtClass: + case Stmt::DefaultStmtClass: + return visitCaseStmt(cast<SwitchCase>(S)); + case Stmt::CXXForRangeStmtClass: + return visitCXXForRangeStmt(cast<CXXForRangeStmt>(S)); case Stmt::NullStmtClass: return true; default: { @@ -136,9 +153,10 @@ bool ByteCodeStmtGen<Emitter>::visitCompoundStmt( const CompoundStmt *CompoundStmt) { BlockScope<Emitter> Scope(this); - for (auto *InnerStmt : CompoundStmt->body()) + for (auto *InnerStmt : CompoundStmt->body()) { if (!visitStmt(InnerStmt)) return false; + } return true; } @@ -161,6 +179,114 @@ return true; } +template <class Emitter> +bool ByteCodeStmtGen<Emitter>::visitForStmt(const ForStmt *FS) { + // Compile the initialisation statement in an outer scope. + BlockScope<Emitter> OuterScope(this); + if (auto *Init = FS->getInit()) + if (!visitStmt(Init)) + return false; + + LabelTy LabelStart = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + + // Compile the condition, body and increment in the loop scope. + this->emitLabel(LabelStart); + { + BlockScope<Emitter> InnerScope(this); + + if (auto *Cond = FS->getCond()) { + if (auto *CondDecl = FS->getConditionVariableDeclStmt()) + if (!visitDeclStmt(CondDecl)) + return false; + + if (!this->visit(Cond)) + return false; + + if (!this->jumpFalse(LabelEnd)) + return false; + } + + if (auto *Body = FS->getBody()) { + LabelTy LabelSkip = this->getLabel(); + LoopScope<Emitter> FlowScope(this, LabelEnd, LabelSkip); + if (!visitStmt(Body)) + return false; + this->emitLabel(LabelSkip); + } + + if (auto *Inc = FS->getInc()) { + ExprScope<Emitter> IncScope(this); + if (!this->discard(Inc)) + return false; + } + if (!this->jump(LabelStart)) + return false; + } + this->emitLabel(LabelEnd); + return true; +} + +template <class Emitter> +bool ByteCodeStmtGen<Emitter>::visitWhileStmt(const WhileStmt *WS) { + LabelTy LabelStart = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + + this->emitLabel(LabelStart); + { + BlockScope<Emitter> InnerScope(this); + if (auto *CondDecl = WS->getConditionVariableDeclStmt()) + if (!visitDeclStmt(CondDecl)) + return false; + + if (!this->visit(WS->getCond())) + return false; + + if (!this->jumpFalse(LabelEnd)) + return false; + + { + LoopScope<Emitter> FlowScope(this, LabelEnd, LabelStart); + if (!visitStmt(WS->getBody())) + return false; + } + if (!this->jump(LabelStart)) + return false; + } + this->emitLabel(LabelEnd); + + return true; +} + +template <class Emitter> +bool ByteCodeStmtGen<Emitter>::visitDoStmt(const DoStmt *DS) { + LabelTy LabelStart = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + LabelTy LabelSkip = this->getLabel(); + + this->emitLabel(LabelStart); + { + { + LoopScope<Emitter> FlowScope(this, LabelEnd, LabelSkip); + if (!visitStmt(DS->getBody())) + return false; + this->emitLabel(LabelSkip); + } + + { + ExprScope<Emitter> CondScope(this); + if (!this->visitBool(DS->getCond())) + return false; + } + + if (!this->jumpTrue(LabelStart)) + return false; + } + this->emitLabel(LabelEnd); + + return true; +} + template <class Emitter> bool ByteCodeStmtGen<Emitter>::visitReturnStmt(const ReturnStmt *RS) { if (const Expr *RE = RS->getRetValue()) { @@ -222,6 +348,167 @@ return true; } +template <class Emitter> +bool ByteCodeStmtGen<Emitter>::visitBreakStmt(const BreakStmt *BS) { + if (!BreakLabel) + return this->bail(BS); + return this->jump(*BreakLabel); +} + +template <class Emitter> +bool ByteCodeStmtGen<Emitter>::visitContinueStmt(const ContinueStmt *CS) { + if (!ContinueLabel) + return this->bail(CS); + return this->jump(*ContinueLabel); +} + +template <class Emitter> +bool ByteCodeStmtGen<Emitter>::visitSwitchStmt(const SwitchStmt *SS) { + BlockScope<Emitter> InnerScope(this); + + if (Optional<PrimType> T = this->classify(SS->getCond()->getType())) { + // The condition is stored in a local and fetched for every test. + unsigned Off = this->allocateLocalPrimitive(SS->getCond(), *T, + /*isConst=*/true); + + // Compile the condition in its own scope. + { + ExprScope<Emitter> CondScope(this); + if (const Stmt *CondInit = SS->getInit()) + if (!visitStmt(SS->getInit())) + return false; + + if (const DeclStmt *CondDecl = SS->getConditionVariableDeclStmt()) + if (!visitDeclStmt(CondDecl)) + return false; + + if (!this->visit(SS->getCond())) + return false; + + if (!this->emitSetLocal(*T, Off, SS->getCond())) + return false; + } + + LabelTy LabelEnd = this->getLabel(); + + // Generate code to inspect all case labels, jumping to the matched one. + const DefaultStmt *Default = nullptr; + CaseMap Labels; + for (auto *SC = SS->getSwitchCaseList(); SC; SC = SC->getNextSwitchCase()) { + LabelTy Label = this->getLabel(); + Labels.insert({SC, Label}); + + if (auto *DS = dyn_cast<DefaultStmt>(SC)) { + Default = DS; + continue; + } + + if (auto *CS = dyn_cast<CaseStmt>(SC)) { + if (!this->emitGetLocal(*T, Off, CS)) + return false; + if (!this->visit(CS->getLHS())) + return false; + + if (auto *RHS = CS->getRHS()) { + if (!this->visit(CS->getRHS())) + return false; + if (!this->emitInRange(*T, CS)) + return false; + } else { + if (!this->emitEQ(*T, CS)) + return false; + } + + if (!this->jumpTrue(Label)) + return false; + continue; + } + + return this->bail(SS); + } + + // If a case wasn't matched, jump to default or skip the body. + if (!this->jump(Default ? Labels[Default] : LabelEnd)) + return false; + OptLabelTy DefaultLabel = Default ? Labels[Default] : OptLabelTy{}; + + // Compile the body, using labels defined previously. + SwitchScope<Emitter> LabelScope(this, std::move(Labels), LabelEnd, + DefaultLabel); + if (!visitStmt(SS->getBody())) + return false; + this->emitLabel(LabelEnd); + return true; + } else { + return this->bail(SS); + } +} + +template <class Emitter> +bool ByteCodeStmtGen<Emitter>::visitCaseStmt(const SwitchCase *CS) { + auto It = CaseLabels.find(CS); + if (It == CaseLabels.end()) + return this->bail(CS); + + this->emitLabel(It->second); + return visitStmt(CS->getSubStmt()); +} + +template <class Emitter> +bool ByteCodeStmtGen<Emitter>::visitCXXForRangeStmt(const CXXForRangeStmt *FS) { + BlockScope<Emitter> Scope(this); + + // Emit the optional init-statement. + if (auto *Init = FS->getInit()) { + if (!visitStmt(Init)) + return false; + } + + // Initialise the __range variable. + if (!visitStmt(FS->getRangeStmt())) + return false; + + // Create the __begin and __end iterators. + if (!visitStmt(FS->getBeginStmt()) || !visitStmt(FS->getEndStmt())) + return false; + + LabelTy LabelStart = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + + this->emitLabel(LabelStart); + { + // Lower the condition. + if (!this->visitBool(FS->getCond())) + return false; + if (!this->jumpFalse(LabelEnd)) + return false; + + // Lower the loop var and body, marking labels for continue/break. + { + BlockScope<Emitter> InnerScope(this); + if (!visitStmt(FS->getLoopVarStmt())) + return false; + + LabelTy LabelSkip = this->getLabel(); + { + LoopScope<Emitter> FlowScope(this, LabelEnd, LabelSkip); + + if (!visitStmt(FS->getBody())) + return false; + } + this->emitLabel(LabelSkip); + } + + // Increment: ++__begin + if (!visitStmt(FS->getInc())) + return false; + if (!this->jump(LabelStart)) + return false; + } + this->emitLabel(LabelEnd); + return true; +} + template <class Emitter> bool ByteCodeStmtGen<Emitter>::visitVarDecl(const VarDecl *VD) { auto DT = VD->getType(); Index: clang/lib/AST/Interp/ByteCodeExprGen.h =================================================================== --- clang/lib/AST/Interp/ByteCodeExprGen.h +++ clang/lib/AST/Interp/ByteCodeExprGen.h @@ -76,6 +76,10 @@ bool VisitUnaryMinus(const UnaryOperator *E); bool VisitCallExpr(const CallExpr *E); + // Fallback methods for nodes which are not yet implemented. + bool VisitStmt(const Stmt *E) { llvm_unreachable("not an expression"); } + bool VisitExpr(const Expr *E) { return this->bail(E); } + protected: bool visitExpr(const Expr *E) override; bool visitDecl(const VarDecl *VD) override; @@ -146,6 +150,8 @@ bool emitFunctionCall(const FunctionDecl *Callee, llvm::Optional<PrimType> T, const Expr *Call); + bool visitAssign(PrimType T, const BinaryOperator *BO); + enum class DerefKind { /// Value is read and pushed to stack. Read, Index: clang/lib/AST/Interp/ByteCodeExprGen.cpp =================================================================== --- clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -300,6 +300,9 @@ }; switch (BO->getOpcode()) { + case BO_Assign: + return visitAssign(*T, BO); + case BO_EQ: return Discard(this->emitEQ(*LT, BO)); case BO_NE: return Discard(this->emitNE(*LT, BO)); case BO_LT: return Discard(this->emitLT(*LT, BO)); @@ -372,8 +375,10 @@ return false; } } else { - consumeError(Func.takeError()); - return this->bail(E); + handleAllErrors(Func.takeError(), [this](ByteCodeGenError &Err) { + S.FFDiag(Err.getLoc(), diag::err_experimental_clang_interp_failed); + }); + return false; } return DiscardResult && T ? this->emitPop(*T, E) : true; } @@ -410,6 +415,31 @@ } } +template <class Emitter> +bool ByteCodeExprGen<Emitter>::visitAssign(PrimType T, + const BinaryOperator *BO) { + return dereference( + BO->getLHS(), DerefKind::Write, + [this, BO](PrimType) { + // Generate a value to store - will be set. + return visit(BO->getRHS()); + }, + [this, BO](PrimType T) { + // Pointer on stack - compile RHS and assign to pointer. + if (!visit(BO->getRHS())) + return false; + + if (BO->getLHS()->refersToBitField()) { + return this->bail(BO); + } else { + if (DiscardResult) + return this->emitStorePop(T, BO); + else + return this->emitStore(T, BO); + } + }); +} + template <class Emitter> bool ByteCodeExprGen<Emitter>::dereference( const Expr *LV, DerefKind AK, llvm::function_ref<bool(PrimType)> Direct, Index: clang/lib/AST/Interp/ByteCodeEmitter.cpp =================================================================== --- clang/lib/AST/Interp/ByteCodeEmitter.cpp +++ clang/lib/AST/Interp/ByteCodeEmitter.cpp @@ -57,12 +57,17 @@ Function *Func = P.createFunction(F, ParamOffset, std::move(ParamTypes), std::move(ParamDescriptors)); // Compile the function body. - if (!F->isConstexpr() || !visitFunc(F)) { - // Return a dummy function if compilation failed. - if (BailLocation) + if (!F->isConstexpr()) { + // Return a dummy function for non-constexpr. + return Func; + } else if (!visitFunc(F)) { + if (BailLocation) { + // If compiler bailed, return an error. return llvm::make_error<ByteCodeGenError>(*BailLocation); - else + } else { + // Otherwise, return a dummy function which is not constexpr. return Func; + } } else { // Create scopes from descriptors. llvm::SmallVector<Scope, 2> Scopes;
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits