ergawy updated this revision to Diff 315658.
ergawy marked 5 inline comments as done.
ergawy added a comment.

Handle review comments.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D93591/new/

https://reviews.llvm.org/D93591

Files:
  mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
  mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
  mlir/lib/Target/SPIRV/Deserialization.cpp
  mlir/lib/Target/SPIRV/Serialization.cpp
  mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
  mlir/test/Target/SPIRV/spec-constant.mlir

Index: mlir/test/Target/SPIRV/spec-constant.mlir
===================================================================
--- mlir/test/Target/SPIRV/spec-constant.mlir
+++ mlir/test/Target/SPIRV/spec-constant.mlir
@@ -85,3 +85,34 @@
   // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32>
   spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32>
 }
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+
+  spv.specConstant @sc_i32_1 = 1 : i32
+
+  spv.func @use_composite() -> (i32) "None" {
+    // CHECK: [[USE1:%.*]] = spv.mlir.referenceof @sc_i32_1 : i32
+    // CHECK: [[USE2:%.*]] = spv.constant 0 : i32
+
+    // CHECK: [[RES1:%.*]] = spv.SpecConstantOperation wraps "spv.ISub"([[USE1]], [[USE2]]) : (i32, i32) -> i32
+
+    // CHECK: [[USE3:%.*]] = spv.mlir.referenceof @sc_i32_1 : i32
+    // CHECK: [[USE4:%.*]] = spv.constant 0 : i32
+
+    // CHECK: [[RES2:%.*]] = spv.SpecConstantOperation wraps "spv.ISub"([[USE3]], [[USE4]]) : (i32, i32) -> i32
+
+    %0 = spv.mlir.referenceof @sc_i32_1 : i32
+    %1 = spv.constant 0 : i32
+    %2 = spv.SpecConstantOperation wraps "spv.ISub"(%0, %1) : (i32, i32) -> i32
+
+    // CHECK: [[RES3:%.*]] = spv.SpecConstantOperation wraps "spv.IMul"([[RES1]], [[RES2]]) : (i32, i32) -> i32
+    %3 = spv.SpecConstantOperation wraps "spv.IMul"(%2, %2) : (i32, i32) -> i32
+
+    // Make sure deserialization continues from the right place after creating
+    // the previous op.
+    // CHECK: spv.ReturnValue [[RES3]]
+    spv.ReturnValue %3 : i32
+  }
+}
Index: mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
===================================================================
--- mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -780,6 +780,20 @@
 
 // -----
 
+spv.module Logical GLSL450 {
+  spv.specConstant @sc = 42 : i32
+
+  spv.func @foo() -> i32 "None" {
+    // CHECK: [[SC:%.*]] = spv.mlir.referenceof @sc
+    %0 = spv.mlir.referenceof @sc : i32
+    // CHECK: spv.SpecConstantOperation wraps "spv.ISub"([[SC]], [[SC]]) : (i32, i32) -> i32
+    %1 = spv.SpecConstantOperation wraps "spv.ISub"(%0, %0) : (i32, i32) -> i32
+    spv.ReturnValue %1 : i32
+  }
+}
+
+// -----
+
 spv.module Logical GLSL450 {
   spv.func @foo() -> i32 "None" {
     %0 = spv.constant 1: i32
Index: mlir/lib/Target/SPIRV/Serialization.cpp
===================================================================
--- mlir/lib/Target/SPIRV/Serialization.cpp
+++ mlir/lib/Target/SPIRV/Serialization.cpp
@@ -204,6 +204,9 @@
   LogicalResult
   processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
 
+  LogicalResult
+  processSpecConstantOperationOp(spirv::SpecConstantOperationOp op);
+
   /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA
   /// value to use with other operations. The SPIR-V spec recommends that
   /// OpUndef be generated at module level. The serialization generates an
@@ -711,6 +714,49 @@
   return processName(resultID, op.sym_name());
 }
 
+LogicalResult
+Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
+  uint32_t typeID = 0;
+  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
+    return failure();
+  }
+
+  auto resultID = getNextID();
+
+  SmallVector<uint32_t, 8> operands;
+  operands.push_back(typeID);
+  operands.push_back(resultID);
+
+  Block &block = op.getRegion().getBlocks().front();
+  Operation &enclosedOp = block.getOperations().front();
+
+  std::string enclosedOpName;
+  llvm::raw_string_ostream rss(enclosedOpName);
+  rss << "Op" << enclosedOp.getName().stripDialect();
+  auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
+
+  if (!enclosedOpcode) {
+    op.emitError("Couldn't find op code for op ")
+        << enclosedOp.getName().getStringRef();
+    return failure();
+  }
+
+  operands.push_back(static_cast<uint32_t>(enclosedOpcode.getValue()));
+
+  // Append operands to the enclosed op to the list of operands.
+  for (Value operand : enclosedOp.getOperands()) {
+    uint32_t id = getValueID(operand);
+    assert(id && "use before def!");
+    operands.push_back(id);
+  }
+
+  encodeInstructionInto(typesGlobalValues,
+                        spirv::Opcode::OpSpecConstantOperation, operands);
+  valueIDMap[op.getResult()] = resultID;
+
+  return success();
+}
+
 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
   auto undefType = op.getType();
   auto &id = undefValIDMap[undefType];
@@ -1929,6 +1975,9 @@
       .Case([&](spirv::SpecConstantCompositeOp op) {
         return processSpecConstantCompositeOp(op);
       })
+      .Case([&](spirv::SpecConstantOperationOp op) {
+        return processSpecConstantOperationOp(op);
+      })
       .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
       .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
 
Index: mlir/lib/Target/SPIRV/Deserialization.cpp
===================================================================
--- mlir/lib/Target/SPIRV/Deserialization.cpp
+++ mlir/lib/Target/SPIRV/Deserialization.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Target/SPIRV/Deserialization.h"
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVModule.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
@@ -28,6 +29,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/bit.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
@@ -132,6 +134,14 @@
   SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
 };
 
+/// A struct that collects the info needed to materialize/emit a
+/// SpecConstantOperation op.
+struct SpecConstOperationMaterializationInfo {
+  spirv::Opcode enclodesOpcode;
+  uint32_t resultTypeID;
+  SmallVector<uint32_t> enclosedOpOperands;
+};
+
 //===----------------------------------------------------------------------===//
 // Deserializer Declaration
 //===----------------------------------------------------------------------===//
@@ -216,9 +226,14 @@
   /// Gets the constant's attribute and type associated with the given <id>.
   Optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
 
-  /// Gets the constant's integer attribute with the given <id>. Returns a null
-  /// IntegerAttr if the given is not registered or does not correspond to an
-  /// integer constant.
+  /// Gets the info needed to materialize the spec constant operation op
+  /// associated with the given <id>.
+  Optional<SpecConstOperationMaterializationInfo>
+  getSpecConstantOperation(uint32_t id);
+
+  /// Gets the constant's integer attribute with the given <id>. Returns a
+  /// null IntegerAttr if the given is not registered or does not correspond
+  /// to an integer constant.
   IntegerAttr getConstantInt(uint32_t id);
 
   /// Returns a symbol to be used for the function name with the given
@@ -305,8 +320,20 @@
   /// `operands`.
   LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
 
+  /// Processes a SPIR-V OpSpecConstantComposite instruction with the given
+  /// `operands`.
   LogicalResult processSpecConstantComposite(ArrayRef<uint32_t> operands);
 
+  /// Processes a SPIR-V OpSpecConstantOperation instruction with the given
+  /// `operands`.
+  LogicalResult processSpecConstantOperation(ArrayRef<uint32_t> operands);
+
+  /// Materializes/emits an OpSpecConstantOperation instruction.
+  Value materializeSpecConstantOperation(uint32_t resultID,
+                                         spirv::Opcode enclosedOpcode,
+                                         uint32_t resultTypeID,
+                                         ArrayRef<uint32_t> enclosedOpOperands);
+
   /// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
   LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
 
@@ -534,6 +561,11 @@
   // Result <id> to composite spec constant mapping.
   DenseMap<uint32_t, spirv::SpecConstantCompositeOp> specConstCompositeMap;
 
+  /// Result <id> to info needed to materialize an OpSpecConstantOperation
+  /// mapping.
+  DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
+      specConstOperationMap;
+
   // Result <id> to variable mapping.
   DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
 
@@ -1036,6 +1068,14 @@
   return constIt->getSecond();
 }
 
+Optional<SpecConstOperationMaterializationInfo>
+Deserializer::getSpecConstantOperation(uint32_t id) {
+  auto constIt = specConstOperationMap.find(id);
+  if (constIt == specConstOperationMap.end())
+    return llvm::None;
+  return constIt->getSecond();
+}
+
 std::string Deserializer::getFunctionSymbol(uint32_t id) {
   auto funcName = nameMap.lookup(id).str();
   if (funcName.empty()) {
@@ -1745,6 +1785,91 @@
   return success();
 }
 
+LogicalResult
+Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
+  if (operands.size() < 3)
+    return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
+                                 "result <id>, and operand opcode");
+
+  uint32_t resultTypeID = operands[0];
+
+  if (!getType(resultTypeID))
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << resultTypeID;
+
+  uint32_t resultID = operands[1];
+  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
+  auto emplaceResult = specConstOperationMap.try_emplace(
+      resultID,
+      SpecConstOperationMaterializationInfo{
+          enclosedOpcode, resultTypeID,
+          SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
+
+  if (!emplaceResult.second)
+    return emitError(unknownLoc, "value with <id>: ")
+           << resultID << " is probably defined before.";
+
+  return success();
+}
+
+Value Deserializer::materializeSpecConstantOperation(
+    uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
+    ArrayRef<uint32_t> enclosedOpOperands) {
+
+  Type resultType = getType(resultTypeID);
+
+  // Instructions wrapped by OpSpecConstantOp need an ID for their
+  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
+  // dialect wrapped op. For that purpose, a new value map is created and "fake"
+  // ID in that map is assigned to the result of the enclosed instruction. Note
+  // that there is no need to update this fake ID since we only need to
+  // reference the created Value for the enclosed op from the spv::YieldOp
+  // created later in this method (both of which are the only values in their
+  // region: the SpecConstantOperation's region). If we encounter another
+  // SpecConstantOperation in the module, we simply re-use the fake ID since the
+  // previous Value assigned to it isn't visible in the current scope anyway.
+  DenseMap<uint32_t, Value> newValueMap;
+  llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap,
+                                                                newValueMap);
+  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
+
+  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
+  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
+  enclosedOpResultTypeAndOperands.push_back(fakeID);
+  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
+                                         enclosedOpOperands.end());
+
+  // Process enclosed instruction before creating the enclosing
+  // specConstantOperation (and its region). This way, references to constants,
+  // global variables, and spec constants will be materialized outside the new
+  // op's region. For more info, see Deserializer::getValue's implementation.
+  if (failed(
+          processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
+    return Value();
+
+  // Since the enclosed op is emitted in the current block, split it in a
+  // separate new block.
+  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
+
+  auto loc = createFileLineColLoc(opBuilder);
+  auto specConstOperationOp =
+      opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
+
+  Region &body = specConstOperationOp.body();
+  // Move the new block into SpecConstantOperation's body.
+  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
+                          Region::iterator(enclosedBlock));
+  Block &block = body.back();
+
+  // RAII guard to reset the insertion point to the module's region after
+  // deserializing the body of the specConstantOperation.
+  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
+  opBuilder.setInsertionPointToEnd(&block);
+
+  opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
+  return specConstOperationOp.getResult();
+}
+
 LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {
     return emitError(unknownLoc,
@@ -2378,6 +2503,12 @@
         opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
     return referenceOfOp.reference();
   }
+  if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
+    return materializeSpecConstantOperation(
+        id, specConstOperationInfo->enclodesOpcode,
+        specConstOperationInfo->resultTypeID,
+        specConstOperationInfo->enclosedOpOperands);
+  }
   if (auto undef = getUndefType(id)) {
     return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
   }
@@ -2483,6 +2614,8 @@
     return processConstantComposite(operands);
   case spirv::Opcode::OpSpecConstantComposite:
     return processSpecConstantComposite(operands);
+  case spirv::Opcode::OpSpecConstantOperation:
+    return processSpecConstantOperation(operands);
   case spirv::Opcode::OpConstantTrue:
     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
   case spirv::Opcode::OpSpecConstantTrue:
Index: mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
===================================================================
--- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -3445,9 +3445,8 @@
     return constOp.emitOpError("invalid enclosed op");
 
   for (auto operand : enclosedOp.getOperands())
-    if (!isa<spirv::ConstantOp, spirv::SpecConstantOp,
-             spirv::SpecConstantCompositeOp, spirv::SpecConstantOperationOp>(
-            operand.getDefiningOp()))
+    if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
+             spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
       return constOp.emitOpError(
           "invalid operand, must be defined by a constant operation");
 
Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
===================================================================
--- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3170,6 +3170,7 @@
 def SPV_OC_OpSpecConstantFalse         : I32EnumAttrCase<"OpSpecConstantFalse", 49>;
 def SPV_OC_OpSpecConstant              : I32EnumAttrCase<"OpSpecConstant", 50>;
 def SPV_OC_OpSpecConstantComposite     : I32EnumAttrCase<"OpSpecConstantComposite", 51>;
+def SPV_OC_OpSpecConstantOperation     : I32EnumAttrCase<"OpSpecConstantOperation", 52>;
 def SPV_OC_OpFunction                  : I32EnumAttrCase<"OpFunction", 54>;
 def SPV_OC_OpFunctionParameter         : I32EnumAttrCase<"OpFunctionParameter", 55>;
 def SPV_OC_OpFunctionEnd               : I32EnumAttrCase<"OpFunctionEnd", 56>;
@@ -3314,7 +3315,8 @@
       SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant,
       SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue,
       SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
-      SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
+      SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOperation,
+      SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
       SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
       SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
       SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic,
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to