================
@@ -45,6 +46,93 @@
 using namespace cir;
 using namespace llvm;
 
+
+static std::string getLLVMIntrinsicNameForType(mlir::Type llvmTy) {
+  std::string s;
+  {
+    llvm::raw_string_ostream os(s);
+    os << llvmTy;
+  }
+  return s;
+}
+
+// Actual lowering
+mlir::LogicalResult CIRToLLVMSqrtOpLowering::matchAndRewrite(
+    cir::SqrtOp op, typename cir::SqrtOp::Adaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+
+  mlir::Location loc = op.getLoc();
+  mlir::MLIRContext *ctx = rewriter.getContext();
+
+  mlir::Type cirResTy = op.getResult().getType();
+  mlir::Type llvmResTy = getTypeConverter()->convertType(cirResTy);
+  if (!llvmResTy)
+    return op.emitOpError(
+        "expected LLVM dialect result type for cir.sqrt lowering");
+
+  Value operand = adaptor.getInput();
+  Value llvmOperand = operand;
+  if (operand.getType() != llvmResTy) {
+    llvmOperand = rewriter.create<LLVM::BitcastOp>(loc, llvmResTy, operand);
+  }
+
+  // Build the llvm.sqrt.* intrinsic name depending on scalar vs vector result
+  std::string intrinsicName = "llvm.sqrt.";
+  std::string suffix;
+
+  // If the CIR result type is a vector, include the 'vN' part in the suffix.
+  if (auto vec = cirResTy.dyn_cast<cir::VectorType>()) {
+    Type elt = vec.getElementType();
+    if (auto f = elt.dyn_cast<cir::FloatType>()) {
+      unsigned width = f.getWidth();
+      unsigned n = vec.getNumElements();
+      if (width == 32)
+        suffix = "v" + std::to_string(n) + "f32";
+      else if (width == 64)
+        suffix = "v" + std::to_string(n) + "f64";
+      else if (width == 16)
+        suffix = "v" + std::to_string(n) + "f16";
+      else
+        return op.emitOpError("unsupported float width for sqrt");
+    } else {
+      return op.emitOpError("vector element must be floating point for sqrt");
+    }
+  } else if (auto f = cirResTy.dyn_cast<cir::FloatType>()) {
+    // Scalar float
+    unsigned width = f.getWidth();
+    if (width == 32)
+      suffix = "f32";
+    else if (width == 64)
+      suffix = "f64";
+    else if (width == 16)
+      suffix = "f16";
+    else
+      return op.emitOpError("unsupported float width for sqrt");
+  } else {
+    return op.emitOpError("unsupported type for cir.sqrt lowering");
+  }
+
+  intrinsicName += suffix;
+
+  // Ensure the llvm intrinsic function exists at module scope. Insert it at
+  // the start of the module body using an insertion guard.
+  ModuleOp module = op->getParentOfType<ModuleOp>();
+  if (!module.lookupSymbol<LLVM::LLVMFuncOp>(intrinsicName)) {
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(module.getBody());
+    auto llvmFnType = LLVM::LLVMFunctionType::get(ctx, llvmResTy, {llvmResTy},
+                                                  /*isVarArg=*/false);
+    rewriter.create<LLVM::LLVMFuncOp>(loc, intrinsicName, llvmFnType);
+  }
+
+  // Create the call and replace cir.sqrt
+  auto callee = SymbolRefAttr::get(ctx, intrinsicName);
+  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, llvmResTy, callee,
+                                            ArrayRef<Value>{llvmOperand});
+
+  return mlir::success();
----------------
andykaylor wrote:

```suggestion
  mlir::Type resTy = typeConverter->convertType(op.getType());
  rewriter.replaceOpWithNewOp<mlir::LLVM::SqrtOp>(op, resTy,
                                                    adaptor.getSrc());
  return mlir::success();

```
The LLVM dialect's SqrtOp is a direct replacement and will ultimately be 
lowered to the intrinsic.

https://github.com/llvm/llvm-project/pull/169310
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to