================
@@ -184,6 +185,53 @@ static mlir::Value emitX86Select(CIRGenBuilderTy &builder,
mlir::Location loc,
return cir::VecTernaryOp::create(builder, loc, mask, op0, op1);
}
+static mlir::Value emitX86ScalarSelect(CIRGenBuilderTy &builder,
+ mlir::Location loc, mlir::Value mask,
+ mlir::Value op0, mlir::Value op1) {
+
+ // If the mask is all ones just return first argument.
+ if (auto c = mlir::dyn_cast_or_null<cir::ConstantOp>(mask.getDefiningOp()))
+ if (c.isAllOnesValue())
+ return op0;
+
+ // Extract the scalar values from the vector operands
+ auto vecTy0 = mlir::dyn_cast<cir::VectorType>(op0.getType());
+ auto vecTy1 = mlir::dyn_cast<cir::VectorType>(op1.getType());
+
+ mlir::Value scalar0 = op0;
+ mlir::Value scalar1 = op1;
+
+ if (vecTy0)
+ scalar0 = builder.createExtractElement(loc, op0, uint64_t(0));
+
+ if (vecTy1)
+ scalar1 = builder.createExtractElement(loc, op1, uint64_t(0));
+
+ // Get the mask as a vector of i1 and extract bit 0
+ auto intTy = mlir::dyn_cast<cir::IntType>(mask.getType());
+ assert(intTy && "mask must be an integer type");
+ unsigned width = intTy.getWidth();
+
+ auto i1Ty = builder.getUIntNTy(1);
+ auto maskVecTy = cir::VectorType::get(i1Ty, width);
+ mlir::Value maskVec = builder.createBitcast(mask, maskVecTy);
+
+ // Extract bit 0 from the mask vector
+ mlir::Value bit0 = builder.createExtractElement(loc, maskVec, uint64_t(0));
+
+ // Convert i1 to bool for select
+ auto boolTy = cir::BoolType::get(builder.getContext());
+ mlir::Value cond = cir::CastOp::create(builder, loc, boolTy,
+ cir::CastKind::int_to_bool, bit0);
+
+ mlir::Value result = builder.createSelect(loc, cond, scalar0, scalar1);
+
+ if (vecTy0)
+ result = builder.createInsertElement(loc, op0, result, uint64_t(0));
----------------
Priyanshu3820 wrote:
moved
https://github.com/llvm/llvm-project/pull/174003
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits