This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 0710ecce79 Improve StringArray(Utf8) sort performance (~2-4x faster) 
(#7860)
0710ecce79 is described below

commit 0710ecce798dbc5123ce68f5972d3c8928749d30
Author: Qi Zhu <[email protected]>
AuthorDate: Thu Aug 7 18:30:53 2025 +0800

    Improve StringArray(Utf8) sort performance (~2-4x faster) (#7860)
    
    # Which issue does this PR close?
    
    Improve StringArray(Utf8) sort performance
    
    - Closes [#7847](https://github.com/apache/arrow-rs/issues/7847)
    
    # Rationale for this change
    Support prefix compare, and i optimized it to u32 prefix, and u64
    increment compare, it will have best performance when experimenting.
    
    
    # What changes are included in this PR?
    
    Support prefix compare, and i optimized it to u32 prefix, and u64
    increment compare, it will have best performance when experimenting.
    
    # Are these changes tested?
    Yes
    
    
    ```rust
    critcmp  issue_7847  main --filter "sort string\["
    group                                         issue_7847                    
         main
    -----                                         ----------                    
         ----
    sort string[0-400] nulls to indices 2^12      1.00     51.4±0.56µs        ? 
?/sec    1.19     61.0±1.02µs        ? ?/sec
    sort string[0-400] to indices 2^12            1.00     96.5±1.63µs        ? 
?/sec    1.23    118.3±0.91µs        ? ?/sec
    sort string[10] dict nulls to indices 2^12    1.00     72.4±1.00µs        ? 
?/sec    1.00     72.5±0.61µs        ? ?/sec
    sort string[10] dict to indices 2^12          1.00    137.1±1.51µs        ? 
?/sec    1.01    138.1±1.06µs        ? ?/sec
    sort string[10] nulls to indices 2^12         1.00     47.5±0.69µs        ? 
?/sec    1.18     56.3±0.56µs        ? ?/sec
    sort string[10] to indices 2^12               1.00     86.4±1.37µs        ? 
?/sec    1.20    103.5±1.13µs        ? ?/sec
    ```
    
    # Are there any user-facing changes?
    
    If there are user-facing changes then we may require documentation to be
    updated before approving the PR.
    
    If there are any breaking changes to public APIs, please call them out.
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 arrow-ord/src/sort.rs | 381 +++++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 377 insertions(+), 4 deletions(-)

diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index a405aa7a37..ba026af637 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -345,12 +345,88 @@ fn sort_bytes<T: ByteArrayType>(
     options: SortOptions,
     limit: Option<usize>,
 ) -> UInt32Array {
-    let mut valids = value_indices
+    // Note: Why do we use 4‑byte prefix?
+    // Compute the 4‑byte prefix in BE order, or left‑pad if shorter.
+    // Most byte‐sequences differ in their first few bytes, so by
+    // comparing up to 4 bytes as a single u32 we avoid the overhead
+    // of a full lexicographical compare for the vast majority of cases.
+
+    // 1. Build a vector of (index, prefix, length) tuples
+    let mut valids: Vec<(u32, u32, u64)> = value_indices
         .into_iter()
-        .map(|index| (index, values.value(index as usize).as_ref()))
-        .collect::<Vec<(u32, &[u8])>>();
+        .map(|idx| unsafe {
+            let slice: &[u8] = values.value_unchecked(idx as usize).as_ref();
+            let len = slice.len() as u64;
+            // Compute the 4‑byte prefix in BE order, or left‑pad if shorter
+            let prefix = if slice.len() >= 4 {
+                let raw = std::ptr::read_unaligned(slice.as_ptr() as *const 
u32);
+                u32::from_be(raw)
+            } else if slice.is_empty() {
+                // Handle empty slice case to avoid shift overflow
+                0u32
+            } else {
+                let mut v = 0u32;
+                for &b in slice {
+                    v = (v << 8) | (b as u32);
+                }
+                // Safe shift: slice.len() is in range [1, 3], so shift is in 
range [8, 24]
+                v << (8 * (4 - slice.len()))
+            };
+            (idx, prefix, len)
+        })
+        .collect();
 
-    sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into()
+    // 2. compute the number of non-null entries to partially sort
+    let vlimit = match (limit, options.nulls_first) {
+        (Some(l), true) => l.saturating_sub(nulls.len()).min(valids.len()),
+        _ => valids.len(),
+    };
+
+    // 3. Comparator: compare prefix, then (when both slices shorter than 4) 
length, otherwise full slice
+    let cmp_bytes = |a: &(u32, u32, u64), b: &(u32, u32, u64)| unsafe {
+        let (ia, pa, la) = *a;
+        let (ib, pb, lb) = *b;
+        // 3.1 prefix (first 4 bytes)
+        let ord = pa.cmp(&pb);
+        if ord != Ordering::Equal {
+            return ord;
+        }
+        // 3.2 only if both slices had length < 4 (so prefix was padded)
+        if la < 4 || lb < 4 {
+            let ord = la.cmp(&lb);
+            if ord != Ordering::Equal {
+                return ord;
+            }
+        }
+        // 3.3 full lexicographical compare
+        let a_bytes: &[u8] = values.value_unchecked(ia as usize).as_ref();
+        let b_bytes: &[u8] = values.value_unchecked(ib as usize).as_ref();
+        a_bytes.cmp(b_bytes)
+    };
+
+    // 4. Partially sort according to ascending/descending
+    if !options.descending {
+        sort_unstable_by(&mut valids, vlimit, cmp_bytes);
+    } else {
+        sort_unstable_by(&mut valids, vlimit, |x, y| cmp_bytes(x, 
y).reverse());
+    }
+
+    // 5. Assemble nulls and sorted indices into final output
+    let total = valids.len() + nulls.len();
+    let out_limit = limit.unwrap_or(total).min(total);
+    let mut out = Vec::with_capacity(out_limit);
+
+    if options.nulls_first {
+        out.extend_from_slice(&nulls[..nulls.len().min(out_limit)]);
+        let rem = out_limit - out.len();
+        out.extend(valids.iter().map(|&(i, _, _)| i).take(rem));
+    } else {
+        out.extend(valids.iter().map(|&(i, _, _)| i).take(out_limit));
+        let rem = out_limit - out.len();
+        out.extend_from_slice(&nulls[..rem]);
+    }
+
+    out.into()
 }
 
 fn sort_byte_view<T: ByteViewType>(
@@ -4841,4 +4917,301 @@ mod tests {
         assert_eq!(valid, vec![0, 2]);
         assert_eq!(nulls, vec![1, 3]);
     }
+
+    // Test specific edge case strings that exercise the 4-byte prefix logic
+    #[test]
+    fn test_specific_edge_cases() {
+        let test_cases = vec![
+            // Key test cases for lengths 1-4 that test prefix padding
+            "a", "ab", "ba", "baa", "abba", "abbc", "abc", "cda",
+            // Test cases where first 4 bytes are same but subsequent bytes 
differ
+            "abcd", "abcde", "abcdf", "abcdaaa", "abcdbbb",
+            // Test cases with length < 4 that require padding
+            "z", "za", "zaa", "zaaa", "zaaab", // Empty string
+            "",      // Test various length combinations with same prefix
+            "test", "test1", "test12", "test123", "test1234",
+        ];
+
+        // Use standard library sort as reference
+        let mut expected = test_cases.clone();
+        expected.sort();
+
+        // Use our sorting algorithm
+        let string_array = StringArray::from(test_cases.clone());
+        let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
+        let result = sort_bytes(
+            &string_array,
+            indices,
+            vec![], // no nulls
+            SortOptions::default(),
+            None,
+        );
+
+        // Verify results
+        let sorted_strings: Vec<&str> = result
+            .values()
+            .iter()
+            .map(|&idx| test_cases[idx as usize])
+            .collect();
+
+        assert_eq!(sorted_strings, expected);
+    }
+
+    // Test sorting correctness for different length combinations
+    #[test]
+    fn test_length_combinations() {
+        let test_cases = vec![
+            // Focus on testing strings of length 1-4, as these affect padding 
logic
+            ("", 0),
+            ("a", 1),
+            ("ab", 2),
+            ("abc", 3),
+            ("abcd", 4),
+            ("abcde", 5),
+            ("b", 1),
+            ("ba", 2),
+            ("bab", 3),
+            ("babc", 4),
+            ("babcd", 5),
+            // Test same prefix with different lengths
+            ("test", 4),
+            ("test1", 5),
+            ("test12", 6),
+            ("test123", 7),
+        ];
+
+        let strings: Vec<&str> = test_cases.iter().map(|(s, _)| *s).collect();
+        let mut expected = strings.clone();
+        expected.sort();
+
+        let string_array = StringArray::from(strings.clone());
+        let indices: Vec<u32> = (0..strings.len() as u32).collect();
+        let result = sort_bytes(&string_array, indices, vec![], 
SortOptions::default(), None);
+
+        let sorted_strings: Vec<&str> = result
+            .values()
+            .iter()
+            .map(|&idx| strings[idx as usize])
+            .collect();
+
+        assert_eq!(sorted_strings, expected);
+    }
+
+    // Test UTF-8 string handling
+    #[test]
+    fn test_utf8_strings() {
+        let test_cases = vec![
+            "a",
+            "你",       // 3-byte UTF-8 character
+            "你好",     // 6 bytes
+            "你好世界", // 12 bytes
+            "🎉",       // 4-byte emoji
+            "🎉🎊",     // 8 bytes
+            "café",     // Contains accent character
+            "naïve",
+            "Москва", // Cyrillic script
+            "東京",   // Japanese kanji
+            "한국",   // Korean
+        ];
+
+        let mut expected = test_cases.clone();
+        expected.sort();
+
+        let string_array = StringArray::from(test_cases.clone());
+        let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
+        let result = sort_bytes(&string_array, indices, vec![], 
SortOptions::default(), None);
+
+        let sorted_strings: Vec<&str> = result
+            .values()
+            .iter()
+            .map(|&idx| test_cases[idx as usize])
+            .collect();
+
+        assert_eq!(sorted_strings, expected);
+    }
+
+    // Fuzz testing: generate random UTF-8 strings and verify sort correctness
+    #[test]
+    fn test_fuzz_random_strings() {
+        let mut rng = StdRng::seed_from_u64(42); // Fixed seed for 
reproducibility
+
+        for _ in 0..100 {
+            // Run 100 rounds of fuzz testing
+            let mut test_strings = Vec::new();
+
+            // Generate 20-50 random strings
+            let num_strings = rng.random_range(20..=50);
+
+            for _ in 0..num_strings {
+                let string = generate_random_string(&mut rng);
+                test_strings.push(string);
+            }
+
+            // Use standard library sort as reference
+            let mut expected = test_strings.clone();
+            expected.sort();
+
+            // Use our sorting algorithm
+            let string_array = StringArray::from(test_strings.clone());
+            let indices: Vec<u32> = (0..test_strings.len() as u32).collect();
+            let result = sort_bytes(&string_array, indices, vec![], 
SortOptions::default(), None);
+
+            let sorted_strings: Vec<String> = result
+                .values()
+                .iter()
+                .map(|&idx| test_strings[idx as usize].clone())
+                .collect();
+
+            assert_eq!(
+                sorted_strings, expected,
+                "Fuzz test failed with input: {test_strings:?}"
+            );
+        }
+    }
+
+    // Helper function to generate random UTF-8 strings
+    fn generate_random_string(rng: &mut StdRng) -> String {
+        // Bias towards generating short strings, especially length 1-4
+        let length = if rng.random_bool(0.6) {
+            rng.random_range(0..=4) // 60% probability for 0-4 length strings
+        } else {
+            rng.random_range(5..=20) // 40% probability for longer strings
+        };
+
+        if length == 0 {
+            return String::new();
+        }
+
+        let mut result = String::new();
+        let mut current_len = 0;
+
+        while current_len < length {
+            let c = generate_random_char(rng);
+            let char_len = c.len_utf8();
+
+            // Ensure we don't exceed target length
+            if current_len + char_len <= length {
+                result.push(c);
+                current_len += char_len;
+            } else {
+                // If adding this character would exceed length, fill with 
ASCII
+                let remaining = length - current_len;
+                for _ in 0..remaining {
+                    result.push(rng.random_range('a'..='z'));
+                    current_len += 1;
+                }
+                break;
+            }
+        }
+
+        result
+    }
+
+    // Generate random characters (including various UTF-8 characters)
+    fn generate_random_char(rng: &mut StdRng) -> char {
+        match rng.random_range(0..10) {
+            0..=5 => rng.random_range('a'..='z'), // 60% ASCII lowercase
+            6 => rng.random_range('A'..='Z'),     // 10% ASCII uppercase
+            7 => rng.random_range('0'..='9'),     // 10% digits
+            8 => {
+                // 10% Chinese characters
+                let chinese_chars = ['你', '好', '世', '界', '测', '试', '中', '文'];
+                chinese_chars[rng.random_range(0..chinese_chars.len())]
+            }
+            9 => {
+                // 10% other Unicode characters (single `char`s)
+                let special_chars = ['é', 'ï', '🎉', '🎊', 'α', 'β', 'γ'];
+                special_chars[rng.random_range(0..special_chars.len())]
+            }
+            _ => unreachable!(),
+        }
+    }
+
+    // Test descending sort order
+    #[test]
+    fn test_descending_sort() {
+        let test_cases = vec!["a", "ab", "ba", "baa", "abba", "abbc", "abc", 
"cda"];
+
+        let mut expected = test_cases.clone();
+        expected.sort();
+        expected.reverse(); // Descending order
+
+        let string_array = StringArray::from(test_cases.clone());
+        let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
+        let result = sort_bytes(
+            &string_array,
+            indices,
+            vec![],
+            SortOptions {
+                descending: true,
+                nulls_first: false,
+            },
+            None,
+        );
+
+        let sorted_strings: Vec<&str> = result
+            .values()
+            .iter()
+            .map(|&idx| test_cases[idx as usize])
+            .collect();
+
+        assert_eq!(sorted_strings, expected);
+    }
+
+    // Stress test: large number of strings with same prefix
+    #[test]
+    fn test_same_prefix_stress() {
+        let mut test_cases = Vec::new();
+        let prefix = "same";
+
+        // Generate many strings with the same prefix
+        for i in 0..1000 {
+            test_cases.push(format!("{prefix}{i:04}"));
+        }
+
+        let mut expected = test_cases.clone();
+        expected.sort();
+
+        let string_array = StringArray::from(test_cases.clone());
+        let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
+        let result = sort_bytes(&string_array, indices, vec![], 
SortOptions::default(), None);
+
+        let sorted_strings: Vec<String> = result
+            .values()
+            .iter()
+            .map(|&idx| test_cases[idx as usize].clone())
+            .collect();
+
+        assert_eq!(sorted_strings, expected);
+    }
+
+    // Test limit parameter
+    #[test]
+    fn test_with_limit() {
+        let test_cases = vec!["z", "y", "x", "w", "v", "u", "t", "s"];
+        let limit = 3;
+
+        let mut expected = test_cases.clone();
+        expected.sort();
+        expected.truncate(limit);
+
+        let string_array = StringArray::from(test_cases.clone());
+        let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
+        let result = sort_bytes(
+            &string_array,
+            indices,
+            vec![],
+            SortOptions::default(),
+            Some(limit),
+        );
+
+        let sorted_strings: Vec<&str> = result
+            .values()
+            .iter()
+            .map(|&idx| test_cases[idx as usize])
+            .collect();
+
+        assert_eq!(sorted_strings, expected);
+        assert_eq!(sorted_strings.len(), limit);
+    }
 }

Reply via email to