timshen updated this revision to Diff 139044.
timshen added a comment.
Herald added a subscriber: christof.

Rebase.


https://reviews.llvm.org/D41844

Files:
  libcxx/include/experimental/simd
  libcxx/test/std/experimental/simd/simd.horizontal/hmax.pass.cpp
  libcxx/test/std/experimental/simd/simd.horizontal/hmin.pass.cpp
  libcxx/test/std/experimental/simd/simd.horizontal/reduce.pass.cpp

Index: libcxx/test/std/experimental/simd/simd.horizontal/reduce.pass.cpp
===================================================================
--- libcxx/test/std/experimental/simd/simd.horizontal/reduce.pass.cpp
+++ libcxx/test/std/experimental/simd/simd.horizontal/reduce.pass.cpp
@@ -42,13 +42,48 @@
 
 inline int factorial(int n) { return n == 1 ? 1 : n * factorial(n - 1); }
 
-void test_reduce() {
+void test_reduce_simd() {
   int n = (int)native_simd<int>::size();
   assert(reduce(native_simd<int>([](int i) { return i; })) == n * (n - 1) / 2);
   assert(reduce(native_simd<int>([](int i) { return i; }), std::plus<int>()) ==
          n * (n - 1) / 2);
   assert(reduce(native_simd<int>([](int i) { return i + 1; }),
                 std::multiplies<int>()) == factorial(n));
 }
 
-int main() { test_reduce(); }
+void test_reduce_mask() {
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    assert(reduce(where(a < 2, a), 0, std::plus<int>()) == 0 + 1);
+    assert(reduce(where(a >= 2, a), 1, std::multiplies<int>()) == 2 * 3);
+    assert(reduce(where(a >= 2, a)) == 2 + 3);
+    assert(reduce(where(a >= 2, a), std::plus<int>()) == 2 + 3);
+    assert(reduce(where(a >= 2, a), std::multiplies<int>()) == 2 * 3);
+    assert(reduce(where(a >= 2, a), std::bit_and<int>()) == (2 & 3));
+    assert(reduce(where(a >= 2, a), std::bit_or<int>()) == (2 | 3));
+    assert(reduce(where(a >= 2, a), std::bit_xor<int>()) == (2 ^ 3));
+  }
+  {
+    fixed_size_simd_mask<int, 4> a;
+    a[0] = false;
+    a[1] = true;
+    a[2] = true;
+    a[3] = false;
+    assert(reduce(where(fixed_size_simd_mask<int, 4>(true), a)) == true);
+    assert(reduce(where(fixed_size_simd_mask<int, 4>(true), a),
+                  std::plus<bool>()) == true);
+    assert(reduce(where(fixed_size_simd_mask<int, 4>(true), a),
+                  std::multiplies<bool>()) == false);
+    assert(reduce(where(fixed_size_simd_mask<int, 4>(true), a),
+                  std::bit_and<bool>()) == false);
+    assert(reduce(where(fixed_size_simd_mask<int, 4>(true), a),
+                  std::bit_or<bool>()) == true);
+    assert(reduce(where(fixed_size_simd_mask<int, 4>(true), a),
+                  std::bit_xor<bool>()) == false);
+  }
+}
+
+int main() {
+  test_reduce_simd();
+  test_reduce_mask();
+}
Index: libcxx/test/std/experimental/simd/simd.horizontal/hmin.pass.cpp
===================================================================
--- libcxx/test/std/experimental/simd/simd.horizontal/hmin.pass.cpp
+++ libcxx/test/std/experimental/simd/simd.horizontal/hmin.pass.cpp
@@ -20,7 +20,7 @@
 
 using namespace std::experimental::parallelism_v2;
 
-void test_hmin() {
+void test_hmin_simd() {
   {
     int a[] = {2, 5, -4, 6};
     assert(hmin(fixed_size_simd<int, 4>(a, element_aligned_tag())) == -4);
@@ -39,4 +39,27 @@
   }
 }
 
-int main() { test_hmin(); }
+void test_hmin_mask() {
+  assert(hmin(where(native_simd_mask<int>(false), native_simd<int>())) ==
+         std::numeric_limits<int>::max());
+  {
+    int buffer[] = {2, 5, -4, 6};
+    fixed_size_simd<int, 4> a(buffer, element_aligned_tag());
+    assert(hmin(where(a >= -4, a)) == -4);
+    assert(hmin(where(a > -4, a)) == 2);
+    assert(hmin(where(a > 2, a)) == 5);
+    assert(hmin(where(a > 5, a)) == 6);
+    assert(hmin(where(a > 6, a)) == std::numeric_limits<int>::max());
+  }
+  {
+    bool buffer[] = {false, true, true, false};
+    fixed_size_simd_mask<int, 4> a(buffer, element_aligned_tag());
+    assert(hmin(where(fixed_size_simd_mask<int, 4>(true), a)) == false);
+    assert(hmin(where(a, a)) == true);
+  }
+}
+
+int main() {
+  test_hmin_simd();
+  test_hmin_mask();
+}
Index: libcxx/test/std/experimental/simd/simd.horizontal/hmax.pass.cpp
===================================================================
--- libcxx/test/std/experimental/simd/simd.horizontal/hmax.pass.cpp
+++ libcxx/test/std/experimental/simd/simd.horizontal/hmax.pass.cpp
@@ -20,7 +20,7 @@
 
 using namespace std::experimental::parallelism_v2;
 
-void test_hmax() {
+void test_hmax_simd() {
   {
     int a[] = {2, 5, -4, 6};
     assert(hmax(fixed_size_simd<int, 4>(a, element_aligned_tag())) == 6);
@@ -39,4 +39,34 @@
   }
 }
 
-int main() { test_hmax(); }
+void test_hmax_mask() {
+  assert(hmax(where(native_simd_mask<int>(false), native_simd<int>())) ==
+         std::numeric_limits<int>::min());
+  {
+    int buffer[] = {2, 5, -4, 6};
+    fixed_size_simd<int, 4> a(buffer, element_aligned_tag());
+    assert(hmax(where(a <= 6, a)) == 6);
+    assert(hmax(where(a < 6, a)) == 5);
+    assert(hmax(where(a < 5, a)) == 2);
+    assert(hmax(where(a < 2, a)) == -4);
+    assert(hmax(where(a < -4, a)) == std::numeric_limits<int>::min());
+  }
+  {
+    bool buffer[] = {false, true, true, false};
+    fixed_size_simd_mask<int, 4> a(buffer, element_aligned_tag());
+    assert(hmax(where(fixed_size_simd_mask<int, 4>(true), a)) == true);
+    assert(hmax(where(!a, a)) == false);
+  }
+
+  {
+    const fixed_size_simd<float, 1> a(0);
+    assert(hmax(where(fixed_size_simd_mask<float, 1>(true), a)) == 0.f);
+    assert(hmax(where(fixed_size_simd_mask<float, 1>(false), a)) ==
+           std::numeric_limits<float>::lowest());
+  }
+}
+
+int main() {
+  test_hmax_simd();
+  test_hmax_mask();
+}
Index: libcxx/include/experimental/simd
===================================================================
--- libcxx/include/experimental/simd
+++ libcxx/include/experimental/simd
@@ -594,6 +594,7 @@
 #include <cstddef>
 #include <cstring>
 #include <functional>
+#include <limits>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
 #pragma GCC system_header
@@ -1634,33 +1635,48 @@
 }
 
 // reductions [simd.reductions]
-template <class _Tp, class _Abi, class _BinaryOp = std::plus<_Tp>>
-_Tp reduce(const simd<_Tp, _Abi>& __v, _BinaryOp __op = _BinaryOp()) {
-  _Tp __acc = __v[0];
+template <class _SimdType, class _BinaryOp>
+typename _SimdType::value_type __reduce(const _SimdType& __v, _BinaryOp __op) {
+  auto __acc = __v[0];
   for (size_t __i = 1; __i < __v.size(); __i++) {
     __acc = __op(__acc, __v[__i]);
   }
   return __acc;
 }
 
-template <class _Tp, class _Abi>
-_Tp hmin(const simd<_Tp, _Abi>& __v) {
-  _Tp __acc = __v[0];
+template <class _SimdType>
+typename _SimdType::value_type __hmin(const _SimdType& __v) {
+  auto __acc = __v[0];
   for (size_t __i = 1; __i < __v.size(); __i++) {
     __acc = __acc > __v[__i] ? __v[__i] : __acc;
   }
   return __acc;
 }
 
-template <class _Tp, class _Abi>
-_Tp hmax(const simd<_Tp, _Abi>& __v) {
-  _Tp __acc = __v[0];
+template <class _SimdType>
+typename _SimdType::value_type __hmax(const _SimdType& __v) {
+  auto __acc = __v[0];
   for (size_t __i = 1; __i < __v.size(); __i++) {
     __acc = __acc < __v[__i] ? __v[__i] : __acc;
   }
   return __acc;
 }
 
+template <class _Tp, class _Abi, class _BinaryOp = std::plus<_Tp>>
+_Tp reduce(const simd<_Tp, _Abi>& __v, _BinaryOp __op = _BinaryOp()) {
+  return __reduce(__v, __op);
+}
+
+template <class _Tp, class _Abi>
+_Tp hmin(const simd<_Tp, _Abi>& __v) {
+  return __hmin(__v);
+}
+
+template <class _Tp, class _Abi>
+_Tp hmax(const simd<_Tp, _Abi>& __v) {
+  return __hmax(__v);
+}
+
 // algorithms [simd.alg]
 template <class _Tp, class _Abi>
 simd<_Tp, _Abi> min(const simd<_Tp, _Abi>& __a,
@@ -2295,6 +2311,19 @@
                                  const_where_expression<bool, const _Up>>::type
   where(_Mp __m, const _Up& __v) noexcept;
 
+  template <class _Mp, class _Vp, class _BinaryOp>
+  friend typename _Vp::value_type
+  reduce(const const_where_expression<_Mp, _Vp>& __w,
+         typename _Vp::value_type __identity, _BinaryOp __op);
+
+  template <class _Mp, class _Vp>
+  friend typename _Vp::value_type
+  hmin(const const_where_expression<_Mp, _Vp>& __w);
+
+  template <class _Mp, class _Vp>
+  friend typename _Vp::value_type
+  hmax(const const_where_expression<_Mp, _Vp>& __w);
+
 public:
   const_where_expression& operator=(const const_where_expression&) = delete;
 
@@ -2479,41 +2508,63 @@
 
 template <class _MaskType, class _SimdType, class _BinaryOp>
 typename _SimdType::value_type
-reduce(const const_where_expression<_MaskType, _SimdType>&,
-       typename _SimdType::value_type neutral_element, _BinaryOp binary_op);
+reduce(const const_where_expression<_MaskType, _SimdType>& __w,
+       typename _SimdType::value_type __identity, _BinaryOp __op) {
+  auto __v = __w.__v_;
+  where(!__w.__m_, __v) = _SimdType(__identity);
+  return __reduce(__v, __op);
+}
 
 template <class _MaskType, class _SimdType>
 typename _SimdType::value_type
-reduce(const const_where_expression<_MaskType, _SimdType>&,
-       plus<typename _SimdType::value_type> binary_op = {});
+reduce(const const_where_expression<_MaskType, _SimdType>& __w,
+       plus<typename _SimdType::value_type> __op = {}) {
+  return reduce(__w, typename _SimdType::value_type(0), __op);
+}
 
 template <class _MaskType, class _SimdType>
 typename _SimdType::value_type
-reduce(const const_where_expression<_MaskType, _SimdType>&,
-       multiplies<typename _SimdType::value_type> binary_op);
+reduce(const const_where_expression<_MaskType, _SimdType>& __w,
+       multiplies<typename _SimdType::value_type> __op) {
+  return reduce(__w, typename _SimdType::value_type(1), __op);
+}
 
 template <class _MaskType, class _SimdType>
 typename _SimdType::value_type
-reduce(const const_where_expression<_MaskType, _SimdType>&,
-       bit_and<typename _SimdType::value_type> binary_op);
+reduce(const const_where_expression<_MaskType, _SimdType>& __w,
+       bit_and<typename _SimdType::value_type> __op) {
+  return reduce(__w, typename _SimdType::value_type(-1), __op);
+}
 
 template <class _MaskType, class _SimdType>
 typename _SimdType::value_type
-reduce(const const_where_expression<_MaskType, _SimdType>&,
-       bit_or<typename _SimdType::value_type> binary_op);
+reduce(const const_where_expression<_MaskType, _SimdType>& __w,
+       bit_or<typename _SimdType::value_type> __op) {
+  return reduce(__w, typename _SimdType::value_type(0), __op);
+}
 
 template <class _MaskType, class _SimdType>
 typename _SimdType::value_type
-reduce(const const_where_expression<_MaskType, _SimdType>&,
-       bit_xor<typename _SimdType::value_type> binary_op);
+reduce(const const_where_expression<_MaskType, _SimdType>& __w,
+       bit_xor<typename _SimdType::value_type> __op) {
+  return reduce(__w, typename _SimdType::value_type(0), __op);
+}
 
 template <class _MaskType, class _SimdType>
 typename _SimdType::value_type
-hmin(const const_where_expression<_MaskType, _SimdType>&);
+hmin(const const_where_expression<_MaskType, _SimdType>& __w) {
+  return __hmin(__simd_mask_friend::__simd_select(
+      _SimdType(std::numeric_limits<typename _SimdType::value_type>::max()),
+      __w.__v_, __w.__m_));
+}
 
 template <class _MaskType, class _SimdType>
 typename _SimdType::value_type
-hmax(const const_where_expression<_MaskType, _SimdType>&);
+hmax(const const_where_expression<_MaskType, _SimdType>& __w) {
+  return __hmax(__simd_mask_friend::__simd_select(
+      _SimdType(std::numeric_limits<typename _SimdType::value_type>::lowest()),
+      __w.__v_, __w.__m_));
+}
 
 _LIBCPP_END_NAMESPACE_EXPERIMENTAL_SIMD
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to