LuoYuanke updated this revision to Diff 393012.
LuoYuanke added a comment.
Herald added a subscriber: martong.

Updating D115199 <https://reviews.llvm.org/D115199>: [WIP][X86][AMX] Support 
amxpreserve attribute in clang.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D115199/new/

https://reviews.llvm.org/D115199

Files:
  clang/include/clang/AST/Type.h
  clang/include/clang/AST/TypeProperties.td
  clang/include/clang/Basic/Attr.td
  clang/include/clang/Basic/AttrDocs.td
  clang/include/clang/CodeGen/CGFunctionInfo.h
  clang/lib/AST/ASTContext.cpp
  clang/lib/AST/ASTStructuralEquivalence.cpp
  clang/lib/AST/TypePrinter.cpp
  clang/lib/CodeGen/CGCall.cpp
  clang/lib/CodeGen/CodeGenModule.cpp
  clang/lib/Sema/SemaDecl.cpp
  clang/lib/Sema/SemaType.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/test/Sema/attr-target-mv.c
  clang/test/SemaCXX/attr-non-x86-amx-preserve.cpp
  clang/test/SemaCXX/attr-x86-amx-preserve.cpp
  clang/unittests/AST/StructuralEquivalenceTest.cpp

Index: clang/unittests/AST/StructuralEquivalenceTest.cpp
===================================================================
--- clang/unittests/AST/StructuralEquivalenceTest.cpp
+++ clang/unittests/AST/StructuralEquivalenceTest.cpp
@@ -476,6 +476,16 @@
   EXPECT_FALSE(testStructuralMatch(t));
 }
 
+TEST_F(StructuralEquivalenceFunctionTest,
+       FunctionsWithDifferentAMXSavedRegsAttr) {
+  if (llvm::Triple(llvm::sys::getDefaultTargetTriple()).getArch() !=
+      llvm::Triple::x86_64)
+    return;
+  auto t = makeNamedDecls("__attribute__((amxpreserve)) void foo();",
+                          "                             void foo();", Lang_C99);
+  EXPECT_FALSE(testStructuralMatch(t));
+}
+
 struct StructuralEquivalenceCXXMethodTest : StructuralEquivalenceTest {
 };
 
Index: clang/test/SemaCXX/attr-x86-amx-preserve.cpp
===================================================================
--- /dev/null
+++ clang/test/SemaCXX/attr-x86-amx-preserve.cpp
@@ -0,0 +1,33 @@
+// RUN: %clang_cc1 -std=c++11 -triple x86_64-unknown-linux-gnu -fsyntax-only -verify %s
+
+struct a {
+  int b __attribute__((amxpreserve)); // expected-warning {{'amxpreserve' only applies to function types; type here is 'int'}}
+  static void foo(int *a) __attribute__((amxpreserve)) {}
+};
+
+struct a test __attribute__((amxpreserve)); // expected-warning {{'amxpreserve' only applies to function types; type here is 'struct a'}}
+
+__attribute__((amxpreserve(999))) void bar(int *) {} // expected-error {{'amxpreserve' attribute takes no arguments}}
+
+void __attribute__((amxpreserve)) foo(int *){}
+
+__attribute__((amxpreserve)) void foo2(int *) {}
+
+typedef __attribute__((amxpreserve)) void (*foo3)(int *);
+
+int (*foo4)(double a, __attribute__((amxpreserve)) float b); // expected-warning {{'amxpreserve' only applies to function types; type here is 'float'}}
+
+typedef void (*foo5)(int *);
+
+void foo6(){} // expected-note {{previous declaration is here}}
+
+void __attribute__((amxpreserve)) foo6(); // expected-error {{function declared with 'amxpreserve' attribute was previously declared without the 'amxpreserve' attribute}} 
+
+int main(int argc, char **argv) {
+  void (*fp)(int *) = foo; // expected-error {{cannot initialize a variable of type 'void (*)(int *)' with an lvalue of type 'void (int *) __attribute__((amxpreserve))'}} 
+  a::foo(&argc);
+  foo3 func = foo2;
+  func(&argc);
+  foo5 __attribute__((amxpreserve)) func2 = foo2;
+  return 0;
+}
Index: clang/test/SemaCXX/attr-non-x86-amx-preserve.cpp
===================================================================
--- /dev/null
+++ clang/test/SemaCXX/attr-non-x86-amx-preserve.cpp
@@ -0,0 +1,29 @@
+// RUN: %clang_cc1 -std=c++11 -triple armv7-unknown-linux-gnueabi -fsyntax-only -verify %s
+
+struct a {
+  int __attribute__((amxpreserve)) b;                     // expected-warning {{unknown attribute 'amxpreserve' ignored}}
+  static void foo(int *a) __attribute__((amxpreserve)) {} // expected-warning {{unknown attribute 'amxpreserve' ignored}}
+};
+
+struct a test __attribute__((amxpreserve)); // expected-warning {{unknown attribute 'amxpreserve' ignored}}
+
+__attribute__((amxpreserve(999))) void bar(int *) {} // expected-warning {{unknown attribute 'amxpreserve' ignored}}
+
+__attribute__((amxpreserve)) void foo(int *){} // expected-warning {{unknown attribute 'amxpreserve' ignored}}
+
+[[clang::amxpreserve]] void foo2(int *) {} // expected-warning {{unknown attribute 'amxpreserve' ignored}}
+
+typedef __attribute__((amxpreserve)) void (*foo3)(int *); // expected-warning {{unknown attribute 'amxpreserve' ignored}}
+
+typedef void (*foo5)(int *);
+
+int (*foo4)(double a, __attribute__((amxpreserve)) float b); // expected-warning {{unknown attribute 'amxpreserve' ignored}}
+
+int main(int argc, char **argv) {
+  void (*fp)(int *) = foo;
+  a::foo(&argc);
+  foo3 func = foo2;
+  func(&argc);
+  foo5 __attribute__((amxpreserve)) func2 = foo2; // expected-warning {{unknown attribute 'amxpreserve' ignored}}
+  return 0;
+}
Index: clang/test/Sema/attr-target-mv.c
===================================================================
--- clang/test/Sema/attr-target-mv.c
+++ clang/test/Sema/attr-target-mv.c
@@ -76,6 +76,11 @@
 // expected-note@+1 {{function multiversioning caused by this declaration}}
 int __attribute__((target("arch=ivybridge")))  prev_no_target2(void);
 
+void __attribute__((target("sse4.2"))) addtl_amx_attrs(void);
+//expected-error@+2 {{attribute 'target' multiversioning cannot be combined with attribute 'amxpreserve'}}
+void __attribute__((amxpreserve,target("arch=sandybridge")))
+addtl_amx_attrs(void);
+
 void __attribute__((target("sse4.2"))) addtl_attrs(void);
 //expected-error@+2 {{attribute 'target' multiversioning cannot be combined with attribute 'no_caller_saved_registers'}}
 void __attribute__((no_caller_saved_registers,target("arch=sandybridge")))
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -595,6 +595,7 @@
   Abv->Add(BitCodeAbbrevOp(0));                         // NoCallerSavedRegs
   Abv->Add(BitCodeAbbrevOp(0));                         // NoCfCheck
   Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 1)); // CmseNSCall
+  Abv->Add(BitCodeAbbrevOp(0));                         // AMXPreserve
   // FunctionProtoType
   Abv->Add(BitCodeAbbrevOp(0));                         // IsVariadic
   Abv->Add(BitCodeAbbrevOp(0));                         // HasTrailingReturn
Index: clang/lib/Sema/SemaType.cpp
===================================================================
--- clang/lib/Sema/SemaType.cpp
+++ clang/lib/Sema/SemaType.cpp
@@ -135,6 +135,7 @@
   case ParsedAttr::AT_CmseNSCall:                                              \
   case ParsedAttr::AT_AnyX86NoCallerSavedRegisters:                            \
   case ParsedAttr::AT_AnyX86NoCfCheck:                                         \
+  case ParsedAttr::AT_AMXPreserve:                                             \
     CALLING_CONV_ATTRS_CASELIST
 
 // Microsoft-specific type qualifiers.
@@ -7514,6 +7515,20 @@
     return true;
   }
 
+  if (attr.getKind() == ParsedAttr::AT_AMXPreserve) {
+    if (S.CheckAttrTarget(attr) || S.CheckAttrNoArgs(attr))
+      return true;
+
+    // Delay if this is not a function type.
+    if (!unwrapped.isFunctionType())
+      return false;
+
+    FunctionType::ExtInfo EI =
+        unwrapped.get()->getExtInfo().withAMXPreserve(true);
+    type = unwrapped.wrap(S, S.Context.adjustFunctionType(unwrapped.get(), EI));
+    return true;
+  }
+
   if (attr.getKind() == ParsedAttr::AT_AnyX86NoCallerSavedRegisters) {
     if (S.CheckAttrTarget(attr) || S.CheckAttrNoArgs(attr))
       return true;
Index: clang/lib/Sema/SemaDecl.cpp
===================================================================
--- clang/lib/Sema/SemaDecl.cpp
+++ clang/lib/Sema/SemaDecl.cpp
@@ -3526,6 +3526,18 @@
     RequiresAdjustment = true;
   }
 
+  if (OldTypeInfo.getAMXPreserve() != NewTypeInfo.getAMXPreserve()) {
+    if (NewTypeInfo.getAMXPreserve()) {
+      AMXPreserveAttr *Attr = New->getAttr<AMXPreserveAttr>();
+      Diag(New->getLocation(), diag::err_function_attribute_mismatch) << Attr;
+      Diag(OldLocation, diag::note_previous_declaration);
+      return true;
+    }
+
+    NewTypeInfo = NewTypeInfo.withAMXPreserve(true);
+    RequiresAdjustment = true;
+  }
+
   if (RequiresAdjustment) {
     const FunctionType *AdjustedType = New->getType()->getAs<FunctionType>();
     AdjustedType = Context.adjustFunctionType(AdjustedType, NewTypeInfo);
Index: clang/lib/CodeGen/CodeGenModule.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenModule.cpp
+++ clang/lib/CodeGen/CodeGenModule.cpp
@@ -1858,6 +1858,8 @@
     // carry an explicit noinline attribute.
     if (!F->hasFnAttribute(llvm::Attribute::AlwaysInline))
       B.addAttribute(llvm::Attribute::NoInline);
+  } else if (D->hasAttr<AMXPreserveAttr>()) {
+    B.addAttribute(llvm::Attribute::AMXPreserve);
   } else {
     // Otherwise, propagate the inline hint attribute and potentially use its
     // absence to mark things as noinline.
Index: clang/lib/CodeGen/CGCall.cpp
===================================================================
--- clang/lib/CodeGen/CGCall.cpp
+++ clang/lib/CodeGen/CGCall.cpp
@@ -823,6 +823,7 @@
   FI->NoReturn = info.getNoReturn();
   FI->ReturnsRetained = info.getProducesResult();
   FI->NoCallerSavedRegs = info.getNoCallerSavedRegs();
+  FI->AMXPreserve = info.getAMXPreserve();
   FI->NoCfCheck = info.getNoCfCheck();
   FI->Required = required;
   FI->HasRegParm = info.getHasRegParm();
@@ -2116,6 +2117,8 @@
       FuncAttrs.addAttribute(llvm::Attribute::NoCfCheck);
     if (TargetDecl->hasAttr<LeafAttr>())
       FuncAttrs.addAttribute(llvm::Attribute::NoCallback);
+    if (TargetDecl->hasAttr<AMXPreserveAttr>())
+      FuncAttrs.addAttribute(llvm::Attribute::AMXPreserve);
 
     HasOptnone = TargetDecl->hasAttr<OptimizeNoneAttr>();
     if (auto *AllocSize = TargetDecl->getAttr<AllocSizeAttr>()) {
Index: clang/lib/AST/TypePrinter.cpp
===================================================================
--- clang/lib/AST/TypePrinter.cpp
+++ clang/lib/AST/TypePrinter.cpp
@@ -1000,6 +1000,8 @@
        << Info.getRegParm() << ")))";
   if (Info.getNoCallerSavedRegs())
     OS << " __attribute__((no_caller_saved_registers))";
+  if (Info.getAMXPreserve())
+    OS << " __attribute__((amxpreserve))";
   if (Info.getNoCfCheck())
     OS << " __attribute__((nocf_check))";
 }
Index: clang/lib/AST/ASTStructuralEquivalence.cpp
===================================================================
--- clang/lib/AST/ASTStructuralEquivalence.cpp
+++ clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -617,6 +617,8 @@
     return false;
   if (EI1.getNoCallerSavedRegs() != EI2.getNoCallerSavedRegs())
     return false;
+  if (EI1.getAMXPreserve() != EI2.getAMXPreserve())
+    return false;
   if (EI1.getNoCfCheck() != EI2.getNoCfCheck())
     return false;
 
Index: clang/lib/AST/ASTContext.cpp
===================================================================
--- clang/lib/AST/ASTContext.cpp
+++ clang/lib/AST/ASTContext.cpp
@@ -9632,6 +9632,8 @@
     return {};
   if (lbaseInfo.getNoCallerSavedRegs() != rbaseInfo.getNoCallerSavedRegs())
     return {};
+  if (lbaseInfo.getAMXPreserve() != rbaseInfo.getAMXPreserve())
+    return {};
   if (lbaseInfo.getNoCfCheck() != rbaseInfo.getNoCfCheck())
     return {};
 
Index: clang/include/clang/CodeGen/CGFunctionInfo.h
===================================================================
--- clang/include/clang/CodeGen/CGFunctionInfo.h
+++ clang/include/clang/CodeGen/CGFunctionInfo.h
@@ -579,6 +579,9 @@
   /// Whether this function saved caller registers.
   unsigned NoCallerSavedRegs : 1;
 
+  /// Whether this function preserve AMX state.
+  unsigned AMXPreserve : 1;
+
   /// How many arguments to pass inreg.
   unsigned HasRegParm : 1;
   unsigned RegParm : 3;
@@ -671,6 +674,9 @@
   /// Whether this function no longer saves caller registers.
   bool isNoCallerSavedRegs() const { return NoCallerSavedRegs; }
 
+  /// Whether this function preserve AMX state.
+  bool isAMXPreserve() const { return AMXPreserve; }
+
   /// Whether this function has nocf_check attribute.
   bool isNoCfCheck() const { return NoCfCheck; }
 
@@ -700,7 +706,7 @@
     return FunctionType::ExtInfo(isNoReturn(), getHasRegParm(), getRegParm(),
                                  getASTCallingConvention(), isReturnsRetained(),
                                  isNoCallerSavedRegs(), isNoCfCheck(),
-                                 isCmseNSCall());
+                                 isCmseNSCall(), isAMXPreserve());
   }
 
   CanQualType getReturnType() const { return getArgsBuffer()[0].type; }
@@ -742,6 +748,7 @@
     ID.AddInteger(RegParm);
     ID.AddBoolean(NoCfCheck);
     ID.AddBoolean(CmseNSCall);
+    ID.AddBoolean(AMXPreserve);
     ID.AddInteger(Required.getOpaqueData());
     ID.AddBoolean(HasExtParameterInfos);
     if (HasExtParameterInfos) {
@@ -770,6 +777,7 @@
     ID.AddInteger(info.getRegParm());
     ID.AddBoolean(info.getNoCfCheck());
     ID.AddBoolean(info.getCmseNSCall());
+    ID.AddBoolean(info.getAMXPreserve());
     ID.AddInteger(required.getOpaqueData());
     ID.AddBoolean(!paramInfos.empty());
     if (!paramInfos.empty()) {
Index: clang/include/clang/Basic/AttrDocs.td
===================================================================
--- clang/include/clang/Basic/AttrDocs.td
+++ clang/include/clang/Basic/AttrDocs.td
@@ -4457,6 +4457,39 @@
   }];
 }
 
+def AMXPreserveDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+Use this attribute to indicate that the specified function has no
+caller-saved AMX registers. Compiler doesn't save and restore any
+AMX register across function call. It is user's responsibility that
+ensure there is no AMX register clobber in the function with "amxpreserve"
+attribute.
+
+Like 'no_caller_saved_registers', 'amxpreserve' attribute is not a
+calling convention. In fact, it only overrides the decision of which
+AMX registers should be saved by the caller.
+
+For example:
+
+  .. code-block:: c
+
+    __attribute__ ((amxpreserve ))
+    void f () {
+      ...
+    }
+
+    void bar () {
+      ...
+      f();
+      ...
+    }
+
+  In this case compiler doesn't save and restore AMX registers across the
+  call of f().
+  }];
+}
+
 def X86ForceAlignArgPointerDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
Index: clang/include/clang/Basic/Attr.td
===================================================================
--- clang/include/clang/Basic/Attr.td
+++ clang/include/clang/Basic/Attr.td
@@ -2892,6 +2892,12 @@
   let SimpleHandler = 1;
 }
 
+def AMXPreserve : InheritableAttr, TargetSpecificAttr<TargetAnyX86> {
+  let Spellings = [Clang<"amxpreserve">];
+  let Documentation = [AMXPreserveDocs];
+  let SimpleHandler = 1;
+}
+
 def AnyX86Interrupt : InheritableAttr, TargetSpecificAttr<TargetAnyX86> {
   // NOTE: If you add any additional spellings, ARMInterrupt's,
   // M68kInterrupt's, MSP430Interrupt's and MipsInterrupt's spellings must match.
Index: clang/include/clang/AST/TypeProperties.td
===================================================================
--- clang/include/clang/AST/TypeProperties.td
+++ clang/include/clang/AST/TypeProperties.td
@@ -287,6 +287,9 @@
   def : Property<"cmseNSCall", Bool> {
     let Read = [{ node->getExtInfo().getCmseNSCall() }];
   }
+  def : Property<"amxPreserve", Bool> {
+    let Read = [{ node->getExtInfo().getAMXPreserve() }];
+  }
 }
 
 let Class = FunctionNoProtoType in {
@@ -294,7 +297,7 @@
     auto extInfo = FunctionType::ExtInfo(noReturn, hasRegParm, regParm,
                                          callingConvention, producesResult,
                                          noCallerSavedRegs, noCfCheck,
-                                         cmseNSCall);
+                                         cmseNSCall, amxPreserve);
     return ctx.getFunctionNoProtoType(returnType, extInfo);
   }]>;
 }
@@ -328,7 +331,7 @@
     auto extInfo = FunctionType::ExtInfo(noReturn, hasRegParm, regParm,
                                          callingConvention, producesResult,
                                          noCallerSavedRegs, noCfCheck,
-                                         cmseNSCall);
+                                         cmseNSCall, amxPreserve);
     FunctionProtoType::ExtProtoInfo epi;
     epi.ExtInfo = extInfo;
     epi.Variadic = variadic;
Index: clang/include/clang/AST/Type.h
===================================================================
--- clang/include/clang/AST/Type.h
+++ clang/include/clang/AST/Type.h
@@ -1585,7 +1585,7 @@
 
     /// Extra information which affects how the function is called, like
     /// regparm and the calling convention.
-    unsigned ExtInfo : 13;
+    unsigned ExtInfo : 14;
 
     /// The ref-qualifier associated with a \c FunctionProtoType.
     ///
@@ -1823,7 +1823,7 @@
   Type(TypeClass tc, QualType canon, TypeDependence Dependence)
       : ExtQualsTypeCommonBase(this,
                                canon.isNull() ? QualType(this_(), 0) : canon) {
-    static_assert(sizeof(*this) <= 8 + sizeof(ExtQualsTypeCommonBase),
+    static_assert(sizeof(*this) <= 16 + sizeof(ExtQualsTypeCommonBase),
                   "changing bitfields changed sizeof(Type)!");
     static_assert(alignof(decltype(*this)) % sizeof(void *) == 0,
                   "Insufficient alignment!");
@@ -3663,6 +3663,8 @@
 
     // |  CC  |noreturn|produces|nocallersavedregs|regparm|nocfcheck|cmsenscall|
     // |0 .. 4|   5    |    6   |       7         |8 .. 10|    11   |    12    |
+    // |amxpreserve|
+    // |   13      |
     //
     // regparm is either 0 (no regparm attribute) or the regparm value+1.
     enum { CallConvMask = 0x1F };
@@ -3675,6 +3677,7 @@
     };
     enum { NoCfCheckMask = 0x800 };
     enum { CmseNSCallMask = 0x1000 };
+    enum { AMXPreserveMask = 0x2000 };
     uint16_t Bits = CC_C;
 
     ExtInfo(unsigned Bits) : Bits(static_cast<uint16_t>(Bits)) {}
@@ -3684,14 +3687,15 @@
     // have all the elements (when reading an AST file for example).
     ExtInfo(bool noReturn, bool hasRegParm, unsigned regParm, CallingConv cc,
             bool producesResult, bool noCallerSavedRegs, bool NoCfCheck,
-            bool cmseNSCall) {
+            bool cmseNSCall, bool amxPreserve) {
       assert((!hasRegParm || regParm < 7) && "Invalid regparm value");
       Bits = ((unsigned)cc) | (noReturn ? NoReturnMask : 0) |
              (producesResult ? ProducesResultMask : 0) |
              (noCallerSavedRegs ? NoCallerSavedRegsMask : 0) |
              (hasRegParm ? ((regParm + 1) << RegParmOffset) : 0) |
              (NoCfCheck ? NoCfCheckMask : 0) |
-             (cmseNSCall ? CmseNSCallMask : 0);
+             (cmseNSCall ? CmseNSCallMask : 0) |
+             (amxPreserve ? AMXPreserveMask : 0);
     }
 
     // Constructor with all defaults. Use when for example creating a
@@ -3706,6 +3710,7 @@
     bool getProducesResult() const { return Bits & ProducesResultMask; }
     bool getCmseNSCall() const { return Bits & CmseNSCallMask; }
     bool getNoCallerSavedRegs() const { return Bits & NoCallerSavedRegsMask; }
+    bool getAMXPreserve() const { return Bits & AMXPreserveMask; }
     bool getNoCfCheck() const { return Bits & NoCfCheckMask; }
     bool getHasRegParm() const { return ((Bits & RegParmMask) >> RegParmOffset) != 0; }
 
@@ -3756,6 +3761,13 @@
         return ExtInfo(Bits & ~NoCallerSavedRegsMask);
     }
 
+    ExtInfo withAMXPreserve(bool amxPreserve) const {
+      if (amxPreserve)
+        return ExtInfo(Bits | AMXPreserveMask);
+      else
+        return ExtInfo(Bits & ~AMXPreserveMask);
+    }
+
     ExtInfo withNoCfCheck(bool noCfCheck) const {
       if (noCfCheck)
         return ExtInfo(Bits | NoCfCheckMask);
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to