This is an automated email from the ASF dual-hosted git repository.
zanmato pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 7de2f61762 GH-48167: [Python][C++][Compute] Add python bindings for
scatter, inverse_permutation (#48267)
7de2f61762 is described below
commit 7de2f61762e09044073659d859ac5e87dd66f6b9
Author: tadeja <[email protected]>
AuthorDate: Mon Dec 29 08:14:36 2025 +0100
GH-48167: [Python][C++][Compute] Add python bindings for scatter,
inverse_permutation (#48267)
### Rationale for this change
To close or discuss #48167
`inverse_permutation` and `scatter` functions got implemented via #44393,
PR #44394.
### What changes are included in this PR?
Python tests for `scatter`, `inverse_permutation` kernels and bindings for
`InversePermutationOptions` and `ScatterOptions`.
### Are these changes tested?
Yes, tests added in test_compute.py.
### Are there any user-facing changes?
Bindings for `InversePermutationOptions` and `ScatterOptions` are added.
#### This PR includes breaking changes to public APIs.
Options `InversePermutationOptions` changed from accepting parameter
`std::shared_ptr<DataType> output_type = NULLPTR`
to
`std::optional<std::shared_ptr<DataType>> output_type = std::nullopt`
* GitHub Issue: #48167
Lead-authored-by: Tadeja Kadunc <[email protected]>
Co-authored-by: Rossi Sun <[email protected]>
Co-authored-by: tadeja <[email protected]>
Signed-off-by: Rossi Sun <[email protected]>
---
cpp/src/arrow/compute/api_vector.cc | 2 +-
cpp/src/arrow/compute/api_vector.h | 11 +++--
cpp/src/arrow/compute/function_internal.h | 7 +--
cpp/src/arrow/compute/kernels/vector_swizzle.cc | 16 +++----
.../arrow/compute/kernels/vector_swizzle_test.cc | 3 +-
docs/source/python/api/compute.rst | 4 +-
python/pyarrow/_compute.pyx | 54 ++++++++++++++++++++++
python/pyarrow/compute.py | 2 +
python/pyarrow/includes/libarrow.pxd | 11 +++++
python/pyarrow/tests/test_compute.py | 34 +++++++++++++-
10 files changed, 122 insertions(+), 22 deletions(-)
diff --git a/cpp/src/arrow/compute/api_vector.cc
b/cpp/src/arrow/compute/api_vector.cc
index 538cdccaf2..1bf4de9352 100644
--- a/cpp/src/arrow/compute/api_vector.cc
+++ b/cpp/src/arrow/compute/api_vector.cc
@@ -257,7 +257,7 @@ ListFlattenOptions::ListFlattenOptions(bool recursive)
constexpr char ListFlattenOptions::kTypeName[];
InversePermutationOptions::InversePermutationOptions(
- int64_t max_index, std::shared_ptr<DataType> output_type)
+ int64_t max_index, std::optional<std::shared_ptr<DataType>> output_type)
: FunctionOptions(internal::kInversePermutationOptionsType),
max_index(max_index),
output_type(std::move(output_type)) {}
diff --git a/cpp/src/arrow/compute/api_vector.h
b/cpp/src/arrow/compute/api_vector.h
index b1676219b1..159a787641 100644
--- a/cpp/src/arrow/compute/api_vector.h
+++ b/cpp/src/arrow/compute/api_vector.h
@@ -298,8 +298,9 @@ class ARROW_EXPORT ListFlattenOptions : public
FunctionOptions {
/// \brief Options for inverse_permutation function
class ARROW_EXPORT InversePermutationOptions : public FunctionOptions {
public:
- explicit InversePermutationOptions(int64_t max_index = -1,
- std::shared_ptr<DataType> output_type =
NULLPTR);
+ explicit InversePermutationOptions(
+ int64_t max_index = -1,
+ std::optional<std::shared_ptr<DataType>> output_type = std::nullopt);
static constexpr const char kTypeName[] = "InversePermutationOptions";
static InversePermutationOptions Defaults() { return
InversePermutationOptions(); }
@@ -308,11 +309,11 @@ class ARROW_EXPORT InversePermutationOptions : public
FunctionOptions {
/// of the input indices minus 1 and the length of the function's output
will be the
/// length of the input indices.
int64_t max_index = -1;
- /// \brief The type of the output inverse permutation. If null, the output
will be of
- /// the same type as the input indices, otherwise must be signed integer
type. An
+ /// \brief The data type for the output array of inverse permutation.
Defaults to the
+ /// type of the input indices when `nullopt`. Must be a signed integer type.
An
/// invalid error will be reported if this type is not able to store the
length of the
/// input indices.
- std::shared_ptr<DataType> output_type = NULLPTR;
+ std::optional<std::shared_ptr<DataType>> output_type;
};
/// \brief Options for scatter function
diff --git a/cpp/src/arrow/compute/function_internal.h
b/cpp/src/arrow/compute/function_internal.h
index 9d8928466b..7bea4043a5 100644
--- a/cpp/src/arrow/compute/function_internal.h
+++ b/cpp/src/arrow/compute/function_internal.h
@@ -382,9 +382,10 @@ static inline Result<std::shared_ptr<Scalar>>
GenericToScalar(std::nullopt_t) {
}
template <typename T>
-static inline auto GenericToScalar(const std::optional<T>& value)
- -> Result<decltype(MakeScalar(value.value()))> {
- return value.has_value() ? MakeScalar(value.value()) :
std::make_shared<NullScalar>();
+static inline Result<std::shared_ptr<Scalar>> GenericToScalar(
+ const std::optional<T>& value) {
+ return value.has_value() ? GenericToScalar(value.value())
+ : std::make_shared<NullScalar>();
}
template <typename T>
diff --git a/cpp/src/arrow/compute/kernels/vector_swizzle.cc
b/cpp/src/arrow/compute/kernels/vector_swizzle.cc
index aa82f55c2b..cf9f5379a6 100644
--- a/cpp/src/arrow/compute/kernels/vector_swizzle.cc
+++ b/cpp/src/arrow/compute/kernels/vector_swizzle.cc
@@ -32,7 +32,8 @@ namespace {
const FunctionDoc inverse_permutation_doc(
"Return the inverse permutation of the given indices",
- "For the `i`-th `index` in `indices`, the `index`-th output is `i`",
{"indices"});
+ "For the `i`-th `index` in `indices`, the `index`-th output is `i`",
{"indices"},
+ "InversePermutationOptions");
const InversePermutationOptions* GetDefaultInversePermutationOptions() {
static const auto kDefaultInversePermutationOptions =
@@ -50,10 +51,8 @@ Result<TypeHolder> ResolveInversePermutationOutputType(
DCHECK_EQ(input_types.size(), 1);
DCHECK_NE(input_types[0], nullptr);
- std::shared_ptr<DataType> output_type =
InversePermutationState::Get(ctx).output_type;
- if (!output_type) {
- output_type = input_types[0].owned_type;
- }
+ std::shared_ptr<DataType> output_type =
+
InversePermutationState::Get(ctx).output_type.value_or(input_types[0].owned_type);
if (!is_signed_integer(output_type->id())) {
return Status::TypeError(
"Output type of inverse_permutation must be signed integer, got " +
@@ -77,10 +76,7 @@ struct InversePermutationImpl {
// Apply default options semantics.
int64_t output_length = options.max_index < 0 ? input_length :
options.max_index + 1;
- std::shared_ptr<DataType> output_type = options.output_type;
- if (!output_type) {
- output_type = input_type;
- }
+ std::shared_ptr<DataType> output_type =
options.output_type.value_or(input_type);
ThisType impl(ctx, indices, input_length, output_length);
RETURN_NOT_OK(VisitTypeInline(*output_type, &impl));
@@ -332,7 +328,7 @@ void RegisterVectorInversePermutation(FunctionRegistry*
registry) {
const FunctionDoc scatter_doc(
"Scatter the values into specified positions according to the indices",
"Place the `i`-th value at the position specified by the `i`-th index",
- {"values", "indices"});
+ {"values", "indices"}, "ScatterOptions");
const ScatterOptions* GetDefaultScatterOptions() {
static const auto kDefaultScatterOptions = ScatterOptions::Defaults();
diff --git a/cpp/src/arrow/compute/kernels/vector_swizzle_test.cc
b/cpp/src/arrow/compute/kernels/vector_swizzle_test.cc
index 0879955ec4..22b78a016d 100644
--- a/cpp/src/arrow/compute/kernels/vector_swizzle_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_swizzle_test.cc
@@ -162,7 +162,8 @@ TEST(InversePermutation, DefaultOptions) {
ARROW_SCOPED_TRACE("Default options values");
InversePermutationOptions options;
ASSERT_EQ(options.max_index, -1);
- ASSERT_EQ(options.output_type, nullptr);
+ ASSERT_EQ(options.output_type, std::nullopt);
+ ASSERT_FALSE(options.output_type.has_value());
}
{
ARROW_SCOPED_TRACE("Default options semantics");
diff --git a/docs/source/python/api/compute.rst
b/docs/source/python/api/compute.rst
index b74d674ac6..f58856c5bd 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -532,8 +532,8 @@ Selections
drop_null
filter
inverse_permutation
- take
scatter
+ take
Sorts and Partitions
--------------------
@@ -606,6 +606,7 @@ Compute Options
ExtractRegexSpanOptions
FilterOptions
IndexOptions
+ InversePermutationOptions
JoinOptions
ListFlattenOptions
ListSliceOptions
@@ -630,6 +631,7 @@ Compute Options
RoundToMultipleOptions
RunEndEncodeOptions
ScalarAggregateOptions
+ ScatterOptions
SelectKOptions
SetLookupOptions
SkewOptions
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 59fd775b5a..c80e4f9316 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -1444,6 +1444,60 @@ class RunEndEncodeOptions(_RunEndEncodeOptions):
self._set_options(run_end_type)
+cdef class _InversePermutationOptions(FunctionOptions):
+ def _set_options(self, max_index=-1, output_type=None):
+ cdef optional[shared_ptr[CDataType]] c_output_type = nullopt
+ if output_type is not None:
+ c_output_type = pyarrow_unwrap_data_type(ensure_type(output_type))
+ self.wrapped.reset(
+ new CInversePermutationOptions(max_index, c_output_type))
+
+
+class InversePermutationOptions(_InversePermutationOptions):
+ """
+ Options for `inverse_permutation` function.
+
+ Parameters
+ ----------
+ max_index : int64, default -1
+ The max value in the input indices to allow.
+ The length of the function’s output will be this value plus 1.
+ If negative, this value will be set to the length of the input indices
+ minus 1 and the length of the function’s output will be the length
+ of the input indices.
+ output_type : DataType, default None
+ The data type for the output array of inverse permutation.
+ If None, the output will be of the same type as the input indices,
otherwise
+ must be a signed integer type. An invalid error will be reported if
this type
+ is not able to store the length of the input indices.
+ """
+
+ def __init__(self, max_index=-1, output_type=None):
+ self._set_options(max_index, output_type)
+
+
+cdef class _ScatterOptions(FunctionOptions):
+ def _set_options(self, max_index):
+ self.wrapped.reset(new CScatterOptions(max_index))
+
+
+class ScatterOptions(_ScatterOptions):
+ """
+ Options for `scatter` function.
+
+ Parameters
+ ----------
+ max_index : int64, default -1
+ The max value in the input indices to allow.
+ The length of the function’s output will be this value plus 1.
+ If negative, this value will be set to the length of the input indices
minus 1
+ and the length of the function’s output will be the length of the
input indices.
+ """
+
+ def __init__(self, max_index=-1):
+ self._set_options(max_index)
+
+
cdef class _TakeOptions(FunctionOptions):
def _set_options(self, boundscheck):
self.wrapped.reset(new CTakeOptions(boundscheck))
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index fe0afdb0a8..8177948aae 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -43,6 +43,7 @@ from pyarrow._compute import ( # noqa
ExtractRegexSpanOptions,
FilterOptions,
IndexOptions,
+ InversePermutationOptions,
JoinOptions,
ListSliceOptions,
ListFlattenOptions,
@@ -66,6 +67,7 @@ from pyarrow._compute import ( # noqa
RoundTemporalOptions,
RoundToMultipleOptions,
ScalarAggregateOptions,
+ ScatterOptions,
SelectKOptions,
SetLookupOptions,
SkewOptions,
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index c03bf20026..e96a7d8469 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -2588,6 +2588,17 @@ cdef extern from "arrow/compute/api.h" namespace
"arrow::compute" nogil:
CTakeOptions(c_bool boundscheck)
c_bool boundscheck
+ cdef cppclass CInversePermutationOptions \
+ "arrow::compute::InversePermutationOptions"(CFunctionOptions):
+ CInversePermutationOptions(int64_t max_index,
optional[shared_ptr[CDataType]] output_type)
+ int64_t max_index
+ optional[shared_ptr[CDataType]] output_type
+
+ cdef cppclass CScatterOptions \
+ "arrow::compute::ScatterOptions"(CFunctionOptions):
+ CScatterOptions(int64_t max_index)
+ int64_t max_index
+
cdef cppclass CStrptimeOptions \
"arrow::compute::StrptimeOptions"(CFunctionOptions):
CStrptimeOptions(c_string format, TimeUnit unit, c_bool raise_error)
diff --git a/python/pyarrow/tests/test_compute.py
b/python/pyarrow/tests/test_compute.py
index fe810a6dc9..c6b17e4791 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -40,7 +40,7 @@ except ImportError:
import pyarrow as pa
import pyarrow.compute as pc
-from pyarrow.lib import ArrowNotImplementedError
+from pyarrow.lib import ArrowNotImplementedError, ArrowIndexError
try:
import pyarrow.substrait as pas
@@ -1590,6 +1590,38 @@ def test_filter_null_type():
assert len(table.filter(mask).column(0)) == 5
+def test_inverse_permutation():
+ arr0 = pa.array([], type=pa.int32())
+ arr = pa.chunked_array([
+ arr0, [9, 7, 5, 3, 1], [0], [2, 4, 6], [8], arr0,
+ ])
+ expected = pa.chunked_array([[5, 4, 6, 3, 7, 2, 8, 1, 9, 0]],
type=pa.int32())
+ assert pc.inverse_permutation(arr).equals(expected)
+
+ options = pc.InversePermutationOptions(max_index=9, output_type=pa.int32())
+ assert pc.inverse_permutation(arr, options=options).equals(expected)
+ assert pc.inverse_permutation(arr, max_index=-1).equals(expected)
+
+ with pytest.raises(ArrowIndexError, match="Index out of bounds: 9"):
+ pc.inverse_permutation(arr, max_index=4)
+
+
+def test_scatter():
+ values = pa.array([True, False, True, True, False, False, True, True,
True, False])
+ indices = pa.array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
+ expected = pa.array([False, True, True, True, False,
+ False, True, True, False, True])
+ result = pc.scatter(values, indices)
+ assert result.equals(expected)
+
+ options = pc.ScatterOptions(max_index=-1)
+ assert pc.scatter(values, indices, options=options).equals(expected)
+ assert pc.scatter(values, indices, max_index=9).equals(expected)
+
+ with pytest.raises(ArrowIndexError, match="Index out of bounds: 9"):
+ pc.scatter(values, indices, max_index=4)
+
+
@pytest.mark.parametrize("typ", ["array", "chunked_array"])
def test_compare_array(typ):
if typ == "array":