DarkSharpness opened a new pull request, #228:
URL: https://github.com/apache/tvm-ffi/pull/228
Similar to pybind, we add a `stl.h` which support `array`, `vector`,
`tuple`, `optional` and `variant`. After this file is included, users can use
native C++ components, which could hopefully improve compatibility and reduce
manually effort in converting from tvm::ffi components to C++ components.
Example code:
```cpp
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/stl.h>
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <algorithm>
#include <array>
#include <cstddef>
#include <numeric>
#include <optional>
#include <variant>
#include <vector>
namespace {
// optional, array, vector, tuple is supported
auto sum_row(std::optional<std::vector<std::array<int, 2>>> arg)
-> std::tuple<bool, std::vector<int>> {
if (arg) {
auto result = std::vector<int>{};
result.reserve(arg->size());
for (const auto& row : *arg) {
result.push_back(std::accumulate(row.begin(), row.end(), 0));
}
return {true, result};
} else {
return {false, {}};
}
}
// (const) reference is also supported, though not recommended and won't
bring performance gain
// all types must be cast to value, and then pass by reference
auto find_diff(const std::vector<int>& a, std::vector<int>& b) ->
std::size_t {
const auto max_pos = std::min(a.size(), b.size());
for (std::size_t i = 0; i < max_pos; ++i) {
if (a[i] != b[i]) {
return i;
}
}
return max_pos;
}
auto test_variant(std::variant<int, float, std::vector<int>>&& arg)
-> std::variant<int, std::vector<int>> {
if (std::holds_alternative<int>(arg)) {
std::vector<int> result;
auto& value = std::get<int>(arg);
result.reserve(value);
for (int i = 0; i < value; ++i) {
result.push_back(i);
}
return result;
} else if (std::holds_alternative<float>(arg)) {
return static_cast<int>(std::get<float>(arg));
} else {
auto& value = std::get<std::vector<int>>(arg);
std::reverse(value.begin(), value.end());
return std::move(value);
}
}
TVM_FFI_DLL_EXPORT_TYPED_FUNC(sum_row, sum_row);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(find_diff, find_diff);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(test_variant, test_variant);
} // namespace
```
Python part:
```
from __future__ import annotations
from tvm_ffi.cpp import load_inline
from pathlib import Path
cur_path = Path(__file__).parent
with open(cur_path / "stl.cpp") as f:
cpp_source = f.read()
module = load_inline(
"test_stl",
cpp_sources = cpp_source,
)
print(module.sum_row([[1, 2], [3, 4]])) # Expected output: (True, [3, 7])
print(module.sum_row(None)) # Expected output: (False, [])
print(module.find_diff([1, 2, 3, 4], [1, 2, 4, 3])) # Expected output: 2
(index = 2)
print(module.test_variant(2)) # Expected output: [0, 1]
print(module.test_variant(3.1)) # Expected output: 3
print(module.test_variant([1, 2])) # Expected output: [2, 1]
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]