llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-flang-runtime Author: Slava Zakharin (vzakhari) <details> <summary>Changes</summary> The visitor only allows Internal.*IoStatementState variants to be visited. In case another variant is met a runtime error is produced. During the device compilation the other variants' classes are not referenced, which, for example, helps to avoid warnings about __host__ only methods referenced in __device__ code. I had problems parameterizing the Fortran::common visitor to limit the allowed variants, but I can give it another try if creating a copy looks inappropriate. --- Full diff: https://github.com/llvm/llvm-project/pull/85179.diff 2 Files Affected: - (modified) flang/runtime/io-stmt.cpp (+20-27) - (modified) flang/runtime/io-stmt.h (+56-2) ``````````diff diff --git a/flang/runtime/io-stmt.cpp b/flang/runtime/io-stmt.cpp index 075d7b5ae518a4..efefbc5e1a1c08 100644 --- a/flang/runtime/io-stmt.cpp +++ b/flang/runtime/io-stmt.cpp @@ -467,69 +467,66 @@ int ExternalFormattedIoStatementState<DIR, CHAR>::EndIoStatement() { } Fortran::common::optional<DataEdit> IoStatementState::GetNextDataEdit(int n) { - return common::visit( - [&](auto &x) { return x.get().GetNextDataEdit(*this, n); }, u_); + return visit([&](auto &x) { return x.get().GetNextDataEdit(*this, n); }, u_); } bool IoStatementState::Emit( const char *data, std::size_t bytes, std::size_t elementBytes) { - return common::visit( + return visit( [=](auto &x) { return x.get().Emit(data, bytes, elementBytes); }, u_); } bool IoStatementState::Receive( char *data, std::size_t n, std::size_t elementBytes) { - return common::visit( + return visit( [=](auto &x) { return x.get().Receive(data, n, elementBytes); }, u_); } std::size_t IoStatementState::GetNextInputBytes(const char *&p) { - return common::visit( - [&](auto &x) { return x.get().GetNextInputBytes(p); }, u_); + return visit([&](auto &x) { return x.get().GetNextInputBytes(p); }, u_); } bool IoStatementState::AdvanceRecord(int n) { - return common::visit([=](auto &x) { return x.get().AdvanceRecord(n); }, u_); + return visit([=](auto &x) { return x.get().AdvanceRecord(n); }, u_); } void IoStatementState::BackspaceRecord() { - common::visit([](auto &x) { x.get().BackspaceRecord(); }, u_); + visit([](auto &x) { x.get().BackspaceRecord(); }, u_); } void IoStatementState::HandleRelativePosition(std::int64_t n) { - common::visit([=](auto &x) { x.get().HandleRelativePosition(n); }, u_); + visit([=](auto &x) { x.get().HandleRelativePosition(n); }, u_); } void IoStatementState::HandleAbsolutePosition(std::int64_t n) { - common::visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_); + visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_); } void IoStatementState::CompleteOperation() { - common::visit([](auto &x) { x.get().CompleteOperation(); }, u_); + visit([](auto &x) { x.get().CompleteOperation(); }, u_); } int IoStatementState::EndIoStatement() { - return common::visit([](auto &x) { return x.get().EndIoStatement(); }, u_); + return visit([](auto &x) { return x.get().EndIoStatement(); }, u_); } ConnectionState &IoStatementState::GetConnectionState() { - return common::visit( + return visit( [](auto &x) -> ConnectionState & { return x.get().GetConnectionState(); }, u_); } MutableModes &IoStatementState::mutableModes() { - return common::visit( + return visit( [](auto &x) -> MutableModes & { return x.get().mutableModes(); }, u_); } bool IoStatementState::BeginReadingRecord() { - return common::visit( - [](auto &x) { return x.get().BeginReadingRecord(); }, u_); + return visit([](auto &x) { return x.get().BeginReadingRecord(); }, u_); } IoErrorHandler &IoStatementState::GetIoErrorHandler() const { - return common::visit( + return visit( [](auto &x) -> IoErrorHandler & { return static_cast<IoErrorHandler &>(x.get()); }, @@ -537,8 +534,7 @@ IoErrorHandler &IoStatementState::GetIoErrorHandler() const { } ExternalFileUnit *IoStatementState::GetExternalFileUnit() const { - return common::visit( - [](auto &x) { return x.get().GetExternalFileUnit(); }, u_); + return visit([](auto &x) { return x.get().GetExternalFileUnit(); }, u_); } Fortran::common::optional<char32_t> IoStatementState::GetCurrentChar( @@ -664,28 +660,25 @@ bool IoStatementState::CheckForEndOfRecord(std::size_t afterReading) { bool IoStatementState::Inquire( InquiryKeywordHash inquiry, char *out, std::size_t chars) { - return common::visit( + return visit( [&](auto &x) { return x.get().Inquire(inquiry, out, chars); }, u_); } bool IoStatementState::Inquire(InquiryKeywordHash inquiry, bool &out) { - return common::visit( - [&](auto &x) { return x.get().Inquire(inquiry, out); }, u_); + return visit([&](auto &x) { return x.get().Inquire(inquiry, out); }, u_); } bool IoStatementState::Inquire( InquiryKeywordHash inquiry, std::int64_t id, bool &out) { - return common::visit( - [&](auto &x) { return x.get().Inquire(inquiry, id, out); }, u_); + return visit([&](auto &x) { return x.get().Inquire(inquiry, id, out); }, u_); } bool IoStatementState::Inquire(InquiryKeywordHash inquiry, std::int64_t &n) { - return common::visit( - [&](auto &x) { return x.get().Inquire(inquiry, n); }, u_); + return visit([&](auto &x) { return x.get().Inquire(inquiry, n); }, u_); } std::int64_t IoStatementState::InquirePos() { - return common::visit([&](auto &x) { return x.get().InquirePos(); }, u_); + return visit([&](auto &x) { return x.get().InquirePos(); }, u_); } void IoStatementState::GotChar(int n) { diff --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h index e00d54980aae59..7fecf4d9e41754 100644 --- a/flang/runtime/io-stmt.h +++ b/flang/runtime/io-stmt.h @@ -18,7 +18,6 @@ #include "io-error.h" #include "flang/Common/optional.h" #include "flang/Common/reference-wrapper.h" -#include "flang/Common/visit.h" #include "flang/Runtime/descriptor.h" #include "flang/Runtime/io-api.h" #include <functional> @@ -113,7 +112,7 @@ class IoStatementState { // N.B.: this also works with base classes template <typename A> A *get_if() const { - return common::visit( + return visit( [](auto &x) -> A * { if constexpr (std::is_convertible_v<decltype(x.get()), A &>) { return &x.get(); @@ -211,6 +210,61 @@ class IoStatementState { } private: + // Define special visitor for the variants of IoStatementState. + // During the device code compilation the visitor only allows + // visiting those variants that are supported on the device. + // In particular, only the internal IO variants are supported. + // TODO: parameterize Fortran::common::log2visit instead of + // creating a copy here. + template <class T, class... Ts> + struct is_any_type : std::bool_constant<(std::is_same_v<T, Ts> || ...)> {}; + + template <std::size_t LOW, std::size_t HIGH, typename RESULT, + typename VISITOR, typename VARIANT> + static inline RT_API_ATTRS RESULT Log2VisitHelper( + VISITOR &&visitor, std::size_t which, VARIANT &&u) { +#if !defined(RT_DEVICE_COMPILATION) + constexpr bool isDevice{false}; +#else + constexpr bool isDevice{true}; +#endif + if constexpr (LOW == HIGH) { + if constexpr (!isDevice || + is_any_type< + std::variant_alternative_t<LOW, std::decay_t<decltype(u)>>, + Fortran::common::reference_wrapper< + InternalListIoStatementState<Direction::Output>>, + Fortran::common::reference_wrapper< + InternalFormattedIoStatementState<Direction::Output>>>:: + value) { + return visitor(std::get<LOW>(std::forward<VARIANT>(u))); + } else { + Terminator{__FILE__, __LINE__}.Crash( + "not implemented yet: IoStatementState variant %d\n", + static_cast<int>(LOW)); + } + } else { + static constexpr std::size_t mid{(HIGH + LOW) / 2}; + if (which <= mid) { + return Log2VisitHelper<LOW, mid, RESULT>( + std::forward<VISITOR>(visitor), which, std::forward<VARIANT>(u)); + } else { + return Log2VisitHelper<(mid + 1), HIGH, RESULT>( + std::forward<VISITOR>(visitor), which, std::forward<VARIANT>(u)); + } + } + } + + template <typename VISITOR, typename VARIANT> + static inline RT_API_ATTRS auto visit(VISITOR &&visitor, VARIANT &&u) + -> decltype(visitor(std::get<0>(std::forward<VARIANT>(u)))) { + using Result = decltype(visitor(std::get<0>(std::forward<VARIANT>(u)))); + static constexpr std::size_t high{ + std::variant_size_v<std::decay_t<decltype(u)>> - 1}; + return Log2VisitHelper<0, high, Result>( + std::forward<VISITOR>(visitor), u.index(), std::forward<VARIANT>(u)); + } + std::variant<Fortran::common::reference_wrapper<OpenStatementState>, Fortran::common::reference_wrapper<CloseStatementState>, Fortran::common::reference_wrapper<NoopStatementState>, `````````` </details> https://github.com/llvm/llvm-project/pull/85179 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits