================
@@ -45,6 +46,93 @@
using namespace cir;
using namespace llvm;
+
+static std::string getLLVMIntrinsicNameForType(mlir::Type llvmTy) {
+ std::string s;
+ {
+ llvm::raw_string_ostream os(s);
+ os << llvmTy;
+ }
+ return s;
+}
+
+// Actual lowering
+mlir::LogicalResult CIRToLLVMSqrtOpLowering::matchAndRewrite(
+ cir::SqrtOp op, typename cir::SqrtOp::Adaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+
+ mlir::Location loc = op.getLoc();
+ mlir::MLIRContext *ctx = rewriter.getContext();
+
+ mlir::Type cirResTy = op.getResult().getType();
+ mlir::Type llvmResTy = getTypeConverter()->convertType(cirResTy);
+ if (!llvmResTy)
+ return op.emitOpError(
+ "expected LLVM dialect result type for cir.sqrt lowering");
+
+ Value operand = adaptor.getInput();
+ Value llvmOperand = operand;
+ if (operand.getType() != llvmResTy) {
+ llvmOperand = rewriter.create<LLVM::BitcastOp>(loc, llvmResTy, operand);
+ }
+
+ // Build the llvm.sqrt.* intrinsic name depending on scalar vs vector result
+ std::string intrinsicName = "llvm.sqrt.";
+ std::string suffix;
+
+ // If the CIR result type is a vector, include the 'vN' part in the suffix.
+ if (auto vec = cirResTy.dyn_cast<cir::VectorType>()) {
+ Type elt = vec.getElementType();
+ if (auto f = elt.dyn_cast<cir::FloatType>()) {
+ unsigned width = f.getWidth();
+ unsigned n = vec.getNumElements();
+ if (width == 32)
+ suffix = "v" + std::to_string(n) + "f32";
+ else if (width == 64)
+ suffix = "v" + std::to_string(n) + "f64";
+ else if (width == 16)
+ suffix = "v" + std::to_string(n) + "f16";
+ else
+ return op.emitOpError("unsupported float width for sqrt");
+ } else {
+ return op.emitOpError("vector element must be floating point for sqrt");
+ }
+ } else if (auto f = cirResTy.dyn_cast<cir::FloatType>()) {
+ // Scalar float
+ unsigned width = f.getWidth();
+ if (width == 32)
+ suffix = "f32";
+ else if (width == 64)
+ suffix = "f64";
+ else if (width == 16)
+ suffix = "f16";
+ else
+ return op.emitOpError("unsupported float width for sqrt");
+ } else {
+ return op.emitOpError("unsupported type for cir.sqrt lowering");
+ }
+
+ intrinsicName += suffix;
+
+ // Ensure the llvm intrinsic function exists at module scope. Insert it at
+ // the start of the module body using an insertion guard.
+ ModuleOp module = op->getParentOfType<ModuleOp>();
+ if (!module.lookupSymbol<LLVM::LLVMFuncOp>(intrinsicName)) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(module.getBody());
+ auto llvmFnType = LLVM::LLVMFunctionType::get(ctx, llvmResTy, {llvmResTy},
+ /*isVarArg=*/false);
+ rewriter.create<LLVM::LLVMFuncOp>(loc, intrinsicName, llvmFnType);
+ }
+
+ // Create the call and replace cir.sqrt
+ auto callee = SymbolRefAttr::get(ctx, intrinsicName);
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, llvmResTy, callee,
+ ArrayRef<Value>{llvmOperand});
+
+ return mlir::success();
----------------
andykaylor wrote:
```suggestion
mlir::Type resTy = typeConverter->convertType(op.getType());
rewriter.replaceOpWithNewOp<mlir::LLVM::SqrtOp>(op, resTy,
adaptor.getSrc());
return mlir::success();
```
The LLVM dialect's SqrtOp is a direct replacement and will ultimately be
lowered to the intrinsic.
https://github.com/llvm/llvm-project/pull/169310
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits