This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new e6e50de5ea GH-36753: [C++] Properly pretty-print and diff
HalfFloatArrays (#46857)
e6e50de5ea is described below
commit e6e50de5ea8e117c52101a67388efeb2b78bf60b
Author: Eric Dinse <[email protected]>
AuthorDate: Tue Jun 24 04:01:24 2025 -0400
GH-36753: [C++] Properly pretty-print and diff HalfFloatArrays (#46857)
### Rationale for this change
#36753 asked for this to be implemented now that a half-float library was
available.
### What changes are included in this PR?
Pretty printing and diffing of HalfFloatArrays now displays floating values
instead of uint16.
### Are these changes tested?
Yes, with tests in C++ and Python.
### Are there any user-facing changes?
Pretty-printing and diffing float16 will display as floating point and not
uint16.
* GitHub Issue: #36753
Authored-by: Eric Dinse <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/arrow/array/diff.cc | 9 +++++++++
cpp/src/arrow/array/diff_test.cc | 16 ++++++++++++++++
cpp/src/arrow/pretty_print.cc | 6 ------
cpp/src/arrow/pretty_print_test.cc | 35 +++++++++++++++++++++++++++++++++++
python/pyarrow/tests/test_array.py | 18 ++++++++++++++++--
python/pyarrow/types.pxi | 4 ++--
6 files changed, 78 insertions(+), 10 deletions(-)
diff --git a/cpp/src/arrow/array/diff.cc b/cpp/src/arrow/array/diff.cc
index 4a640e6b9c..cf53c32155 100644
--- a/cpp/src/arrow/array/diff.cc
+++ b/cpp/src/arrow/array/diff.cc
@@ -43,6 +43,7 @@
#include "arrow/type_traits.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/checked_cast.h"
+#include "arrow/util/float16.h"
#include "arrow/util/logging_internal.h"
#include "arrow/util/range.h"
#include "arrow/util/ree_util.h"
@@ -627,6 +628,14 @@ class MakeFormatterImpl {
return Status::OK();
}
+ Status Visit(const HalfFloatType&) {
+ impl_ = [](const Array& array, int64_t index, std::ostream* os) {
+ const auto& float16_arr = checked_cast<const HalfFloatArray&>(array);
+ *os << arrow::util::Float16::FromBits(float16_arr.Value(index));
+ };
+ return Status::OK();
+ }
+
// format Numerics with std::ostream defaults
template <typename T>
enable_if_number<T, Status> Visit(const T&) {
diff --git a/cpp/src/arrow/array/diff_test.cc b/cpp/src/arrow/array/diff_test.cc
index 02bcf5bbb4..3effe2a037 100644
--- a/cpp/src/arrow/array/diff_test.cc
+++ b/cpp/src/arrow/array/diff_test.cc
@@ -35,6 +35,7 @@
#include "arrow/testing/random.h"
#include "arrow/testing/util.h"
#include "arrow/type.h"
+#include "arrow/util/float16.h"
#include "arrow/util/logging.h"
namespace arrow {
@@ -815,4 +816,19 @@ TEST_F(DiffTest, CompareRandomStruct) {
}
}
+TEST_F(DiffTest, CompareHalfFloat) {
+ auto first = ArrayFromJSON(float16(), "[1.1, 2.0, 2.5, 3.3]");
+ auto second = ArrayFromJSON(float16(), "[1.1, 4.0, 3.5, 3.3]");
+ auto expected_diff = R"(
+@@ -1, +1 @@
+-2
+-2.5
++4
++3.5
+)";
+
+ auto diff = first->Diff(*second);
+ ASSERT_EQ(diff, expected_diff);
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/pretty_print.cc b/cpp/src/arrow/pretty_print.cc
index 807498b6bf..7234499285 100644
--- a/cpp/src/arrow/pretty_print.cc
+++ b/cpp/src/arrow/pretty_print.cc
@@ -239,12 +239,6 @@ class ArrayPrinter : public PrettyPrinter {
return WritePrimitiveValues(array);
}
- Status WriteDataValues(const HalfFloatArray& array) {
- // XXX do not know how to format half floats yet
- StringFormatter<Int16Type> formatter{array.type().get()};
- return WritePrimitiveValues(array, &formatter);
- }
-
template <typename ArrayType, typename T = typename ArrayType::TypeClass>
enable_if_has_string_view<T, Status> WriteDataValues(const ArrayType& array)
{
return WriteValues(array, [&](int64_t i) {
diff --git a/cpp/src/arrow/pretty_print_test.cc
b/cpp/src/arrow/pretty_print_test.cc
index 0dfe3c9db3..c90b03bbda 100644
--- a/cpp/src/arrow/pretty_print_test.cc
+++ b/cpp/src/arrow/pretty_print_test.cc
@@ -19,6 +19,7 @@
#include <gtest/gtest.h>
+#include <cmath>
#include <cstdint>
#include <cstring>
#include <limits>
@@ -33,10 +34,13 @@
#include "arrow/testing/builder.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/type.h"
+#include "arrow/util/float16.h"
#include "arrow/util/key_value_metadata.h"
namespace arrow {
+using util::Float16;
+
class TestPrettyPrint : public ::testing::Test {
public:
void SetUp() {}
@@ -330,6 +334,37 @@ TEST_F(TestPrettyPrint, UInt64) {
expected);
}
+TEST_F(TestPrettyPrint, HalfFloat) {
+ static const char* expected = R"expected([
+ -inf,
+ -1234,
+ -0,
+ 0,
+ 1,
+ 1.2001953125,
+ 2.5,
+ 3.9921875,
+ 4.125,
+ 10000,
+ 12344,
+ inf,
+ nan,
+ null
+])expected";
+
+ std::vector<uint16_t> values = {
+ Float16(-1e10f).bits(), Float16(-1234.0f).bits(),
Float16(-0.0f).bits(),
+ Float16(0.0f).bits(), Float16(1.0f).bits(), Float16(1.2f).bits(),
+ Float16(2.5f).bits(), Float16(3.9921875f).bits(),
Float16(4.125f).bits(),
+ Float16(1e4f).bits(), Float16(12345.0f).bits(), Float16(1e5f).bits(),
+ Float16(NAN).bits(), Float16(6.10f).bits()};
+
+ std::vector<bool> is_valid(values.size(), true);
+ is_valid.back() = false;
+
+ CheckPrimitive<HalfFloatType, uint16_t>({0, 10}, is_valid, values, expected);
+}
+
TEST_F(TestPrettyPrint, DateTimeTypes) {
std::vector<bool> is_valid = {true, true, false, true, false};
diff --git a/python/pyarrow/tests/test_array.py
b/python/pyarrow/tests/test_array.py
index 97425df0f9..0cd76c700b 100644
--- a/python/pyarrow/tests/test_array.py
+++ b/python/pyarrow/tests/test_array.py
@@ -568,6 +568,8 @@ def test_array_diff():
arr2 = pa.array(['foo', 'bar', None], type=pa.utf8())
arr3 = pa.array([1, 2, 3])
arr4 = pa.array([[], [1], None], type=pa.list_(pa.int64()))
+ arr5 = pa.array([1.5, 3, 6], type=pa.float16())
+ arr6 = pa.array([1, 3], type=pa.float16())
assert arr1.diff(arr1) == ''
assert arr1.diff(arr2) == '''
@@ -579,6 +581,14 @@ def test_array_diff():
assert arr1.diff(arr3).strip() == '# Array types differed: string vs int64'
assert arr1.diff(arr4).strip() == ('# Array types differed: string vs '
'list<item: int64>')
+ assert arr5.diff(arr5) == ''
+ assert arr5.diff(arr6) == '''
+@@ -0, +0 @@
+-1.5
++1
+@@ -2, +2 @@
+-6
+'''
def test_array_iter():
@@ -1706,9 +1716,13 @@ def test_floating_point_truncate_unsafe():
def test_half_float_array_from_python():
# GH-46611
- arr = pa.array([1.0, 2.0, 3, None, 12345.6789, 1.234567],
type=pa.float16())
+ vals = [-5, 0, 1.0, 2.0, 3, None, 12345.6789, 1.234567, float('inf')]
+ arr = pa.array(vals, type=pa.float16())
assert arr.type == pa.float16()
- assert arr.to_pylist() == [1.0, 2.0, 3.0, None, 12344.0, 1.234375]
+ assert arr.to_pylist() == [-5, 0, 1.0, 2.0, 3, None, 12344.0,
+ 1.234375, float('inf')]
+ assert str(arr) == ("[\n -5,\n 0,\n 1,\n 2,\n 3,\n null,\n 12344,"
+ "\n 1.234375,\n inf\n]")
msg1 = "Could not convert 'a' with type str: tried to convert to float16"
with pytest.raises(pa.ArrowInvalid, match=msg1):
arr = pa.array(['a', 3, None], type=pa.float16())
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 9b2f8881e3..62457de0bb 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -4459,8 +4459,8 @@ def float16():
>>> a
<pyarrow.lib.HalfFloatArray object at ...>
[
- 15872,
- 32256
+ 1.5,
+ nan
]
Note that unlike other float types, if you convert this array