================ @@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() { if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { return emitOpError("the output shape is not expected"); } + return success(); } +SmallVector<Range> +WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + auto indexType = builder.getIndexType(); + auto zeroAttr = builder.getIntegerAttr(indexType, 0); + auto oneAttr = builder.getIntegerAttr(indexType, 1); + Value output = getOutput(); + SmallVector<Range> loopBounds(6); + for (unsigned dim = 0; dim < 6; ++dim) { + loopBounds[dim].offset = zeroAttr; + loopBounds[dim].size = getDimValue(builder, loc, output, dim); + loopBounds[dim].stride = oneAttr; + } + return loopBounds; +} + +SmallVector<utils::IteratorType> +WinogradInputTransformOp::getLoopIteratorTypes() { + SmallVector<utils::IteratorType> iteratorTypes(6, + utils::IteratorType::parallel); + return iteratorTypes; +} + +LogicalResult WinogradInputTransformOp::getResultTilePosition( + OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, + SmallVector<OpFoldResult> &resultSizes) { + auto zeroAttr = builder.getI64IntegerAttr(0); + auto oneAttr = builder.getI64IntegerAttr(1); + + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(offsets[2]); + resultOffsets.push_back(offsets[3]); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultSizes.push_back(sizes[0]); + resultSizes.push_back(sizes[1]); + resultSizes.push_back(oneAttr); + resultSizes.push_back(oneAttr); + resultSizes.push_back(sizes[4]); + resultSizes.push_back(sizes[5]); + + return success(); +} + +FailureOr<TilingResult> +WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes) { + auto oneAttr = builder.getI64IntegerAttr(1); + auto zeroAttr = builder.getI64IntegerAttr(0); + Value input = getInput(); + auto inputType = cast<ShapedType>(input.getType()); + auto inputShape = inputType.getShape(); ---------------- ftynse wrote:
Here and below. https://github.com/llvm/llvm-project/pull/96184 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits