This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 0710ecce79 Improve StringArray(Utf8) sort performance (~2-4x faster)
(#7860)
0710ecce79 is described below
commit 0710ecce798dbc5123ce68f5972d3c8928749d30
Author: Qi Zhu <[email protected]>
AuthorDate: Thu Aug 7 18:30:53 2025 +0800
Improve StringArray(Utf8) sort performance (~2-4x faster) (#7860)
# Which issue does this PR close?
Improve StringArray(Utf8) sort performance
- Closes [#7847](https://github.com/apache/arrow-rs/issues/7847)
# Rationale for this change
Support prefix compare, and i optimized it to u32 prefix, and u64
increment compare, it will have best performance when experimenting.
# What changes are included in this PR?
Support prefix compare, and i optimized it to u32 prefix, and u64
increment compare, it will have best performance when experimenting.
# Are these changes tested?
Yes
```rust
critcmp issue_7847 main --filter "sort string\["
group issue_7847
main
----- ----------
----
sort string[0-400] nulls to indices 2^12 1.00 51.4±0.56µs ?
?/sec 1.19 61.0±1.02µs ? ?/sec
sort string[0-400] to indices 2^12 1.00 96.5±1.63µs ?
?/sec 1.23 118.3±0.91µs ? ?/sec
sort string[10] dict nulls to indices 2^12 1.00 72.4±1.00µs ?
?/sec 1.00 72.5±0.61µs ? ?/sec
sort string[10] dict to indices 2^12 1.00 137.1±1.51µs ?
?/sec 1.01 138.1±1.06µs ? ?/sec
sort string[10] nulls to indices 2^12 1.00 47.5±0.69µs ?
?/sec 1.18 56.3±0.56µs ? ?/sec
sort string[10] to indices 2^12 1.00 86.4±1.37µs ?
?/sec 1.20 103.5±1.13µs ? ?/sec
```
# Are there any user-facing changes?
If there are user-facing changes then we may require documentation to be
updated before approving the PR.
If there are any breaking changes to public APIs, please call them out.
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
arrow-ord/src/sort.rs | 381 +++++++++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 377 insertions(+), 4 deletions(-)
diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index a405aa7a37..ba026af637 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -345,12 +345,88 @@ fn sort_bytes<T: ByteArrayType>(
options: SortOptions,
limit: Option<usize>,
) -> UInt32Array {
- let mut valids = value_indices
+ // Note: Why do we use 4‑byte prefix?
+ // Compute the 4‑byte prefix in BE order, or left‑pad if shorter.
+ // Most byte‐sequences differ in their first few bytes, so by
+ // comparing up to 4 bytes as a single u32 we avoid the overhead
+ // of a full lexicographical compare for the vast majority of cases.
+
+ // 1. Build a vector of (index, prefix, length) tuples
+ let mut valids: Vec<(u32, u32, u64)> = value_indices
.into_iter()
- .map(|index| (index, values.value(index as usize).as_ref()))
- .collect::<Vec<(u32, &[u8])>>();
+ .map(|idx| unsafe {
+ let slice: &[u8] = values.value_unchecked(idx as usize).as_ref();
+ let len = slice.len() as u64;
+ // Compute the 4‑byte prefix in BE order, or left‑pad if shorter
+ let prefix = if slice.len() >= 4 {
+ let raw = std::ptr::read_unaligned(slice.as_ptr() as *const
u32);
+ u32::from_be(raw)
+ } else if slice.is_empty() {
+ // Handle empty slice case to avoid shift overflow
+ 0u32
+ } else {
+ let mut v = 0u32;
+ for &b in slice {
+ v = (v << 8) | (b as u32);
+ }
+ // Safe shift: slice.len() is in range [1, 3], so shift is in
range [8, 24]
+ v << (8 * (4 - slice.len()))
+ };
+ (idx, prefix, len)
+ })
+ .collect();
- sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into()
+ // 2. compute the number of non-null entries to partially sort
+ let vlimit = match (limit, options.nulls_first) {
+ (Some(l), true) => l.saturating_sub(nulls.len()).min(valids.len()),
+ _ => valids.len(),
+ };
+
+ // 3. Comparator: compare prefix, then (when both slices shorter than 4)
length, otherwise full slice
+ let cmp_bytes = |a: &(u32, u32, u64), b: &(u32, u32, u64)| unsafe {
+ let (ia, pa, la) = *a;
+ let (ib, pb, lb) = *b;
+ // 3.1 prefix (first 4 bytes)
+ let ord = pa.cmp(&pb);
+ if ord != Ordering::Equal {
+ return ord;
+ }
+ // 3.2 only if both slices had length < 4 (so prefix was padded)
+ if la < 4 || lb < 4 {
+ let ord = la.cmp(&lb);
+ if ord != Ordering::Equal {
+ return ord;
+ }
+ }
+ // 3.3 full lexicographical compare
+ let a_bytes: &[u8] = values.value_unchecked(ia as usize).as_ref();
+ let b_bytes: &[u8] = values.value_unchecked(ib as usize).as_ref();
+ a_bytes.cmp(b_bytes)
+ };
+
+ // 4. Partially sort according to ascending/descending
+ if !options.descending {
+ sort_unstable_by(&mut valids, vlimit, cmp_bytes);
+ } else {
+ sort_unstable_by(&mut valids, vlimit, |x, y| cmp_bytes(x,
y).reverse());
+ }
+
+ // 5. Assemble nulls and sorted indices into final output
+ let total = valids.len() + nulls.len();
+ let out_limit = limit.unwrap_or(total).min(total);
+ let mut out = Vec::with_capacity(out_limit);
+
+ if options.nulls_first {
+ out.extend_from_slice(&nulls[..nulls.len().min(out_limit)]);
+ let rem = out_limit - out.len();
+ out.extend(valids.iter().map(|&(i, _, _)| i).take(rem));
+ } else {
+ out.extend(valids.iter().map(|&(i, _, _)| i).take(out_limit));
+ let rem = out_limit - out.len();
+ out.extend_from_slice(&nulls[..rem]);
+ }
+
+ out.into()
}
fn sort_byte_view<T: ByteViewType>(
@@ -4841,4 +4917,301 @@ mod tests {
assert_eq!(valid, vec![0, 2]);
assert_eq!(nulls, vec![1, 3]);
}
+
+ // Test specific edge case strings that exercise the 4-byte prefix logic
+ #[test]
+ fn test_specific_edge_cases() {
+ let test_cases = vec![
+ // Key test cases for lengths 1-4 that test prefix padding
+ "a", "ab", "ba", "baa", "abba", "abbc", "abc", "cda",
+ // Test cases where first 4 bytes are same but subsequent bytes
differ
+ "abcd", "abcde", "abcdf", "abcdaaa", "abcdbbb",
+ // Test cases with length < 4 that require padding
+ "z", "za", "zaa", "zaaa", "zaaab", // Empty string
+ "", // Test various length combinations with same prefix
+ "test", "test1", "test12", "test123", "test1234",
+ ];
+
+ // Use standard library sort as reference
+ let mut expected = test_cases.clone();
+ expected.sort();
+
+ // Use our sorting algorithm
+ let string_array = StringArray::from(test_cases.clone());
+ let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
+ let result = sort_bytes(
+ &string_array,
+ indices,
+ vec![], // no nulls
+ SortOptions::default(),
+ None,
+ );
+
+ // Verify results
+ let sorted_strings: Vec<&str> = result
+ .values()
+ .iter()
+ .map(|&idx| test_cases[idx as usize])
+ .collect();
+
+ assert_eq!(sorted_strings, expected);
+ }
+
+ // Test sorting correctness for different length combinations
+ #[test]
+ fn test_length_combinations() {
+ let test_cases = vec![
+ // Focus on testing strings of length 1-4, as these affect padding
logic
+ ("", 0),
+ ("a", 1),
+ ("ab", 2),
+ ("abc", 3),
+ ("abcd", 4),
+ ("abcde", 5),
+ ("b", 1),
+ ("ba", 2),
+ ("bab", 3),
+ ("babc", 4),
+ ("babcd", 5),
+ // Test same prefix with different lengths
+ ("test", 4),
+ ("test1", 5),
+ ("test12", 6),
+ ("test123", 7),
+ ];
+
+ let strings: Vec<&str> = test_cases.iter().map(|(s, _)| *s).collect();
+ let mut expected = strings.clone();
+ expected.sort();
+
+ let string_array = StringArray::from(strings.clone());
+ let indices: Vec<u32> = (0..strings.len() as u32).collect();
+ let result = sort_bytes(&string_array, indices, vec![],
SortOptions::default(), None);
+
+ let sorted_strings: Vec<&str> = result
+ .values()
+ .iter()
+ .map(|&idx| strings[idx as usize])
+ .collect();
+
+ assert_eq!(sorted_strings, expected);
+ }
+
+ // Test UTF-8 string handling
+ #[test]
+ fn test_utf8_strings() {
+ let test_cases = vec![
+ "a",
+ "你", // 3-byte UTF-8 character
+ "你好", // 6 bytes
+ "你好世界", // 12 bytes
+ "🎉", // 4-byte emoji
+ "🎉🎊", // 8 bytes
+ "café", // Contains accent character
+ "naïve",
+ "Москва", // Cyrillic script
+ "東京", // Japanese kanji
+ "한국", // Korean
+ ];
+
+ let mut expected = test_cases.clone();
+ expected.sort();
+
+ let string_array = StringArray::from(test_cases.clone());
+ let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
+ let result = sort_bytes(&string_array, indices, vec![],
SortOptions::default(), None);
+
+ let sorted_strings: Vec<&str> = result
+ .values()
+ .iter()
+ .map(|&idx| test_cases[idx as usize])
+ .collect();
+
+ assert_eq!(sorted_strings, expected);
+ }
+
+ // Fuzz testing: generate random UTF-8 strings and verify sort correctness
+ #[test]
+ fn test_fuzz_random_strings() {
+ let mut rng = StdRng::seed_from_u64(42); // Fixed seed for
reproducibility
+
+ for _ in 0..100 {
+ // Run 100 rounds of fuzz testing
+ let mut test_strings = Vec::new();
+
+ // Generate 20-50 random strings
+ let num_strings = rng.random_range(20..=50);
+
+ for _ in 0..num_strings {
+ let string = generate_random_string(&mut rng);
+ test_strings.push(string);
+ }
+
+ // Use standard library sort as reference
+ let mut expected = test_strings.clone();
+ expected.sort();
+
+ // Use our sorting algorithm
+ let string_array = StringArray::from(test_strings.clone());
+ let indices: Vec<u32> = (0..test_strings.len() as u32).collect();
+ let result = sort_bytes(&string_array, indices, vec![],
SortOptions::default(), None);
+
+ let sorted_strings: Vec<String> = result
+ .values()
+ .iter()
+ .map(|&idx| test_strings[idx as usize].clone())
+ .collect();
+
+ assert_eq!(
+ sorted_strings, expected,
+ "Fuzz test failed with input: {test_strings:?}"
+ );
+ }
+ }
+
+ // Helper function to generate random UTF-8 strings
+ fn generate_random_string(rng: &mut StdRng) -> String {
+ // Bias towards generating short strings, especially length 1-4
+ let length = if rng.random_bool(0.6) {
+ rng.random_range(0..=4) // 60% probability for 0-4 length strings
+ } else {
+ rng.random_range(5..=20) // 40% probability for longer strings
+ };
+
+ if length == 0 {
+ return String::new();
+ }
+
+ let mut result = String::new();
+ let mut current_len = 0;
+
+ while current_len < length {
+ let c = generate_random_char(rng);
+ let char_len = c.len_utf8();
+
+ // Ensure we don't exceed target length
+ if current_len + char_len <= length {
+ result.push(c);
+ current_len += char_len;
+ } else {
+ // If adding this character would exceed length, fill with
ASCII
+ let remaining = length - current_len;
+ for _ in 0..remaining {
+ result.push(rng.random_range('a'..='z'));
+ current_len += 1;
+ }
+ break;
+ }
+ }
+
+ result
+ }
+
+ // Generate random characters (including various UTF-8 characters)
+ fn generate_random_char(rng: &mut StdRng) -> char {
+ match rng.random_range(0..10) {
+ 0..=5 => rng.random_range('a'..='z'), // 60% ASCII lowercase
+ 6 => rng.random_range('A'..='Z'), // 10% ASCII uppercase
+ 7 => rng.random_range('0'..='9'), // 10% digits
+ 8 => {
+ // 10% Chinese characters
+ let chinese_chars = ['你', '好', '世', '界', '测', '试', '中', '文'];
+ chinese_chars[rng.random_range(0..chinese_chars.len())]
+ }
+ 9 => {
+ // 10% other Unicode characters (single `char`s)
+ let special_chars = ['é', 'ï', '🎉', '🎊', 'α', 'β', 'γ'];
+ special_chars[rng.random_range(0..special_chars.len())]
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ // Test descending sort order
+ #[test]
+ fn test_descending_sort() {
+ let test_cases = vec!["a", "ab", "ba", "baa", "abba", "abbc", "abc",
"cda"];
+
+ let mut expected = test_cases.clone();
+ expected.sort();
+ expected.reverse(); // Descending order
+
+ let string_array = StringArray::from(test_cases.clone());
+ let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
+ let result = sort_bytes(
+ &string_array,
+ indices,
+ vec![],
+ SortOptions {
+ descending: true,
+ nulls_first: false,
+ },
+ None,
+ );
+
+ let sorted_strings: Vec<&str> = result
+ .values()
+ .iter()
+ .map(|&idx| test_cases[idx as usize])
+ .collect();
+
+ assert_eq!(sorted_strings, expected);
+ }
+
+ // Stress test: large number of strings with same prefix
+ #[test]
+ fn test_same_prefix_stress() {
+ let mut test_cases = Vec::new();
+ let prefix = "same";
+
+ // Generate many strings with the same prefix
+ for i in 0..1000 {
+ test_cases.push(format!("{prefix}{i:04}"));
+ }
+
+ let mut expected = test_cases.clone();
+ expected.sort();
+
+ let string_array = StringArray::from(test_cases.clone());
+ let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
+ let result = sort_bytes(&string_array, indices, vec![],
SortOptions::default(), None);
+
+ let sorted_strings: Vec<String> = result
+ .values()
+ .iter()
+ .map(|&idx| test_cases[idx as usize].clone())
+ .collect();
+
+ assert_eq!(sorted_strings, expected);
+ }
+
+ // Test limit parameter
+ #[test]
+ fn test_with_limit() {
+ let test_cases = vec!["z", "y", "x", "w", "v", "u", "t", "s"];
+ let limit = 3;
+
+ let mut expected = test_cases.clone();
+ expected.sort();
+ expected.truncate(limit);
+
+ let string_array = StringArray::from(test_cases.clone());
+ let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
+ let result = sort_bytes(
+ &string_array,
+ indices,
+ vec![],
+ SortOptions::default(),
+ Some(limit),
+ );
+
+ let sorted_strings: Vec<&str> = result
+ .values()
+ .iter()
+ .map(|&idx| test_cases[idx as usize])
+ .collect();
+
+ assert_eq!(sorted_strings, expected);
+ assert_eq!(sorted_strings.len(), limit);
+ }
}