On 27/06/25 14:53 +0100, Jonathan Wakely wrote:
On 26/06/25 23:12 -0400, Patrick Palka wrote:
On Thu, 26 Jun 2025, Patrick Palka wrote:

        PR libstdc++/100795

libstdc++-v3/ChangeLog:

        * include/bits/ranges_algo.h (__sample_fn::operator()):
        Reimplement the forward_iterator branch directly.
        * testsuite/25_algorithms/sample/constrained.cc (test02):
        New test.
---
libstdc++-v3/include/bits/ranges_algo.h       | 70 +++++++++++++++++--
.../25_algorithms/sample/constrained.cc       | 28 ++++++++
2 files changed, 91 insertions(+), 7 deletions(-)

diff --git a/libstdc++-v3/include/bits/ranges_algo.h 
b/libstdc++-v3/include/bits/ranges_algo.h
index b12da2af1263..672a0ebce0de 100644
--- a/libstdc++-v3/include/bits/ranges_algo.h
+++ b/libstdc++-v3/include/bits/ranges_algo.h
@@ -1839,14 +1839,70 @@ namespace ranges
      operator()(_Iter __first, _Sent __last, _Out __out,
                 iter_difference_t<_Iter> __n, _Gen&& __g) const
      {
+       // FIXME: Correctly handle integer-class difference types.

On second thought maybe we don't need to teach uniform_int_distribution
to handle integer-class difference types.  We could just assert that
__n fits inside a long long and use that as the difference type?  Same
for shuffle.

Yeah, if we're being asked to take more than 1<<64 samples something
probably went very wrong somewhere.

But isn't it valid to pass in an enormous value of n, as long as
last - first is not ridiculous?

for example:

auto population = views::iota((__int128)0, (__int128)10);
using D = ranges::difference_t<decltype(population)>;
ranges::sample(population, out, numeric_limits<D>::max(), gen);

This n won't fit in long long, but min(last - first, n) will.

Does std::uniform_int_distribution currently support __int128? I think
it does, just using the slower "two divisions" path, because we don't
have a larger type to use for Lemire's algorithm.

        if constexpr (forward_iterator<_Iter>)
          {
-           // FIXME: Forwarding to std::sample here requires computing __lasti
-           // which may take linear time.
-           auto __lasti = ranges::next(__first, __last);
-           return _GLIBCXX_STD_A::
-             sample(std::move(__first), std::move(__lasti), std::move(__out),
-                    __n, std::forward<_Gen>(__g));
+           using _Size = iter_difference_t<_Iter>;
+           using __distrib_type = uniform_int_distribution<_Size>;
+           using __param_type = typename __distrib_type::param_type;
+           using _USize = __detail::__make_unsigned_like_t<_Size>;
+           using __uc_type
+             = common_type_t<typename remove_reference_t<_Gen>::result_type, 
_USize>;
+
+           if (__first == __last)
+             return __out;
+
+           __distrib_type __d{};
+           _Size __unsampled_sz = ranges::distance(__first, __last);
+           __n = std::min(__n, __unsampled_sz);
+
+           // If possible, we use __gen_two_uniform_ints to efficiently produce
+           // two random numbers using a single distribution invocation:
+
+           const __uc_type __urngrange = __g.max() - __g.min();
+           if (__urngrange / __uc_type(__unsampled_sz) >= 
__uc_type(__unsampled_sz))
+             // I.e. (__urngrange >= __unsampled_sz * __unsampled_sz) but 
without
+             // wrapping issues.
+             {
+               while (__n != 0 && __unsampled_sz >= 2)
+                 {
+                   const pair<_Size, _Size> __p =
+                     __gen_two_uniform_ints(__unsampled_sz, __unsampled_sz - 
1, __g);
+
+                   --__unsampled_sz;
+                   if (__p.first < __n)
+                     {
+                       *__out = *__first;
+                       ++__out;
+                       --__n;
+                     }
+
+                   ++__first;
+
+                   if (__n == 0) break;
+
+                   --__unsampled_sz;
+                   if (__p.second < __n)
+                     {
+                       *__out = *__first;
+                       ++__out;
+                       --__n;
+                     }
+
+                   ++__first;
+                 }
+             }
+
+           // The loop above is otherwise equivalent to this one-at-a-time 
version:
+
+           for (; __n != 0; ++__first)
+             if (__d(__g, __param_type{0, --__unsampled_sz}) < __n)
+               {
+                 *__out = *__first;
+                 ++__out;
+                 --__n;
+               }
+           return __out;
          }
        else
          {
@@ -1867,7 +1923,7 @@ namespace ranges
                if (__k < __n)
                  __out[__k] = *__first;
              }
-           return __out + __sample_sz;
+           return __out + iter_difference_t<_Out>(__sample_sz);
          }
      }

diff --git a/libstdc++-v3/testsuite/25_algorithms/sample/constrained.cc 
b/libstdc++-v3/testsuite/25_algorithms/sample/constrained.cc
index b9945b164903..150e2d2036e0 100644
--- a/libstdc++-v3/testsuite/25_algorithms/sample/constrained.cc
+++ b/libstdc++-v3/testsuite/25_algorithms/sample/constrained.cc
@@ -20,6 +20,7 @@

#include <algorithm>
#include <random>
+#include <ranges>
#include <testsuite_hooks.h>
#include <testsuite_iterators.h>

@@ -59,9 +60,36 @@ test01()
    }
}

+void
+test02()
+{
+  // PR libstdc++/100795 - ranges::sample should not use std::sample
+#if 0 // FIXME: ranges::sample rejects integer-class difference types.
+#if __SIZEOF_INT128__
+  auto v = std::views::iota(__int128(0), __int128(20));
+#else
+  auto v = std::views::iota(0ll, 20ll);
+#endif
+#else
+  auto v = std::views::iota(0, 20);
+#endif
+
+  int storage[20] = {2,5,4,3,1,6,7,9,10,8,11,14,12,13,15,16,18,0,19,17};
+  auto w = v | std::views::transform([&](auto i) -> int& { return storage[i]; 
});
+  using type = decltype(w);
+  using cat = 
std::iterator_traits<std::ranges::iterator_t<type>>::iterator_category;
+  static_assert( std::same_as<cat, std::output_iterator_tag> );
+  static_assert( std::ranges::random_access_range<type> );
+
+  ranges::sample(v, w.begin(), 20, rng);
+  ranges::sort(w);
+  VERIFY( ranges::equal(w, v) );
+}
+
int
main()
{
  test01<forward_iterator_wrapper, output_iterator_wrapper>();
  test01<input_iterator_wrapper, random_access_iterator_wrapper>();
+  test02();
}
--
2.50.0.131.gcf6f63ea6b





Reply via email to