This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 1bf9a1bef6 Add hooks to json encoder to override default encoding or 
add support for unsupported types (#7015)
1bf9a1bef6 is described below

commit 1bf9a1bef697615e1dff2b5dfb3656ddd26be129
Author: Adrian Garcia Badaracco <[email protected]>
AuthorDate: Tue Mar 25 12:40:50 2025 -0500

    Add hooks to json encoder to override default encoding or add support for 
unsupported types (#7015)
    
    * add public json encoder trait
    
    * tweak bench
    
    * fmt
    
    * tweak bench
    
    * wip
    
    * refactor EncoderOptions into a builder
    
    * remove unused new method
    
    * fmt
    
    * wip
    
    * remove dynamic dispatch
    
    * clippy
    
    * remove bench
    
    * remove bench
    
    * address perf
    
    * add doctest
    
    * fmt
    
    * rename EncoderWithNullBuffer to NullableEncoder
    
    * Update arrow-json/src/writer/mod.rs
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
    
    * Update mod.rs
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
    
    ---------
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
    Co-authored-by: Andrew Lamb <[email protected]>
---
 arrow-json/src/lib.rs            |   5 +-
 arrow-json/src/writer/encoder.rs | 421 +++++++++++++++++++++++++++---------
 arrow-json/src/writer/mod.rs     | 451 ++++++++++++++++++++++++++++++++++++++-
 3 files changed, 770 insertions(+), 107 deletions(-)

diff --git a/arrow-json/src/lib.rs b/arrow-json/src/lib.rs
index ea0446c3d6..6d7ab4400b 100644
--- a/arrow-json/src/lib.rs
+++ b/arrow-json/src/lib.rs
@@ -75,7 +75,10 @@ pub mod reader;
 pub mod writer;
 
 pub use self::reader::{Reader, ReaderBuilder};
-pub use self::writer::{ArrayWriter, LineDelimitedWriter, Writer, 
WriterBuilder};
+pub use self::writer::{
+    ArrayWriter, Encoder, EncoderFactory, EncoderOptions, LineDelimitedWriter, 
Writer,
+    WriterBuilder,
+};
 use half::f16;
 use serde_json::{Number, Value};
 
diff --git a/arrow-json/src/writer/encoder.rs b/arrow-json/src/writer/encoder.rs
index 0b3c788d55..ee6af03101 100644
--- a/arrow-json/src/writer/encoder.rs
+++ b/arrow-json/src/writer/encoder.rs
@@ -14,6 +14,8 @@
 // KIND, either express or implied.  See the License for the
 // specific language governing permissions and limitations
 // under the License.
+use std::io::Write;
+use std::sync::Arc;
 
 use crate::StructMode;
 use arrow_array::cast::AsArray;
@@ -25,126 +27,322 @@ use arrow_schema::{ArrowError, DataType, FieldRef};
 use half::f16;
 use lexical_core::FormattedSize;
 use serde::Serializer;
-use std::io::Write;
 
+/// Configuration options for the JSON encoder.
 #[derive(Debug, Clone, Default)]
 pub struct EncoderOptions {
-    pub explicit_nulls: bool,
-    pub struct_mode: StructMode,
+    /// Whether to include nulls in the output or elide them.
+    explicit_nulls: bool,
+    /// Whether to encode structs as JSON objects or JSON arrays of their 
values.
+    struct_mode: StructMode,
+    /// An optional hook for customizing encoding behavior.
+    encoder_factory: Option<Arc<dyn EncoderFactory>>,
+}
+
+impl EncoderOptions {
+    /// Set whether to include nulls in the output or elide them.
+    pub fn with_explicit_nulls(mut self, explicit_nulls: bool) -> Self {
+        self.explicit_nulls = explicit_nulls;
+        self
+    }
+
+    /// Set whether to encode structs as JSON objects or JSON arrays of their 
values.
+    pub fn with_struct_mode(mut self, struct_mode: StructMode) -> Self {
+        self.struct_mode = struct_mode;
+        self
+    }
+
+    /// Set an optional hook for customizing encoding behavior.
+    pub fn with_encoder_factory(mut self, encoder_factory: Arc<dyn 
EncoderFactory>) -> Self {
+        self.encoder_factory = Some(encoder_factory);
+        self
+    }
+
+    /// Get whether to include nulls in the output or elide them.
+    pub fn explicit_nulls(&self) -> bool {
+        self.explicit_nulls
+    }
+
+    /// Get whether to encode structs as JSON objects or JSON arrays of their 
values.
+    pub fn struct_mode(&self) -> StructMode {
+        self.struct_mode
+    }
+
+    /// Get the optional hook for customizing encoding behavior.
+    pub fn encoder_factory(&self) -> Option<&Arc<dyn EncoderFactory>> {
+        self.encoder_factory.as_ref()
+    }
+}
+
+/// A trait to create custom encoders for specific data types.
+///
+/// This allows overriding the default encoders for specific data types,
+/// or adding new encoders for custom data types.
+///
+/// # Examples
+///
+/// ```
+/// use std::io::Write;
+/// use arrow_array::{ArrayAccessor, Array, BinaryArray, Float64Array, 
RecordBatch};
+/// use arrow_array::cast::AsArray;
+/// use arrow_schema::{DataType, Field, Schema, FieldRef};
+/// use arrow_json::{writer::{WriterBuilder, JsonArray, NullableEncoder}, 
StructMode};
+/// use arrow_json::{Encoder, EncoderFactory, EncoderOptions};
+/// use arrow_schema::ArrowError;
+/// use std::sync::Arc;
+/// use serde_json::json;
+/// use serde_json::Value;
+///
+/// struct IntArrayBinaryEncoder<B> {
+///     array: B,
+/// }
+///
+/// impl<'a, B> Encoder for IntArrayBinaryEncoder<B>
+/// where
+///     B: ArrayAccessor<Item = &'a [u8]>,
+/// {
+///     fn encode(&mut self, idx: usize, out: &mut Vec<u8>) {
+///         out.push(b'[');
+///         let child = self.array.value(idx);
+///         for (idx, byte) in child.iter().enumerate() {
+///             write!(out, "{byte}").unwrap();
+///             if idx < child.len() - 1 {
+///                 out.push(b',');
+///             }
+///         }
+///         out.push(b']');
+///     }
+/// }
+///
+/// #[derive(Debug)]
+/// struct IntArayBinaryEncoderFactory;
+///
+/// impl EncoderFactory for IntArayBinaryEncoderFactory {
+///     fn make_default_encoder<'a>(
+///         &self,
+///         _field: &'a FieldRef,
+///         array: &'a dyn Array,
+///         _options: &'a EncoderOptions,
+///     ) -> Result<Option<NullableEncoder<'a>>, ArrowError> {
+///         match array.data_type() {
+///             DataType::Binary => {
+///                 let array = array.as_binary::<i32>();
+///                 let encoder = IntArrayBinaryEncoder { array };
+///                 let array_encoder = Box::new(encoder) as Box<dyn Encoder + 
'a>;
+///                 let nulls = array.nulls().cloned();
+///                 Ok(Some(NullableEncoder::new(array_encoder, nulls)))
+///             }
+///             _ => Ok(None),
+///         }
+///     }
+/// }
+///
+/// let binary_array = BinaryArray::from_iter([Some(b"a".as_slice()), None, 
Some(b"b".as_slice())]);
+/// let float_array = Float64Array::from(vec![Some(1.0), Some(2.3), None]);
+/// let fields = vec![
+///     Field::new("bytes", DataType::Binary, true),
+///     Field::new("float", DataType::Float64, true),
+/// ];
+/// let batch = RecordBatch::try_new(
+///     Arc::new(Schema::new(fields)),
+///     vec![
+///         Arc::new(binary_array) as Arc<dyn Array>,
+///         Arc::new(float_array) as Arc<dyn Array>,
+///     ],
+/// )
+/// .unwrap();
+///
+/// let json_value: Value = {
+///     let mut buf = Vec::new();
+///     let mut writer = WriterBuilder::new()
+///         .with_encoder_factory(Arc::new(IntArayBinaryEncoderFactory))
+///         .build::<_, JsonArray>(&mut buf);
+///     writer.write_batches(&[&batch]).unwrap();
+///     writer.finish().unwrap();
+///     serde_json::from_slice(&buf).unwrap()
+/// };
+///
+/// let expected = json!([
+///     {"bytes": [97], "float": 1.0},
+///     {"float": 2.3},
+///     {"bytes": [98]},
+/// ]);
+///
+/// assert_eq!(json_value, expected);
+/// ```
+pub trait EncoderFactory: std::fmt::Debug + Send + Sync {
+    /// Make an encoder that overrides the default encoder for a specific 
field and array or provides an encoder for a custom data type.
+    /// This can be used to override how e.g. binary data is encoded so that 
it is an encoded string or an array of integers.
+    ///
+    /// Note that the type of the field may not match the type of the array: 
for dictionary arrays unless the top-level dictionary is handled this
+    /// will be called again for the keys and values of the dictionary, at 
which point the field type will still be the outer dictionary type but the
+    /// array will have a different type.
+    /// For example, `field`` might have the type `Dictionary(i32, Utf8)` but 
`array` will be `Utf8`.
+    fn make_default_encoder<'a>(
+        &self,
+        _field: &'a FieldRef,
+        _array: &'a dyn Array,
+        _options: &'a EncoderOptions,
+    ) -> Result<Option<NullableEncoder<'a>>, ArrowError> {
+        Ok(None)
+    }
+}
+
+/// An encoder + a null buffer.
+/// This is packaged together into a wrapper struct to minimize dynamic 
dispatch for null checks.
+pub struct NullableEncoder<'a> {
+    encoder: Box<dyn Encoder + 'a>,
+    nulls: Option<NullBuffer>,
+}
+
+impl<'a> NullableEncoder<'a> {
+    /// Create a new encoder with a null buffer.
+    pub fn new(encoder: Box<dyn Encoder + 'a>, nulls: Option<NullBuffer>) -> 
Self {
+        Self { encoder, nulls }
+    }
+
+    /// Encode the value at index `idx` to `out`.
+    pub fn encode(&mut self, idx: usize, out: &mut Vec<u8>) {
+        self.encoder.encode(idx, out)
+    }
+
+    /// Returns whether the value at index `idx` is null.
+    pub fn is_null(&self, idx: usize) -> bool {
+        self.nulls.as_ref().is_some_and(|nulls| nulls.is_null(idx))
+    }
+
+    /// Returns whether the encoder has any nulls.
+    pub fn has_nulls(&self) -> bool {
+        match self.nulls {
+            Some(ref nulls) => nulls.null_count() > 0,
+            None => false,
+        }
+    }
+}
+
+impl Encoder for NullableEncoder<'_> {
+    fn encode(&mut self, idx: usize, out: &mut Vec<u8>) {
+        self.encoder.encode(idx, out)
+    }
 }
 
 /// A trait to format array values as JSON values
 ///
 /// Nullability is handled by the caller to allow encoding nulls implicitly, 
i.e. `{}` instead of `{"a": null}`
 pub trait Encoder {
-    /// Encode the non-null value at index `idx` to `out`
+    /// Encode the non-null value at index `idx` to `out`.
     ///
-    /// The behaviour is unspecified if `idx` corresponds to a null index
+    /// The behaviour is unspecified if `idx` corresponds to a null index.
     fn encode(&mut self, idx: usize, out: &mut Vec<u8>);
 }
 
+/// Creates an encoder for the given array and field.
+///
+/// This first calls the EncoderFactory if one is provided, and then falls 
back to the default encoders.
 pub fn make_encoder<'a>(
+    field: &'a FieldRef,
     array: &'a dyn Array,
-    options: &EncoderOptions,
-) -> Result<Box<dyn Encoder + 'a>, ArrowError> {
-    let (encoder, nulls) = make_encoder_impl(array, options)?;
-    assert!(nulls.is_none(), "root cannot be nullable");
-    Ok(encoder)
-}
-
-fn make_encoder_impl<'a>(
-    array: &'a dyn Array,
-    options: &EncoderOptions,
-) -> Result<(Box<dyn Encoder + 'a>, Option<NullBuffer>), ArrowError> {
+    options: &'a EncoderOptions,
+) -> Result<NullableEncoder<'a>, ArrowError> {
     macro_rules! primitive_helper {
         ($t:ty) => {{
             let array = array.as_primitive::<$t>();
             let nulls = array.nulls().cloned();
-            (Box::new(PrimitiveEncoder::new(array)) as _, nulls)
+            NullableEncoder::new(Box::new(PrimitiveEncoder::new(array)), nulls)
         }};
     }
 
-    Ok(downcast_integer! {
+    if let Some(factory) = options.encoder_factory() {
+        if let Some(encoder) = factory.make_default_encoder(field, array, 
options)? {
+            return Ok(encoder);
+        }
+    }
+
+    let nulls = array.nulls().cloned();
+    let encoder = downcast_integer! {
         array.data_type() => (primitive_helper),
         DataType::Float16 => primitive_helper!(Float16Type),
         DataType::Float32 => primitive_helper!(Float32Type),
         DataType::Float64 => primitive_helper!(Float64Type),
         DataType::Boolean => {
             let array = array.as_boolean();
-            (Box::new(BooleanEncoder(array)), array.nulls().cloned())
+            NullableEncoder::new(Box::new(BooleanEncoder(array)), 
array.nulls().cloned())
         }
-        DataType::Null => (Box::new(NullEncoder), array.logical_nulls()),
+        DataType::Null => NullableEncoder::new(Box::new(NullEncoder), 
array.logical_nulls()),
         DataType::Utf8 => {
             let array = array.as_string::<i32>();
-            (Box::new(StringEncoder(array)) as _, array.nulls().cloned())
+            NullableEncoder::new(Box::new(StringEncoder(array)), 
array.nulls().cloned())
         }
         DataType::LargeUtf8 => {
             let array = array.as_string::<i64>();
-            (Box::new(StringEncoder(array)) as _, array.nulls().cloned())
+            NullableEncoder::new(Box::new(StringEncoder(array)), 
array.nulls().cloned())
         }
         DataType::Utf8View => {
             let array = array.as_string_view();
-            (Box::new(StringViewEncoder(array)) as _, array.nulls().cloned())
+            NullableEncoder::new(Box::new(StringViewEncoder(array)), 
array.nulls().cloned())
         }
         DataType::List(_) => {
             let array = array.as_list::<i32>();
-            (Box::new(ListEncoder::try_new(array, options)?) as _, 
array.nulls().cloned())
+            NullableEncoder::new(Box::new(ListEncoder::try_new(field, array, 
options)?), array.nulls().cloned())
         }
         DataType::LargeList(_) => {
             let array = array.as_list::<i64>();
-            (Box::new(ListEncoder::try_new(array, options)?) as _, 
array.nulls().cloned())
+            NullableEncoder::new(Box::new(ListEncoder::try_new(field, array, 
options)?), array.nulls().cloned())
         }
         DataType::FixedSizeList(_, _) => {
             let array = array.as_fixed_size_list();
-            (Box::new(FixedSizeListEncoder::try_new(array, options)?) as _, 
array.nulls().cloned())
+            NullableEncoder::new(Box::new(FixedSizeListEncoder::try_new(field, 
array, options)?), array.nulls().cloned())
         }
 
         DataType::Dictionary(_, _) => downcast_dictionary_array! {
-            array => (Box::new(DictionaryEncoder::try_new(array, options)?) as 
_,  array.logical_nulls()),
+            array => {
+                
NullableEncoder::new(Box::new(DictionaryEncoder::try_new(field, array, 
options)?), array.nulls().cloned())
+            },
             _ => unreachable!()
         }
 
         DataType::Map(_, _) => {
             let array = array.as_map();
-            (Box::new(MapEncoder::try_new(array, options)?) as _,  
array.nulls().cloned())
+            NullableEncoder::new(Box::new(MapEncoder::try_new(field, array, 
options)?), array.nulls().cloned())
         }
 
         DataType::FixedSizeBinary(_) => {
             let array = array.as_fixed_size_binary();
-            (Box::new(BinaryEncoder::new(array)) as _, array.nulls().cloned())
+            NullableEncoder::new(Box::new(BinaryEncoder::new(array)) as _, 
array.nulls().cloned())
         }
 
         DataType::Binary => {
             let array: &BinaryArray = array.as_binary();
-            (Box::new(BinaryEncoder::new(array)) as _, array.nulls().cloned())
+            NullableEncoder::new(Box::new(BinaryEncoder::new(array)), 
array.nulls().cloned())
         }
 
         DataType::LargeBinary => {
             let array: &LargeBinaryArray = array.as_binary();
-            (Box::new(BinaryEncoder::new(array)) as _, array.nulls().cloned())
+            NullableEncoder::new(Box::new(BinaryEncoder::new(array)), 
array.nulls().cloned())
         }
 
         DataType::Struct(fields) => {
             let array = array.as_struct();
             let encoders = fields.iter().zip(array.columns()).map(|(field, 
array)| {
-                let (encoder, nulls) = make_encoder_impl(array, options)?;
+                let encoder = make_encoder(field, array, options)?;
                 Ok(FieldEncoder{
                     field: field.clone(),
-                    encoder, nulls
+                    encoder,
                 })
             }).collect::<Result<Vec<_>, ArrowError>>()?;
 
             let encoder = StructArrayEncoder{
                 encoders,
-                explicit_nulls: options.explicit_nulls,
-                struct_mode: options.struct_mode,
+                explicit_nulls: options.explicit_nulls(),
+                struct_mode: options.struct_mode(),
             };
-            (Box::new(encoder) as _, array.nulls().cloned())
+            let nulls = array.nulls().cloned();
+            NullableEncoder::new(Box::new(encoder) as Box<dyn Encoder + 'a>, 
nulls)
         }
         DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
             let options = FormatOptions::new().with_display_error(true);
-            let formatter = ArrayFormatter::try_new(array, &options)?;
-            (Box::new(RawArrayFormatter(formatter)) as _, 
array.nulls().cloned())
+            let formatter = 
JsonArrayFormatter::new(ArrayFormatter::try_new(array, &options)?);
+            NullableEncoder::new(Box::new(RawArrayFormatter(formatter)) as 
Box<dyn Encoder + 'a>, nulls)
         }
         d => match d.is_temporal() {
             true => {
@@ -154,11 +352,17 @@ fn make_encoder_impl<'a>(
                 // may need to be revisited
                 let options = FormatOptions::new().with_display_error(true);
                 let formatter = ArrayFormatter::try_new(array, &options)?;
-                (Box::new(formatter) as _, array.nulls().cloned())
+                let formatter = JsonArrayFormatter::new(formatter);
+                NullableEncoder::new(Box::new(formatter) as Box<dyn Encoder + 
'a>, nulls)
             }
-            false => return Err(ArrowError::InvalidArgumentError(format!("JSON 
Writer does not support data type: {d}"))),
+            false => return Err(ArrowError::JsonError(format!(
+                "Unsupported data type for JSON encoding: {:?}",
+                d
+            )))
         }
-    })
+    };
+
+    Ok(encoder)
 }
 
 fn encode_string(s: &str, out: &mut Vec<u8>) {
@@ -168,8 +372,13 @@ fn encode_string(s: &str, out: &mut Vec<u8>) {
 
 struct FieldEncoder<'a> {
     field: FieldRef,
-    encoder: Box<dyn Encoder + 'a>,
-    nulls: Option<NullBuffer>,
+    encoder: NullableEncoder<'a>,
+}
+
+impl FieldEncoder<'_> {
+    fn is_null(&self, idx: usize) -> bool {
+        self.encoder.is_null(idx)
+    }
 }
 
 struct StructArrayEncoder<'a> {
@@ -196,9 +405,10 @@ impl Encoder for StructArrayEncoder<'_> {
         let mut is_first = true;
         // Nulls can only be dropped in explicit mode
         let drop_nulls = (self.struct_mode == StructMode::ObjectOnly) && 
!self.explicit_nulls;
-        for field_encoder in &mut self.encoders {
-            let is_null = is_some_and(field_encoder.nulls.as_ref(), |n| 
n.is_null(idx));
-            if drop_nulls && is_null {
+
+        for field_encoder in self.encoders.iter_mut() {
+            let is_null = field_encoder.is_null(idx);
+            if is_null && drop_nulls {
                 continue;
             }
 
@@ -212,9 +422,10 @@ impl Encoder for StructArrayEncoder<'_> {
                 out.push(b':');
             }
 
-            match is_null {
-                true => out.extend_from_slice(b"null"),
-                false => field_encoder.encoder.encode(idx, out),
+            if is_null {
+                out.extend_from_slice(b"null");
+            } else {
+                field_encoder.encoder.encode(idx, out);
             }
         }
         match self.struct_mode {
@@ -339,20 +550,19 @@ impl Encoder for StringViewEncoder<'_> {
 
 struct ListEncoder<'a, O: OffsetSizeTrait> {
     offsets: OffsetBuffer<O>,
-    nulls: Option<NullBuffer>,
-    encoder: Box<dyn Encoder + 'a>,
+    encoder: NullableEncoder<'a>,
 }
 
 impl<'a, O: OffsetSizeTrait> ListEncoder<'a, O> {
     fn try_new(
+        field: &'a FieldRef,
         array: &'a GenericListArray<O>,
-        options: &EncoderOptions,
+        options: &'a EncoderOptions,
     ) -> Result<Self, ArrowError> {
-        let (encoder, nulls) = make_encoder_impl(array.values().as_ref(), 
options)?;
+        let encoder = make_encoder(field, array.values().as_ref(), options)?;
         Ok(Self {
             offsets: array.offsets().clone(),
             encoder,
-            nulls,
         })
     }
 }
@@ -362,22 +572,25 @@ impl<O: OffsetSizeTrait> Encoder for ListEncoder<'_, O> {
         let end = self.offsets[idx + 1].as_usize();
         let start = self.offsets[idx].as_usize();
         out.push(b'[');
-        match self.nulls.as_ref() {
-            Some(n) => (start..end).for_each(|idx| {
+
+        if self.encoder.has_nulls() {
+            for idx in start..end {
                 if idx != start {
                     out.push(b',')
                 }
-                match n.is_null(idx) {
-                    true => out.extend_from_slice(b"null"),
-                    false => self.encoder.encode(idx, out),
+                if self.encoder.is_null(idx) {
+                    out.extend_from_slice(b"null");
+                } else {
+                    self.encoder.encode(idx, out);
                 }
-            }),
-            None => (start..end).for_each(|idx| {
+            }
+        } else {
+            for idx in start..end {
                 if idx != start {
                     out.push(b',')
                 }
                 self.encoder.encode(idx, out);
-            }),
+            }
         }
         out.push(b']');
     }
@@ -385,19 +598,18 @@ impl<O: OffsetSizeTrait> Encoder for ListEncoder<'_, O> {
 
 struct FixedSizeListEncoder<'a> {
     value_length: usize,
-    nulls: Option<NullBuffer>,
-    encoder: Box<dyn Encoder + 'a>,
+    encoder: NullableEncoder<'a>,
 }
 
 impl<'a> FixedSizeListEncoder<'a> {
     fn try_new(
+        field: &'a FieldRef,
         array: &'a FixedSizeListArray,
-        options: &EncoderOptions,
+        options: &'a EncoderOptions,
     ) -> Result<Self, ArrowError> {
-        let (encoder, nulls) = make_encoder_impl(array.values().as_ref(), 
options)?;
+        let encoder = make_encoder(field, array.values().as_ref(), options)?;
         Ok(Self {
             encoder,
-            nulls,
             value_length: array.value_length().as_usize(),
         })
     }
@@ -408,23 +620,24 @@ impl Encoder for FixedSizeListEncoder<'_> {
         let start = idx * self.value_length;
         let end = start + self.value_length;
         out.push(b'[');
-        match self.nulls.as_ref() {
-            Some(n) => (start..end).for_each(|idx| {
+        if self.encoder.has_nulls() {
+            for idx in start..end {
                 if idx != start {
-                    out.push(b',');
+                    out.push(b',')
                 }
-                if n.is_null(idx) {
+                if self.encoder.is_null(idx) {
                     out.extend_from_slice(b"null");
                 } else {
                     self.encoder.encode(idx, out);
                 }
-            }),
-            None => (start..end).for_each(|idx| {
+            }
+        } else {
+            for idx in start..end {
                 if idx != start {
-                    out.push(b',');
+                    out.push(b',')
                 }
                 self.encoder.encode(idx, out);
-            }),
+            }
         }
         out.push(b']');
     }
@@ -432,15 +645,16 @@ impl Encoder for FixedSizeListEncoder<'_> {
 
 struct DictionaryEncoder<'a, K: ArrowDictionaryKeyType> {
     keys: ScalarBuffer<K::Native>,
-    encoder: Box<dyn Encoder + 'a>,
+    encoder: NullableEncoder<'a>,
 }
 
 impl<'a, K: ArrowDictionaryKeyType> DictionaryEncoder<'a, K> {
     fn try_new(
+        field: &'a FieldRef,
         array: &'a DictionaryArray<K>,
-        options: &EncoderOptions,
+        options: &'a EncoderOptions,
     ) -> Result<Self, ArrowError> {
-        let (encoder, _) = make_encoder_impl(array.values().as_ref(), 
options)?;
+        let encoder = make_encoder(field, array.values().as_ref(), options)?;
 
         Ok(Self {
             keys: array.keys().values().clone(),
@@ -455,22 +669,33 @@ impl<K: ArrowDictionaryKeyType> Encoder for 
DictionaryEncoder<'_, K> {
     }
 }
 
-impl Encoder for ArrayFormatter<'_> {
+/// A newtype wrapper around [`ArrayFormatter`] to keep our usage of it 
private and not implement `Encoder` for the public type
+struct JsonArrayFormatter<'a> {
+    formatter: ArrayFormatter<'a>,
+}
+
+impl<'a> JsonArrayFormatter<'a> {
+    fn new(formatter: ArrayFormatter<'a>) -> Self {
+        Self { formatter }
+    }
+}
+
+impl Encoder for JsonArrayFormatter<'_> {
     fn encode(&mut self, idx: usize, out: &mut Vec<u8>) {
         out.push(b'"');
         // Should be infallible
         // Note: We are making an assumption that the formatter does not 
produce characters that require escaping
-        let _ = write!(out, "{}", self.value(idx));
+        let _ = write!(out, "{}", self.formatter.value(idx));
         out.push(b'"')
     }
 }
 
-/// A newtype wrapper around [`ArrayFormatter`] that skips surrounding the 
value with `"`
-struct RawArrayFormatter<'a>(ArrayFormatter<'a>);
+/// A newtype wrapper around [`JsonArrayFormatter`] that skips surrounding the 
value with `"`
+struct RawArrayFormatter<'a>(JsonArrayFormatter<'a>);
 
 impl Encoder for RawArrayFormatter<'_> {
     fn encode(&mut self, idx: usize, out: &mut Vec<u8>) {
-        let _ = write!(out, "{}", self.0.value(idx));
+        let _ = write!(out, "{}", self.0.formatter.value(idx));
     }
 }
 
@@ -484,14 +709,17 @@ impl Encoder for NullEncoder {
 
 struct MapEncoder<'a> {
     offsets: OffsetBuffer<i32>,
-    keys: Box<dyn Encoder + 'a>,
-    values: Box<dyn Encoder + 'a>,
-    value_nulls: Option<NullBuffer>,
+    keys: NullableEncoder<'a>,
+    values: NullableEncoder<'a>,
     explicit_nulls: bool,
 }
 
 impl<'a> MapEncoder<'a> {
-    fn try_new(array: &'a MapArray, options: &EncoderOptions) -> Result<Self, 
ArrowError> {
+    fn try_new(
+        field: &'a FieldRef,
+        array: &'a MapArray,
+        options: &'a EncoderOptions,
+    ) -> Result<Self, ArrowError> {
         let values = array.values();
         let keys = array.keys();
 
@@ -502,11 +730,11 @@ impl<'a> MapEncoder<'a> {
             )));
         }
 
-        let (keys, key_nulls) = make_encoder_impl(keys, options)?;
-        let (values, value_nulls) = make_encoder_impl(values, options)?;
+        let keys = make_encoder(field, keys, options)?;
+        let values = make_encoder(field, values, options)?;
 
         // We sanity check nulls as these are currently not enforced by 
MapArray (#1697)
-        if is_some_and(key_nulls, |x| x.null_count() != 0) {
+        if keys.has_nulls() {
             return Err(ArrowError::InvalidArgumentError(
                 "Encountered nulls in MapArray keys".to_string(),
             ));
@@ -522,8 +750,7 @@ impl<'a> MapEncoder<'a> {
             offsets: array.offsets().clone(),
             keys,
             values,
-            value_nulls,
-            explicit_nulls: options.explicit_nulls,
+            explicit_nulls: options.explicit_nulls(),
         })
     }
 }
@@ -536,8 +763,9 @@ impl Encoder for MapEncoder<'_> {
         let mut is_first = true;
 
         out.push(b'{');
+
         for idx in start..end {
-            let is_null = is_some_and(self.value_nulls.as_ref(), |n| 
n.is_null(idx));
+            let is_null = self.values.is_null(idx);
             if is_null && !self.explicit_nulls {
                 continue;
             }
@@ -550,9 +778,10 @@ impl Encoder for MapEncoder<'_> {
             self.keys.encode(idx, out);
             out.push(b':');
 
-            match is_null {
-                true => out.extend_from_slice(b"null"),
-                false => self.values.encode(idx, out),
+            if is_null {
+                out.extend_from_slice(b"null");
+            } else {
+                self.values.encode(idx, out);
             }
         }
         out.push(b'}');
diff --git a/arrow-json/src/writer/mod.rs b/arrow-json/src/writer/mod.rs
index 5d3e558480..ee1b5fabe5 100644
--- a/arrow-json/src/writer/mod.rs
+++ b/arrow-json/src/writer/mod.rs
@@ -106,13 +106,13 @@
 //! ```
 mod encoder;
 
-use std::{fmt::Debug, io::Write};
+use std::{fmt::Debug, io::Write, sync::Arc};
 
 use crate::StructMode;
 use arrow_array::*;
 use arrow_schema::*;
 
-use encoder::{make_encoder, EncoderOptions};
+pub use encoder::{make_encoder, Encoder, EncoderFactory, EncoderOptions, 
NullableEncoder};
 
 /// This trait defines how to format a sequence of JSON objects to a
 /// byte stream.
@@ -225,7 +225,7 @@ impl WriterBuilder {
 
     /// Returns `true` if this writer is configured to keep keys with null 
values.
     pub fn explicit_nulls(&self) -> bool {
-        self.0.explicit_nulls
+        self.0.explicit_nulls()
     }
 
     /// Set whether to keep keys with null values, or to omit writing them.
@@ -251,13 +251,13 @@ impl WriterBuilder {
     /// Default is to skip nulls (set to `false`). If `struct_mode == 
ListOnly`,
     /// nulls will be written explicitly regardless of this setting.
     pub fn with_explicit_nulls(mut self, explicit_nulls: bool) -> Self {
-        self.0.explicit_nulls = explicit_nulls;
+        self.0 = self.0.with_explicit_nulls(explicit_nulls);
         self
     }
 
     /// Returns if this writer is configured to write structs as JSON Objects 
or Arrays.
     pub fn struct_mode(&self) -> StructMode {
-        self.0.struct_mode
+        self.0.struct_mode()
     }
 
     /// Set the [`StructMode`] for the writer, which determines whether structs
@@ -266,7 +266,16 @@ impl WriterBuilder {
     /// `ListOnly`, nulls will be written explicitly regardless of the
     /// `explicit_nulls` setting.
     pub fn with_struct_mode(mut self, struct_mode: StructMode) -> Self {
-        self.0.struct_mode = struct_mode;
+        self.0 = self.0.with_struct_mode(struct_mode);
+        self
+    }
+
+    /// Set an encoder factory to use when creating encoders for writing JSON.
+    ///
+    /// This can be used to override how some types are encoded or to provide
+    /// a fallback for types that are not supported by the default encoder.
+    pub fn with_encoder_factory(mut self, factory: Arc<dyn EncoderFactory>) -> 
Self {
+        self.0 = self.0.with_encoder_factory(factory);
         self
     }
 
@@ -351,8 +360,16 @@ where
         }
 
         let array = StructArray::from(batch.clone());
-        let mut encoder = make_encoder(&array, &self.options)?;
+        let field = Arc::new(Field::new_struct(
+            "",
+            batch.schema().fields().clone(),
+            false,
+        ));
+
+        let mut encoder = make_encoder(&field, &array, &self.options)?;
 
+        // Validate that the root is not nullable
+        assert!(!encoder.has_nulls(), "root cannot be nullable");
         for idx in 0..batch.num_rows() {
             self.format.start_row(&mut buffer, is_first_row)?;
             is_first_row = false;
@@ -419,15 +436,19 @@ where
 #[cfg(test)]
 mod tests {
     use core::str;
+    use std::collections::HashMap;
     use std::fs::{read_to_string, File};
     use std::io::{BufReader, Seek};
     use std::sync::Arc;
 
+    use arrow_array::cast::AsArray;
     use serde_json::{json, Value};
 
+    use super::LineDelimited;
+    use super::{Encoder, WriterBuilder};
     use arrow_array::builder::*;
     use arrow_array::types::*;
-    use arrow_buffer::{i256, Buffer, NullBuffer, OffsetBuffer, ToByteSlice};
+    use arrow_buffer::{i256, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer, 
ToByteSlice};
     use arrow_data::ArrayData;
 
     use crate::reader::*;
@@ -446,7 +467,7 @@ mod tests {
             .map(|s| (!s.is_empty()).then(|| 
serde_json::from_slice(s).unwrap()))
             .collect();
 
-        assert_eq!(expected, actual);
+        assert_eq!(actual, expected);
     }
 
     #[test]
@@ -1891,7 +1912,7 @@ mod tests {
         let json_str = str::from_utf8(&json).unwrap();
         assert_eq!(
             json_str,
-            r#"[{"my_dict":"a"},{"my_dict":null},{"my_dict":null}]"#
+            r#"[{"my_dict":"a"},{"my_dict":null},{"my_dict":""}]"#
         )
     }
 
@@ -2036,4 +2057,414 @@ mod tests {
         }
         assert_json_eq(&buf, expected);
     }
+
+    fn make_fallback_encoder_test_data() -> (RecordBatch, Arc<dyn 
EncoderFactory>) {
+        // Note: this is not intended to be an efficient implementation.
+        // Just a simple example to demonstrate how to implement a custom 
encoder.
+        #[derive(Debug)]
+        enum UnionValue {
+            Int32(i32),
+            String(String),
+        }
+
+        #[derive(Debug)]
+        struct UnionEncoder {
+            array: Vec<Option<UnionValue>>,
+        }
+
+        impl Encoder for UnionEncoder {
+            fn encode(&mut self, idx: usize, out: &mut Vec<u8>) {
+                match &self.array[idx] {
+                    None => out.extend_from_slice(b"null"),
+                    Some(UnionValue::Int32(v)) => 
out.extend_from_slice(v.to_string().as_bytes()),
+                    Some(UnionValue::String(v)) => {
+                        out.extend_from_slice(format!("\"{}\"", v).as_bytes())
+                    }
+                }
+            }
+        }
+
+        #[derive(Debug)]
+        struct UnionEncoderFactory;
+
+        impl EncoderFactory for UnionEncoderFactory {
+            fn make_default_encoder<'a>(
+                &self,
+                _field: &'a FieldRef,
+                array: &'a dyn Array,
+                _options: &'a EncoderOptions,
+            ) -> Result<Option<NullableEncoder<'a>>, ArrowError> {
+                let data_type = array.data_type();
+                let fields = match data_type {
+                    DataType::Union(fields, UnionMode::Sparse) => fields,
+                    _ => return Ok(None),
+                };
+                // check that the fields are supported
+                let fields = fields.iter().map(|(_, f)| f).collect::<Vec<_>>();
+                for f in fields.iter() {
+                    match f.data_type() {
+                        DataType::Null => {}
+                        DataType::Int32 => {}
+                        DataType::Utf8 => {}
+                        _ => return Ok(None),
+                    }
+                }
+                let (_, type_ids, _, buffers) = 
array.as_union().clone().into_parts();
+                let mut values = Vec::with_capacity(type_ids.len());
+                for idx in 0..type_ids.len() {
+                    let type_id = type_ids[idx];
+                    let field = &fields[type_id as usize];
+                    let value = match field.data_type() {
+                        DataType::Null => None,
+                        DataType::Int32 => Some(UnionValue::Int32(
+                            buffers[type_id as usize]
+                                .as_primitive::<Int32Type>()
+                                .value(idx),
+                        )),
+                        DataType::Utf8 => Some(UnionValue::String(
+                            buffers[type_id as usize]
+                                .as_string::<i32>()
+                                .value(idx)
+                                .to_string(),
+                        )),
+                        _ => unreachable!(),
+                    };
+                    values.push(value);
+                }
+                let array_encoder =
+                    Box::new(UnionEncoder { array: values }) as Box<dyn 
Encoder + 'a>;
+                let nulls = array.nulls().cloned();
+                Ok(Some(NullableEncoder::new(array_encoder, nulls)))
+            }
+        }
+
+        let int_array = Int32Array::from(vec![Some(1), None, None]);
+        let string_array = StringArray::from(vec![None, Some("a"), None]);
+        let null_array = NullArray::new(3);
+        let type_ids = [0_i8, 1, 2].into_iter().collect::<ScalarBuffer<i8>>();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("A", DataType::Int32, false))),
+            (1, Arc::new(Field::new("B", DataType::Utf8, false))),
+            (2, Arc::new(Field::new("C", DataType::Null, false))),
+        ]
+        .into_iter()
+        .collect::<UnionFields>();
+
+        let children = vec![
+            Arc::new(int_array) as Arc<dyn Array>,
+            Arc::new(string_array),
+            Arc::new(null_array),
+        ];
+
+        let array = UnionArray::try_new(union_fields.clone(), type_ids, None, 
children).unwrap();
+
+        let float_array = Float64Array::from(vec![Some(1.0), None, Some(3.4)]);
+
+        let fields = vec![
+            Field::new(
+                "union",
+                DataType::Union(union_fields, UnionMode::Sparse),
+                true,
+            ),
+            Field::new("float", DataType::Float64, true),
+        ];
+
+        let batch = RecordBatch::try_new(
+            Arc::new(Schema::new(fields)),
+            vec![
+                Arc::new(array) as Arc<dyn Array>,
+                Arc::new(float_array) as Arc<dyn Array>,
+            ],
+        )
+        .unwrap();
+
+        (batch, Arc::new(UnionEncoderFactory))
+    }
+
+    #[test]
+    fn test_fallback_encoder_factory_line_delimited_implicit_nulls() {
+        let (batch, encoder_factory) = make_fallback_encoder_test_data();
+
+        let mut buf = Vec::new();
+        {
+            let mut writer = WriterBuilder::new()
+                .with_encoder_factory(encoder_factory)
+                .with_explicit_nulls(false)
+                .build::<_, LineDelimited>(&mut buf);
+            writer.write_batches(&[&batch]).unwrap();
+            writer.finish().unwrap();
+        }
+
+        println!("{}", str::from_utf8(&buf).unwrap());
+
+        assert_json_eq(
+            &buf,
+            r#"{"union":1,"float":1.0}
+{"union":"a"}
+{"union":null,"float":3.4}
+"#,
+        );
+    }
+
+    #[test]
+    fn test_fallback_encoder_factory_line_delimited_explicit_nulls() {
+        let (batch, encoder_factory) = make_fallback_encoder_test_data();
+
+        let mut buf = Vec::new();
+        {
+            let mut writer = WriterBuilder::new()
+                .with_encoder_factory(encoder_factory)
+                .with_explicit_nulls(true)
+                .build::<_, LineDelimited>(&mut buf);
+            writer.write_batches(&[&batch]).unwrap();
+            writer.finish().unwrap();
+        }
+
+        assert_json_eq(
+            &buf,
+            r#"{"union":1,"float":1.0}
+{"union":"a","float":null}
+{"union":null,"float":3.4}
+"#,
+        );
+    }
+
+    #[test]
+    fn test_fallback_encoder_factory_array_implicit_nulls() {
+        let (batch, encoder_factory) = make_fallback_encoder_test_data();
+
+        let json_value: Value = {
+            let mut buf = Vec::new();
+            let mut writer = WriterBuilder::new()
+                .with_encoder_factory(encoder_factory)
+                .build::<_, JsonArray>(&mut buf);
+            writer.write_batches(&[&batch]).unwrap();
+            writer.finish().unwrap();
+            serde_json::from_slice(&buf).unwrap()
+        };
+
+        let expected = json!([
+            {"union":1,"float":1.0},
+            {"union":"a"},
+            {"float":3.4,"union":null},
+        ]);
+
+        assert_eq!(json_value, expected);
+    }
+
+    #[test]
+    fn test_fallback_encoder_factory_array_explicit_nulls() {
+        let (batch, encoder_factory) = make_fallback_encoder_test_data();
+
+        let json_value: Value = {
+            let mut buf = Vec::new();
+            let mut writer = WriterBuilder::new()
+                .with_encoder_factory(encoder_factory)
+                .with_explicit_nulls(true)
+                .build::<_, JsonArray>(&mut buf);
+            writer.write_batches(&[&batch]).unwrap();
+            writer.finish().unwrap();
+            serde_json::from_slice(&buf).unwrap()
+        };
+
+        let expected = json!([
+            {"union":1,"float":1.0},
+            {"union":"a", "float": null},
+            {"union":null,"float":3.4},
+        ]);
+
+        assert_eq!(json_value, expected);
+    }
+
+    #[test]
+    fn test_default_encoder_byte_array() {
+        struct IntArrayBinaryEncoder<B> {
+            array: B,
+        }
+
+        impl<'a, B> Encoder for IntArrayBinaryEncoder<B>
+        where
+            B: ArrayAccessor<Item = &'a [u8]>,
+        {
+            fn encode(&mut self, idx: usize, out: &mut Vec<u8>) {
+                out.push(b'[');
+                let child = self.array.value(idx);
+                for (idx, byte) in child.iter().enumerate() {
+                    write!(out, "{byte}").unwrap();
+                    if idx < child.len() - 1 {
+                        out.push(b',');
+                    }
+                }
+                out.push(b']');
+            }
+        }
+
+        #[derive(Debug)]
+        struct IntArayBinaryEncoderFactory;
+
+        impl EncoderFactory for IntArayBinaryEncoderFactory {
+            fn make_default_encoder<'a>(
+                &self,
+                _field: &'a FieldRef,
+                array: &'a dyn Array,
+                _options: &'a EncoderOptions,
+            ) -> Result<Option<NullableEncoder<'a>>, ArrowError> {
+                match array.data_type() {
+                    DataType::Binary => {
+                        let array = array.as_binary::<i32>();
+                        let encoder = IntArrayBinaryEncoder { array };
+                        let array_encoder = Box::new(encoder) as Box<dyn 
Encoder + 'a>;
+                        let nulls = array.nulls().cloned();
+                        Ok(Some(NullableEncoder::new(array_encoder, nulls)))
+                    }
+                    _ => Ok(None),
+                }
+            }
+        }
+
+        let binary_array = BinaryArray::from_opt_vec(vec![Some(b"a"), None, 
Some(b"b")]);
+        let float_array = Float64Array::from(vec![Some(1.0), Some(2.3), None]);
+        let fields = vec![
+            Field::new("bytes", DataType::Binary, true),
+            Field::new("float", DataType::Float64, true),
+        ];
+        let batch = RecordBatch::try_new(
+            Arc::new(Schema::new(fields)),
+            vec![
+                Arc::new(binary_array) as Arc<dyn Array>,
+                Arc::new(float_array) as Arc<dyn Array>,
+            ],
+        )
+        .unwrap();
+
+        let json_value: Value = {
+            let mut buf = Vec::new();
+            let mut writer = WriterBuilder::new()
+                .with_encoder_factory(Arc::new(IntArayBinaryEncoderFactory))
+                .build::<_, JsonArray>(&mut buf);
+            writer.write_batches(&[&batch]).unwrap();
+            writer.finish().unwrap();
+            serde_json::from_slice(&buf).unwrap()
+        };
+
+        let expected = json!([
+            {"bytes": [97], "float": 1.0},
+            {"float": 2.3},
+            {"bytes": [98]},
+        ]);
+
+        assert_eq!(json_value, expected);
+    }
+
+    #[test]
+    fn test_encoder_factory_customize_dictionary() {
+        // Test that we can customize the encoding of T even when it shows up 
as Dictionary<_, T>.
+
+        // No particular reason to choose this example.
+        // Just trying to add some variety to the test cases and demonstrate 
use cases of the encoder factory.
+        struct PaddedInt32Encoder {
+            array: Int32Array,
+        }
+
+        impl Encoder for PaddedInt32Encoder {
+            fn encode(&mut self, idx: usize, out: &mut Vec<u8>) {
+                let value = self.array.value(idx);
+                write!(out, "\"{value:0>8}\"").unwrap();
+            }
+        }
+
+        #[derive(Debug)]
+        struct CustomEncoderFactory;
+
+        impl EncoderFactory for CustomEncoderFactory {
+            fn make_default_encoder<'a>(
+                &self,
+                field: &'a FieldRef,
+                array: &'a dyn Array,
+                _options: &'a EncoderOptions,
+            ) -> Result<Option<NullableEncoder<'a>>, ArrowError> {
+                // The point here is:
+                // 1. You can use information from Field to determine how to 
do the encoding.
+                // 2. For dictionary arrays the Field is always the outer 
field but the array may be the keys or values array
+                //    and thus the data type of `field` may not match the data 
type of `array`.
+                let padded = field
+                    .metadata()
+                    .get("padded")
+                    .map(|v| v == "true")
+                    .unwrap_or_default();
+                match (array.data_type(), padded) {
+                    (DataType::Int32, true) => {
+                        let array = array.as_primitive::<Int32Type>();
+                        let nulls = array.nulls().cloned();
+                        let encoder = PaddedInt32Encoder {
+                            array: array.clone(),
+                        };
+                        let array_encoder = Box::new(encoder) as Box<dyn 
Encoder + 'a>;
+                        Ok(Some(NullableEncoder::new(array_encoder, nulls)))
+                    }
+                    _ => Ok(None),
+                }
+            }
+        }
+
+        let to_json = |batch| {
+            let mut buf = Vec::new();
+            let mut writer = WriterBuilder::new()
+                .with_encoder_factory(Arc::new(CustomEncoderFactory))
+                .build::<_, JsonArray>(&mut buf);
+            writer.write_batches(&[batch]).unwrap();
+            writer.finish().unwrap();
+            serde_json::from_slice::<Value>(&buf).unwrap()
+        };
+
+        // Control case: no dictionary wrapping works as expected.
+        let array = Int32Array::from(vec![Some(1), None, Some(2)]);
+        let field = Arc::new(Field::new("int", DataType::Int32, 
true).with_metadata(
+            HashMap::from_iter(vec![("padded".to_string(), 
"true".to_string())]),
+        ));
+        let batch = RecordBatch::try_new(
+            Arc::new(Schema::new(vec![field.clone()])),
+            vec![Arc::new(array)],
+        )
+        .unwrap();
+
+        let json_value = to_json(&batch);
+
+        let expected = json!([
+            {"int": "00000001"},
+            {},
+            {"int": "00000002"},
+        ]);
+
+        assert_eq!(json_value, expected);
+
+        // Now make a dictionary batch
+        let mut array_builder = PrimitiveDictionaryBuilder::<UInt16Type, 
Int32Type>::new();
+        array_builder.append_value(1);
+        array_builder.append_null();
+        array_builder.append_value(1);
+        let array = array_builder.finish();
+        let field = Field::new(
+            "int",
+            DataType::Dictionary(Box::new(DataType::UInt16), 
Box::new(DataType::Int32)),
+            true,
+        )
+        .with_metadata(HashMap::from_iter(vec![(
+            "padded".to_string(),
+            "true".to_string(),
+        )]));
+        let batch = RecordBatch::try_new(Arc::new(Schema::new(vec![field])), 
vec![Arc::new(array)])
+            .unwrap();
+
+        let json_value = to_json(&batch);
+
+        let expected = json!([
+            {"int": "00000001"},
+            {},
+            {"int": "00000001"},
+        ]);
+
+        assert_eq!(json_value, expected);
+    }
 }


Reply via email to