This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3c7f5e25b70c [SPARK-50790][PYTHON] Implement parse json in pyspark
3c7f5e25b70c is described below

commit 3c7f5e25b70ce8332c31bee50b704dc55d810bf1
Author: Gene Pang <[email protected]>
AuthorDate: Tue Jan 14 12:58:42 2025 +0900

    [SPARK-50790][PYTHON] Implement parse json in pyspark
    
    ### What changes were proposed in this pull request?
    
    Implement the parseJson functionality in PySpark, for parsing a json string 
to a VariantVal.
    
    ### Why are the changes needed?
    
    Currently, there is no way to create a VariantVal from python. It can only 
be created from Spark SQL.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Added `VariantVal.parseJson`, which takes a json string, and returns a 
`VariantVal`.
    
    ### How was this patch tested?
    
    Added unittests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49450 from gene-db/py-parse-json.
    
    Authored-by: Gene Pang <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../source/reference/pyspark.sql/variant_val.rst   |   1 +
 python/pyspark/sql/tests/test_types.py             |  11 +
 python/pyspark/sql/types.py                        |   9 +
 python/pyspark/sql/variant_utils.py                | 327 ++++++++++++++++++++-
 4 files changed, 346 insertions(+), 2 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql/variant_val.rst 
b/python/docs/source/reference/pyspark.sql/variant_val.rst
index 8630ae8aace1..883b4c8fdc3d 100644
--- a/python/docs/source/reference/pyspark.sql/variant_val.rst
+++ b/python/docs/source/reference/pyspark.sql/variant_val.rst
@@ -26,3 +26,4 @@ VariantVal
 
     VariantVal.toPython
     VariantVal.toJson
+    VariantVal.parseJson
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 432ddd083c80..75c28ac0dec1 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -2240,6 +2240,17 @@ class TypesTestsMixin:
             PySparkValueError, lambda: str(VariantVal(bytes([32, 10, 1, 0, 0, 
0]), metadata))
         )
 
+        # check parse_json
+        for key, json, obj in expected_values:
+            self.assertEqual(VariantVal.parseJson(json).toJson(), json)
+            self.assertEqual(VariantVal.parseJson(json).toPython(), obj)
+
+        # compare the parse_json in Spark vs python. `json_str` contains all 
of `expected_values`.
+        parse_json_spark_output = variants[0]
+        parse_json_python_output = VariantVal.parseJson(json_str)
+        self.assertEqual(parse_json_spark_output.value, 
parse_json_python_output.value)
+        self.assertEqual(parse_json_spark_output.metadata, 
parse_json_python_output.metadata)
+
     def test_to_ddl(self):
         schema = StructType().add("a", NullType()).add("b", 
BooleanType()).add("c", BinaryType())
         self.assertEqual(schema.toDDL(), "a VOID,b BOOLEAN,c BINARY")
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index f40a8bf62b29..b913e05e16d2 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1770,6 +1770,15 @@ class VariantVal:
         """
         return VariantUtils.to_json(self.value, self.metadata, zone_id)
 
+    @classmethod
+    def parseJson(cls, json_str: str) -> "VariantVal":
+        """
+        Convert the VariantVal to a nested Python object of Python data types.
+        :return: Python representation of the Variant nested structure
+        """
+        (value, metadata) = VariantUtils.parse_json(json_str)
+        return VariantVal(value, metadata)
+
 
 _atomic_types: List[Type[DataType]] = [
     StringType,
diff --git a/python/pyspark/sql/variant_utils.py 
b/python/pyspark/sql/variant_utils.py
index 40cc69c1f096..3025523064e1 100644
--- a/python/pyspark/sql/variant_utils.py
+++ b/python/pyspark/sql/variant_utils.py
@@ -21,7 +21,7 @@ import datetime
 import json
 import struct
 from array import array
-from typing import Any, Callable, Dict, List, Tuple
+from typing import Any, Callable, Dict, List, NamedTuple, Tuple
 from pyspark.errors import PySparkValueError
 from zoneinfo import ZoneInfo
 
@@ -108,8 +108,25 @@ class VariantUtils:
     # string size) + (size bytes of string content).
     LONG_STR = 16
 
+    VERSION = 1
+    # The lower 4 bits of the first metadata byte contain the version.
+    VERSION_MASK = 0x0F
+
+    U8_MAX = 0xFF
+    U16_MAX = 0xFFFF
+    U24_MAX = 0xFFFFFF
+    U24_SIZE = 3
     U32_SIZE = 4
 
+    I8_MAX = 0x7F
+    I8_MIN = -0x80
+    I16_MAX = 0x7FFF
+    I16_MIN = -0x8000
+    I32_MAX = 0x7FFFFFFF
+    I32_MIN = -0x80000000
+    I64_MAX = 0x7FFFFFFFFFFFFFFF
+    I64_MIN = -0x8000000000000000
+
     EPOCH = datetime.datetime(
         year=1970, month=1, day=1, hour=0, minute=0, second=0, 
tzinfo=datetime.timezone.utc
     )
@@ -140,6 +157,15 @@ class VariantUtils:
         """
         return cls._to_python(value, metadata, 0)
 
+    @classmethod
+    def parse_json(cls, json_str: str) -> Tuple[bytes, bytes]:
+        """
+        Parses the JSON string and creates the Variant binary (value, metadata)
+        :return: tuple of 2 binary values (value, metadata)
+        """
+        builder = VariantBuilder()
+        return builder.build(json_str)
+
     @classmethod
     def _read_long(cls, data: bytes, pos: int, num_bytes: int, signed: bool) 
-> int:
         cls._check_index(pos, len(data))
@@ -468,7 +494,10 @@ class VariantUtils:
                 value, offset_start + offset_size * i, offset_size, 
signed=False
             )
             value_pos = data_start + offset
-            key_value_pos_list.append((cls._get_metadata_key(metadata, id), 
value_pos))
+            if metadata is not None:
+                key_value_pos_list.append((cls._get_metadata_key(metadata, 
id), value_pos))
+            else:
+                key_value_pos_list.append(("", value_pos))
         return func(key_value_pos_list)
 
     @classmethod
@@ -496,3 +525,297 @@ class VariantUtils:
             element_pos = data_start + offset
             value_pos_list.append(element_pos)
         return func(value_pos_list)
+
+
+class FieldEntry(NamedTuple):
+    """
+    Info about an object field
+    """
+
+    key: str
+    id: int
+    offset: int
+
+
+class VariantBuilder:
+    """
+    A utility class for building VariantVal.
+    """
+
+    DEFAULT_SIZE_LIMIT = 16 * 1024 * 1024
+
+    def __init__(self, size_limit: int = DEFAULT_SIZE_LIMIT):
+        self.value = bytearray()
+        self.dictionary = dict[str, int]()
+        self.dictionary_keys = list[bytes]()
+        self.size_limit = size_limit
+
+    def build(self, json_str: str) -> Tuple[bytes, bytes]:
+        parsed = json.loads(json_str, parse_float=self._handle_float)
+        self._process_parsed_json(parsed)
+
+        num_keys = len(self.dictionary_keys)
+        dictionary_string_size = sum(len(key) for key in self.dictionary_keys)
+
+        # Determine the number of bytes required per offset entry.
+        # The largest offset is the one-past-the-end value, which is total 
string size. It's very
+        # unlikely that the number of keys could be larger, but incorporate 
that into the
+        # calculation in case of pathological data.
+        max_size = max(dictionary_string_size, num_keys)
+        if max_size > self.size_limit:
+            raise PySparkValueError(errorClass="VARIANT_SIZE_LIMIT_EXCEEDED", 
messageParameters={})
+        offset_size = self._get_integer_size(max_size)
+
+        offset_start = 1 + offset_size
+        string_start = offset_start + (num_keys + 1) * offset_size
+        metadata_size = string_start + dictionary_string_size
+        if metadata_size > self.size_limit:
+            raise PySparkValueError(errorClass="VARIANT_SIZE_LIMIT_EXCEEDED", 
messageParameters={})
+
+        metadata = bytearray()
+        header_byte = VariantUtils.VERSION | ((offset_size - 1) << 6)
+        metadata.extend(header_byte.to_bytes(1, byteorder="little"))
+        metadata.extend(num_keys.to_bytes(offset_size, byteorder="little"))
+        # write offsets
+        current_offset = 0
+        for key in self.dictionary_keys:
+            metadata.extend(current_offset.to_bytes(offset_size, 
byteorder="little"))
+            current_offset += len(key)
+        metadata.extend(current_offset.to_bytes(offset_size, 
byteorder="little"))
+        # write key data
+        for key in self.dictionary_keys:
+            metadata.extend(key)
+        return (bytes(self.value), bytes(metadata))
+
+    def _process_parsed_json(self, parsed: Any) -> None:
+        if type(parsed) is dict:
+            fields = list[FieldEntry]()
+            start = len(self.value)
+            for key, value in parsed.items():
+                id = self._add_key(key)
+                fields.append(FieldEntry(key, id, len(self.value) - start))
+                self._process_parsed_json(value)
+            self._finish_writing_object(start, fields)
+        elif type(parsed) is list:
+            offsets = []
+            start = len(self.value)
+            for elem in parsed:
+                offsets.append(len(self.value) - start)
+                self._process_parsed_json(elem)
+            self._finish_writing_array(start, offsets)
+        elif type(parsed) is str:
+            self._append_string(parsed)
+        elif type(parsed) is int:
+            if not self._append_int(parsed):
+                self._process_parsed_json(self._handle_float(str(parsed)))
+        elif type(parsed) is float:
+            self._append_float(parsed)
+        elif type(parsed) is decimal.Decimal:
+            self._append_decimal(parsed)
+        elif type(parsed) is bool:
+            self._append_boolean(parsed)
+        elif parsed is None:
+            self._append_null()
+        else:
+            raise PySparkValueError(errorClass="MALFORMED_VARIANT", 
messageParameters={})
+
+    # Choose the smallest unsigned integer type that can store `value`. It 
must be within
+    # [0, U24_MAX].
+    def _get_integer_size(self, value: int) -> int:
+        if value <= VariantUtils.U8_MAX:
+            return 1
+        if value <= VariantUtils.U16_MAX:
+            return 2
+        return VariantUtils.U24_SIZE
+
+    def _check_capacity(self, additional: int) -> None:
+        required = len(self.value) + additional
+        if required > self.size_limit:
+            raise PySparkValueError(errorClass="VARIANT_SIZE_LIMIT_EXCEEDED", 
messageParameters={})
+
+    def _primitive_header(self, type: int) -> bytes:
+        return bytes([(type << 2) | VariantUtils.PRIMITIVE])
+
+    def _short_string_header(self, size: int) -> bytes:
+        return bytes([size << 2 | VariantUtils.SHORT_STR])
+
+    def _array_header(self, large_size: bool, offset_size: int) -> bytes:
+        return bytes(
+            [
+                (
+                    (large_size << (VariantUtils.BASIC_TYPE_BITS + 2))
+                    | ((offset_size - 1) << VariantUtils.BASIC_TYPE_BITS)
+                    | VariantUtils.ARRAY
+                )
+            ]
+        )
+
+    def _object_header(self, large_size: bool, id_size: int, offset_size: int) 
-> bytes:
+        return bytes(
+            [
+                (
+                    (large_size << (VariantUtils.BASIC_TYPE_BITS + 4))
+                    | ((id_size - 1) << (VariantUtils.BASIC_TYPE_BITS + 2))
+                    | ((offset_size - 1) << VariantUtils.BASIC_TYPE_BITS)
+                    | VariantUtils.OBJECT
+                )
+            ]
+        )
+
+    # Add a key to the variant dictionary. If the key already exists, the 
dictionary is
+    # not modified. In either case, return the id of the key.
+    def _add_key(self, key: str) -> int:
+        if key in self.dictionary:
+            return self.dictionary[key]
+        id = len(self.dictionary_keys)
+        self.dictionary[key] = id
+        self.dictionary_keys.append(key.encode("utf-8"))
+        return id
+
+    def _handle_float(self, num_str: str) -> Any:
+        # a float can be a decimal if it only contains digits, '-', or '-'.
+        if all([ch.isdecimal() or ch == "-" or ch == "." for ch in num_str]):
+            dec = decimal.Decimal(num_str)
+            precision = len(dec.as_tuple().digits)
+            scale = -int(dec.as_tuple().exponent)
+
+            if (
+                scale <= VariantUtils.MAX_DECIMAL16_PRECISION
+                and precision <= VariantUtils.MAX_DECIMAL16_PRECISION
+            ):
+                return dec
+        return float(num_str)
+
+    def _append_boolean(self, b: bool) -> None:
+        self._check_capacity(1)
+        self.value.extend(self._primitive_header(VariantUtils.TRUE if b else 
VariantUtils.FALSE))
+
+    def _append_null(self) -> None:
+        self._check_capacity(1)
+        self.value.extend(self._primitive_header(VariantUtils.NULL))
+
+    def _append_string(self, s: str) -> None:
+        text = s.encode("utf-8")
+        long_str = len(text) > VariantUtils.MAX_SHORT_STR_SIZE
+        additional = (1 + VariantUtils.U32_SIZE) if long_str else 1
+        self._check_capacity(additional + len(text))
+        if long_str:
+            self.value.extend(self._primitive_header(VariantUtils.LONG_STR))
+            self.value.extend(len(text).to_bytes(VariantUtils.U32_SIZE, 
byteorder="little"))
+        else:
+            self.value.extend(self._short_string_header(len(text)))
+        self.value.extend(text)
+
+    def _append_int(self, i: int) -> bool:
+        self._check_capacity(1 + 8)
+        if i >= VariantUtils.I8_MIN and i <= VariantUtils.I8_MAX:
+            self.value.extend(self._primitive_header(VariantUtils.INT1))
+            self.value.extend(i.to_bytes(1, byteorder="little", signed=True))
+        elif i >= VariantUtils.I16_MIN and i <= VariantUtils.I16_MAX:
+            self.value.extend(self._primitive_header(VariantUtils.INT2))
+            self.value.extend(i.to_bytes(2, byteorder="little", signed=True))
+        elif i >= VariantUtils.I32_MIN and i <= VariantUtils.I32_MAX:
+            self.value.extend(self._primitive_header(VariantUtils.INT4))
+            self.value.extend(i.to_bytes(4, byteorder="little", signed=True))
+        elif i >= VariantUtils.I64_MIN and i <= VariantUtils.I64_MAX:
+            self.value.extend(self._primitive_header(VariantUtils.INT8))
+            self.value.extend(i.to_bytes(8, byteorder="little", signed=True))
+        else:
+            return False
+        return True
+
+    # Append a decimal value to the variant builder. The caller should 
guarantee that its precision
+    # and scale fit into `MAX_DECIMAL16_PRECISION`.
+    def _append_decimal(self, d: decimal.Decimal) -> None:
+        self._check_capacity(2 + 16)
+        precision = len(d.as_tuple().digits)
+        scale = -int(d.as_tuple().exponent)
+        unscaled = int("".join(map(str, d.as_tuple().digits)))
+        unscaled = -unscaled if d < 0 else unscaled
+        if (
+            scale <= VariantUtils.MAX_DECIMAL4_PRECISION
+            and precision <= VariantUtils.MAX_DECIMAL4_PRECISION
+        ):
+            self.value.extend(self._primitive_header(VariantUtils.DECIMAL4))
+            self.value.extend(scale.to_bytes(1, byteorder="little"))
+            self.value.extend(unscaled.to_bytes(4, byteorder="little", 
signed=True))
+        elif (
+            scale <= VariantUtils.MAX_DECIMAL8_PRECISION
+            and precision <= VariantUtils.MAX_DECIMAL8_PRECISION
+        ):
+            self.value.extend(self._primitive_header(VariantUtils.DECIMAL8))
+            self.value.extend(scale.to_bytes(1, byteorder="little"))
+            self.value.extend(unscaled.to_bytes(8, byteorder="little", 
signed=True))
+        else:
+            assert (
+                scale <= VariantUtils.MAX_DECIMAL16_PRECISION
+                and precision <= VariantUtils.MAX_DECIMAL16_PRECISION
+            )
+            self.value.extend(self._primitive_header(VariantUtils.DECIMAL16))
+            self.value.extend(scale.to_bytes(1, byteorder="little"))
+            self.value.extend(unscaled.to_bytes(16, byteorder="little", 
signed=True))
+
+    def _append_float(self, f: float) -> None:
+        self._check_capacity(1 + 8)
+        self.value.extend(self._primitive_header(VariantUtils.DOUBLE))
+        self.value.extend(struct.pack("<d", f))
+
+    # Finish writing a variant array after all of its elements have already 
been written.
+    def _finish_writing_array(self, start: int, offsets: List[int]) -> None:
+        data_size = len(self.value) - start
+        num_offsets = len(offsets)
+        large_size = num_offsets > VariantUtils.U8_MAX
+        size_bytes = VariantUtils.U32_SIZE if large_size else 1
+        offset_size = self._get_integer_size(data_size)
+        # The space for header byte, object size, and offset list.
+        header_size = 1 + size_bytes + (num_offsets + 1) * offset_size
+        self._check_capacity(header_size)
+        self.value.extend(bytearray(header_size))
+        # Shift the just-written element data to make room for the header 
section.
+        self.value[start + header_size :] = bytes(self.value[start : start + 
data_size])
+        # Write the header byte, num offsets
+        offset_start = start + 1 + size_bytes
+        self.value[start : start + 1] = self._array_header(large_size, 
offset_size)
+        self.value[start + 1 : offset_start] = 
num_offsets.to_bytes(size_bytes, byteorder="little")
+        # write offset list
+        offset_list = bytearray()
+        for offset in offsets:
+            offset_list.extend(offset.to_bytes(offset_size, 
byteorder="little"))
+        offset_list.extend(data_size.to_bytes(offset_size, byteorder="little"))
+        self.value[offset_start : offset_start + len(offset_list)] = 
offset_list
+
+    # Finish writing a variant object after all of its fields have already 
been written.
+    def _finish_writing_object(self, start: int, fields: List[FieldEntry]) -> 
None:
+        num_fields = len(fields)
+        # object fields are from a python dictionary, so keys are already 
distinct
+        fields.sort(key=lambda f: f.key)
+        max_id = 0
+        for field in fields:
+            max_id = max(max_id, field.id)
+
+        data_size = len(self.value) - start
+        large_size = num_fields > VariantUtils.U8_MAX
+        size_bytes = VariantUtils.U32_SIZE if large_size else 1
+        id_size = self._get_integer_size(max_id)
+        offset_size = self._get_integer_size(data_size)
+        # The space for header byte, object size, id list, and offset list.
+        header_size = 1 + size_bytes + num_fields * id_size + (num_fields + 1) 
* offset_size
+        self._check_capacity(header_size)
+        self.value.extend(bytearray(header_size))
+        # Shift the just-written field data to make room for the object header 
section.
+        self.value[start + header_size :] = self.value[start : start + 
data_size]
+        # Write the header byte, num fields, id list, offset list
+        self.value[start : start + 1] = self._object_header(large_size, 
id_size, offset_size)
+        self.value[start + 1 : start + 1 + size_bytes] = num_fields.to_bytes(
+            size_bytes, byteorder="little"
+        )
+        id_start = start + 1 + size_bytes
+        offset_start = id_start + num_fields * id_size
+        id_list = bytearray()
+        offset_list = bytearray()
+        for field in fields:
+            id_list.extend(field.id.to_bytes(id_size, byteorder="little"))
+            offset_list.extend(field.offset.to_bytes(offset_size, 
byteorder="little"))
+        offset_list.extend(data_size.to_bytes(offset_size, byteorder="little"))
+        self.value[id_start : id_start + len(id_list)] = id_list
+        self.value[offset_start : offset_start + len(offset_list)] = 
offset_list


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to