================ @@ -335,49 +336,129 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion, for (auto [i, opOrSingle] : llvm::enumerate(regions)) { bool isLast = i + 1 == regions.size(); if (std::holds_alternative<SingleRegion>(opOrSingle)) { - OpBuilder singleBuilder(sourceRegion.getContext()); - Block *singleBlock = new Block(); - singleBuilder.setInsertionPointToStart(singleBlock); - OpBuilder allocaBuilder(sourceRegion.getContext()); Block *allocaBlock = new Block(); allocaBuilder.setInsertionPointToStart(allocaBlock); - OpBuilder parallelBuilder(sourceRegion.getContext()); - Block *parallelBlock = new Block(); - parallelBuilder.setInsertionPointToStart(parallelBlock); - - auto [allParallelized, copyprivateVars] = - moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder, - singleBuilder, parallelBuilder); - if (allParallelized) { - // The single region was not required as all operations were safe to - // parallelize - assert(copyprivateVars.empty()); - assert(allocaBlock->empty()); - delete singleBlock; + it = block.begin(); + while (&*it != terminator) + if (isa<hlfir::SumOp>(it)) + break; + else + it++; + + if (auto sumOp = dyn_cast<hlfir::SumOp>(it)) { + /// Implementation: + /// Intrinsic function `SUM` operations + /// -- + /// x = sum(array) + /// + /// is converted to + /// + /// !$omp parallel do + /// do i = 1, size(array) + /// x = x + array(i) + /// end do + /// !$omp end parallel do + + OpBuilder wslBuilder(sourceRegion.getContext()); + Block *wslBlock = new Block(); + wslBuilder.setInsertionPointToStart(wslBlock); + + Value target = dyn_cast<hlfir::AssignOp>(++it).getLhs(); + Value array = sumOp.getArray(); + Value dim = sumOp.getDim(); + fir::SequenceType arrayTy = dyn_cast<fir::SequenceType>( + hlfir::getFortranElementOrSequenceType(array.getType())); + llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape(); + if (arrayShape.size() == 1 && !dim) { + Value itr = allocaBuilder.create<fir::AllocaOp>( + loc, allocaBuilder.getI64Type()); + Value c_one = allocaBuilder.create<arith::ConstantOp>( + loc, allocaBuilder.getI64IntegerAttr(1)); + Value c_arr_size = allocaBuilder.create<arith::ConstantOp>( + loc, allocaBuilder.getI64IntegerAttr(arrayShape[0])); + // Value c_zero = allocaBuilder.create<arith::ConstantOp>(loc, + // allocaBuilder.getZeroAttr(arrayTy.getEleTy())); + // allocaBuilder.create<fir::StoreOp>(loc, c_zero, target); + + omp::WsloopOperands wslOps; + omp::WsloopOp wslOp = + rootBuilder.create<omp::WsloopOp>(loc, wslOps); + + hlfir::LoopNest ln; + ln.outerOp = wslOp; + omp::LoopNestOperands lnOps; + lnOps.loopLowerBounds.push_back(c_one); + lnOps.loopUpperBounds.push_back(c_arr_size); + lnOps.loopSteps.push_back(c_one); + lnOps.loopInclusive = wslBuilder.getUnitAttr(); + omp::LoopNestOp lnOp = + wslBuilder.create<omp::LoopNestOp>(loc, lnOps); + Block *lnBlock = wslBuilder.createBlock(&lnOp.getRegion()); + lnBlock->addArgument(c_one.getType(), loc); + wslBuilder.create<fir::StoreOp>( + loc, lnOp.getRegion().getArgument(0), itr); + Value tarLoad = wslBuilder.create<fir::LoadOp>(loc, target); + Value itrLoad = wslBuilder.create<fir::LoadOp>(loc, itr); + hlfir::DesignateOp arrDesOp = wslBuilder.create<hlfir::DesignateOp>( + loc, fir::ReferenceType::get(arrayTy.getEleTy()), array, + itrLoad); + Value desLoad = wslBuilder.create<fir::LoadOp>(loc, arrDesOp); + Value addf = + wslBuilder.create<arith::AddFOp>(loc, tarLoad, desLoad); + wslBuilder.create<fir::StoreOp>(loc, addf, target); + wslBuilder.create<omp::YieldOp>(loc); + ln.body = lnBlock; + wslOp.getRegion().push_back(wslBlock); + targetRegion.front().getOperations().splice( + wslOp->getIterator(), allocaBlock->getOperations()); + } else { + emitError(loc, "Only 1D array scalar assignment for sum " ---------------- tblah wrote:
instead of emitting an error here it would be better to go to the outer else branch and use the runtime library version of SUM https://github.com/llvm/llvm-project/pull/113082 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits