https://github.com/erichkeane created https://github.com/llvm/llvm-project/pull/146414
This implements the async, wait, if, and if_present (as well as device_type, but that is a detail of async/wait) lowering. All of these are implemented the same way they are for the compute constructs, so this is a pretty mild amount of changes. >From 74735a498e303bffc14175b06169105a3989c951 Mon Sep 17 00:00:00 2001 From: erichkeane <eke...@nvidia.com> Date: Mon, 30 Jun 2025 10:24:33 -0700 Subject: [PATCH] [OpenACC][CIR] Implement 'rest' of update clause lowering This implements the async, wait, if, and if_present (as well as device_type, but that is a detail of async/wait) lowering. All of these are implemented the same way they are for the compute constructs, so this is a pretty mild amount of changes. --- clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp | 35 +++--- clang/test/CIR/CodeGenOpenACC/update.c | 111 ++++++++++++++++++ .../mlir/Dialect/OpenACC/OpenACCOps.td | 15 +++ mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 43 +++++++ 4 files changed, 186 insertions(+), 18 deletions(-) diff --git a/clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp b/clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp index b7a73e2f62945..2623b9bffe6ae 100644 --- a/clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp @@ -376,7 +376,8 @@ class OpenACCClauseCIREmitter final // on all operation types. mlir::ArrayAttr getAsyncOnlyAttr() { if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp, - mlir::acc::KernelsOp, mlir::acc::DataOp>) { + mlir::acc::KernelsOp, mlir::acc::DataOp, + mlir::acc::UpdateOp>) { return operation.getAsyncOnlyAttr(); } else if constexpr (isOneOfTypes<OpTy, mlir::acc::EnterDataOp, mlir::acc::ExitDataOp>) { @@ -401,7 +402,8 @@ class OpenACCClauseCIREmitter final // on all operation types. mlir::ArrayAttr getAsyncOperandsDeviceTypeAttr() { if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp, - mlir::acc::KernelsOp, mlir::acc::DataOp>) { + mlir::acc::KernelsOp, mlir::acc::DataOp, + mlir::acc::UpdateOp>) { return operation.getAsyncOperandsDeviceTypeAttr(); } else if constexpr (isOneOfTypes<OpTy, mlir::acc::EnterDataOp, mlir::acc::ExitDataOp>) { @@ -427,7 +429,8 @@ class OpenACCClauseCIREmitter final // on all operation types. mlir::OperandRange getAsyncOperands() { if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp, - mlir::acc::KernelsOp, mlir::acc::DataOp>) + mlir::acc::KernelsOp, mlir::acc::DataOp, + mlir::acc::UpdateOp>) return operation.getAsyncOperands(); else if constexpr (isOneOfTypes<OpTy, mlir::acc::EnterDataOp, mlir::acc::ExitDataOp>) @@ -522,7 +525,8 @@ class OpenACCClauseCIREmitter final decodeDeviceType(clause.getArchitectures()[0].getIdentifierInfo())); } else if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp, mlir::acc::KernelsOp, - mlir::acc::DataOp, mlir::acc::LoopOp>) { + mlir::acc::DataOp, mlir::acc::LoopOp, + mlir::acc::UpdateOp>) { // Nothing to do here, these constructs don't have any IR for these, as // they just modify the other clauses IR. So setting of // `lastDeviceTypeValues` (done above) is all we need. @@ -531,7 +535,7 @@ class OpenACCClauseCIREmitter final // 'lastDeviceTypeValues' to set the value for the child visitor. } else { // TODO: When we've implemented this for everything, switch this to an - // unreachable. update, data, routine constructs remain. + // unreachable. routine construct remains. return clauseNotImplemented(clause); } } @@ -566,7 +570,8 @@ class OpenACCClauseCIREmitter final hasAsyncClause = true; if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp, mlir::acc::KernelsOp, mlir::acc::DataOp, - mlir::acc::EnterDataOp, mlir::acc::ExitDataOp>) { + mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, + mlir::acc::UpdateOp>) { if (!clause.hasIntExpr()) { operation.addAsyncOnly(builder.getContext(), lastDeviceTypeValues); } else { @@ -655,27 +660,20 @@ class OpenACCClauseCIREmitter final mlir::acc::ShutdownOp, mlir::acc::SetOp, mlir::acc::DataOp, mlir::acc::WaitOp, mlir::acc::HostDataOp, mlir::acc::EnterDataOp, - mlir::acc::ExitDataOp>) { + mlir::acc::ExitDataOp, mlir::acc::UpdateOp>) { operation.getIfCondMutable().append( createCondition(clause.getConditionExpr())); } else if constexpr (isCombinedType<OpTy>) { applyToComputeOp(clause); } else { - // 'if' applies to most of the constructs, but hold off on lowering them - // until we can write tests/know what we're doing with codegen to make - // sure we get it right. - // TODO: When we've implemented this for everything, switch this to an - // unreachable. update construct remains. - return clauseNotImplemented(clause); + llvm_unreachable("Unknown construct kind in VisitIfClause"); } } void VisitIfPresentClause(const OpenACCIfPresentClause &clause) { - if constexpr (isOneOfTypes<OpTy, mlir::acc::HostDataOp>) { + if constexpr (isOneOfTypes<OpTy, mlir::acc::HostDataOp, + mlir::acc::UpdateOp>) { operation.setIfPresent(true); - } else if constexpr (isOneOfTypes<OpTy, mlir::acc::UpdateOp>) { - // Last unimplemented one here, so just put it in this way instead. - return clauseNotImplemented(clause); } else { llvm_unreachable("unknown construct kind in VisitIfPresentClause"); } @@ -710,7 +708,8 @@ class OpenACCClauseCIREmitter final void VisitWaitClause(const OpenACCWaitClause &clause) { if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp, mlir::acc::KernelsOp, mlir::acc::DataOp, - mlir::acc::EnterDataOp, mlir::acc::ExitDataOp>) { + mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, + mlir::acc::UpdateOp>) { if (!clause.hasExprs()) { operation.addWaitOnly(builder.getContext(), lastDeviceTypeValues); } else { diff --git a/clang/test/CIR/CodeGenOpenACC/update.c b/clang/test/CIR/CodeGenOpenACC/update.c index 4e25a1df2a42b..2b29504e6ca20 100644 --- a/clang/test/CIR/CodeGenOpenACC/update.c +++ b/clang/test/CIR/CodeGenOpenACC/update.c @@ -64,4 +64,115 @@ void acc_update(int parmVar, int *ptrParmVar) { // CHECK-NEXT: %[[UPD_DEV2:.*]] = acc.update_device varPtr(%[[PTRPARM]] : !cir.ptr<!cir.ptr<!s32i>>) -> !cir.ptr<!cir.ptr<!s32i>> {name = "ptrParmVar", structured = false} // CHECK-NEXT: acc.update dataOperands(%[[GDP1]], %[[UPD_DEV2]] : !cir.ptr<!s32i>, !cir.ptr<!cir.ptr<!s32i>>) // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) if (parmVar == 1) + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: %[[PARM_LOAD:.*]] = cir.load{{.*}} %[[PARM]] + // CHECK-NEXT: %[[ONE_CONST:.*]] = cir.const #cir.int<1> + // CHECK-NEXT: %[[CMP:.*]] = cir.cmp(eq, %[[PARM_LOAD]], %[[ONE_CONST]]) + // CHECK-NEXT: %[[CMP_CAST:.*]] = builtin.unrealized_conversion_cast %[[CMP]] + // CHECK-NEXT: acc.update if(%[[CMP_CAST]]) dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} +#pragma acc update self(parmVar) if (parmVar == 1) if_present + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: %[[PARM_LOAD:.*]] = cir.load{{.*}} %[[PARM]] + // CHECK-NEXT: %[[ONE_CONST:.*]] = cir.const #cir.int<1> + // CHECK-NEXT: %[[CMP:.*]] = cir.cmp(eq, %[[PARM_LOAD]], %[[ONE_CONST]]) + // CHECK-NEXT: %[[CMP_CAST:.*]] = builtin.unrealized_conversion_cast %[[CMP]] + // CHECK-NEXT: acc.update if(%[[CMP_CAST]]) dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) attributes {ifPresent} + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) wait + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: acc.update wait dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) wait device_type(nvidia) + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: acc.update wait dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) device_type(radeon) wait + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: acc.update wait([#acc.device_type<radeon>]) dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) wait(parmVar) + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: %[[PARM_LOAD:.*]] = cir.load{{.*}} %[[PARM]] + // CHECK-NEXT: %[[PARM_CAST:.*]] = builtin.unrealized_conversion_cast %[[PARM_LOAD]] + // CHECK-NEXT: acc.update wait({%[[PARM_CAST]] : si32}) dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) wait(parmVar) device_type(nvidia) + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: %[[PARM_LOAD:.*]] = cir.load{{.*}} %[[PARM]] + // CHECK-NEXT: %[[PARM_CAST:.*]] = builtin.unrealized_conversion_cast %[[PARM_LOAD]] + // CHECK-NEXT: acc.update wait({%[[PARM_CAST]] : si32}) dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) device_type(radeon) wait(parmVar) + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: %[[PARM_LOAD:.*]] = cir.load{{.*}} %[[PARM]] + // CHECK-NEXT: %[[PARM_CAST:.*]] = builtin.unrealized_conversion_cast %[[PARM_LOAD]] + // CHECK-NEXT: acc.update wait({%[[PARM_CAST]] : si32} [#acc.device_type<radeon>]) dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) device_type(radeon) wait(parmVar, 1, 2) + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: %[[PARM_LOAD:.*]] = cir.load{{.*}} %[[PARM]] + // CHECK-NEXT: %[[PARM_CAST:.*]] = builtin.unrealized_conversion_cast %[[PARM_LOAD]] + // CHECK-NEXT: %[[ONE_CONST:.*]] = cir.const #cir.int<1> + // CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_CONST]] + // CHECK-NEXT: %[[TWO_CONST:.*]] = cir.const #cir.int<2> + // CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_CONST]] + // CHECK-NEXT: acc.update wait({%[[PARM_CAST]] : si32, %[[ONE_CAST]] : si32, %[[TWO_CAST]] : si32} [#acc.device_type<radeon>]) dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) device_type(radeon) wait(devnum:parmVar: 1, 2) + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: %[[PARM_LOAD:.*]] = cir.load{{.*}} %[[PARM]] + // CHECK-NEXT: %[[PARM_CAST:.*]] = builtin.unrealized_conversion_cast %[[PARM_LOAD]] + // CHECK-NEXT: %[[ONE_CONST:.*]] = cir.const #cir.int<1> + // CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_CONST]] + // CHECK-NEXT: %[[TWO_CONST:.*]] = cir.const #cir.int<2> + // CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_CONST]] + // CHECK-NEXT: acc.update wait({devnum: %[[PARM_CAST]] : si32, %[[ONE_CAST]] : si32, %[[TWO_CAST]] : si32} [#acc.device_type<radeon>]) dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) async + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) async -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: acc.update async dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) async to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) async device_type(nvidia) + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) async -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: acc.update async dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) async to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) device_type(radeon) async + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) async([#acc.device_type<radeon>]) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: acc.update async([#acc.device_type<radeon>]) dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) async([#acc.device_type<radeon>]) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) async(parmVar) + // CHECK-NEXT: %[[PARM_LOAD:.*]] = cir.load{{.*}} %[[PARM]] + // CHECK-NEXT: %[[PARM_CAST:.*]] = builtin.unrealized_conversion_cast %[[PARM_LOAD]] + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) async(%[[PARM_CAST]] : si32) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: acc.update async(%[[PARM_CAST]] : si32) dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) async(%[[PARM_CAST]] : si32) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) async(parmVar) device_type(nvidia) + // CHECK-NEXT: %[[PARM_LOAD:.*]] = cir.load{{.*}} %[[PARM]] + // CHECK-NEXT: %[[PARM_CAST:.*]] = builtin.unrealized_conversion_cast %[[PARM_LOAD]] + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) async(%[[PARM_CAST]] : si32) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: acc.update async(%[[PARM_CAST]] : si32) dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) async(%[[PARM_CAST]] : si32) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + +#pragma acc update self(parmVar) device_type(radeon) async(parmVar) + // CHECK-NEXT: %[[PARM_LOAD:.*]] = cir.load{{.*}} %[[PARM]] + // CHECK-NEXT: %[[PARM_CAST:.*]] = builtin.unrealized_conversion_cast %[[PARM_LOAD]] + // CHECK-NEXT: %[[GDP1:.*]] = acc.getdeviceptr varPtr(%[[PARM]] : !cir.ptr<!s32i>) async(%[[PARM_CAST]] : si32 [#acc.device_type<radeon>]) -> !cir.ptr<!s32i> {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} + // CHECK-NEXT: acc.update async(%[[PARM_CAST]] : si32 [#acc.device_type<radeon>]) dataOperands(%[[GDP1]] : !cir.ptr<!s32i>) + // CHECK-NEXT: acc.update_host accPtr(%[[GDP1]] : !cir.ptr<!s32i>) async(%[[PARM_CAST]] : si32 [#acc.device_type<radeon>]) to varPtr(%[[PARM]] : !cir.ptr<!s32i>) {dataClause = #acc<data_clause acc_update_self>, name = "parmVar", structured = false} } diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 9aaf9040c25b7..276b74bd43772 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -3028,6 +3028,21 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", /// Return the wait devnum value clause for the given device_type if /// present. mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType); + /// Add an entry to the 'async-only' attribute (clause spelled without + /// arguments)for each of the additional device types (or a none if it is + /// empty). + void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>); + /// Add a value to the 'async' with the current list of device types. + void addAsyncOperand(MLIRContext *, mlir::Value, + llvm::ArrayRef<DeviceType>); + /// Add an entry to the 'wait-only' attribute (clause spelled without + /// arguments)for each of the additional device types (or a none if it is + /// empty). + void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>); + /// Add an array-like entry to the 'wait' with the current list of device + /// types. + void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange, + llvm::ArrayRef<DeviceType>); }]; let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 0fcdf7be57c81..80c807e774a7e 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -3854,6 +3854,49 @@ mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { deviceType); } +void UpdateOp::addAsyncOnly(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper( + context, getAsyncOnlyAttr(), effectiveDeviceTypes)); +} + +void UpdateOp::addAsyncOperand( + MLIRContext *context, mlir::Value newValue, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper( + context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue, + getAsyncOperandsMutable())); +} + +void UpdateOp::addWaitOnly(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(), + effectiveDeviceTypes)); +} + +void UpdateOp::addWaitOperands( + MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + + llvm::SmallVector<int32_t> segments; + if (getWaitOperandsSegments()) + llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments)); + + setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper( + context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues, + getWaitOperandsMutable(), segments)); + setWaitOperandsSegments(segments); + + llvm::SmallVector<mlir::Attribute> hasDevnums; + if (getHasWaitDevnumAttr()) + llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums)); + hasDevnums.insert( + hasDevnums.end(), + std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)), + mlir::BoolAttr::get(context, hasDevnum)); + setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums)); +} + //===----------------------------------------------------------------------===// // WaitOp //===----------------------------------------------------------------------===// _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits