This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new bde919479e01 [SPARK-54176][GEO][PYTHON] Introduce Geography and
Geometry data types to PySpark Connect
bde919479e01 is described below
commit bde919479e0143b2f44cff4c1b7debc312e838d2
Author: Uros Bojanic <[email protected]>
AuthorDate: Wed Nov 5 10:49:50 2025 -0800
[SPARK-54176][GEO][PYTHON] Introduce Geography and Geometry data types to
PySpark Connect
### What changes were proposed in this pull request?
Introduce `GeographyType` and `GeometryType` to PySpark Connect. Note that
the geospatial data types have already been introduced in PySpark as part of:
https://github.com/apache/spark/pull/52627.
Also, introduce classes to represent a `Geography` and `Geometry` value in
Python. Note that the corresponding classes have already been introduced on
Scala side as part of: https://github.com/apache/spark/pull/52804.
### Why are the changes needed?
Enabling geospatial types in Spark Connect.
### Does this PR introduce _any_ user-facing change?
Yes, `GeographyType` and `GeometryType` are now available in PySpark
Connect.
### How was this patch tested?
Added new Python Connect tests:
- `test_parity_geographytype`
- `test_parity_geometrytype`
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52871 from uros-db/geo-spark-connect.
Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
dev/sparktestsupport/modules.py | 2 +
python/pyspark/errors/error-conditions.json | 10 ++
python/pyspark/sql/connect/types.py | 18 +++
python/pyspark/sql/conversion.py | 74 +++++++++
.../sql/tests/connect/test_parity_geographytype.py | 38 +++++
.../sql/tests/connect/test_parity_geometrytype.py | 38 +++++
python/pyspark/sql/types.py | 168 +++++++++++++++++++++
7 files changed, 348 insertions(+)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 07ac4c76b91a..aa8ca58a5a75 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1114,6 +1114,8 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.test_connect_retry",
"pyspark.sql.tests.connect.test_connect_session",
"pyspark.sql.tests.connect.test_connect_stat",
+ "pyspark.sql.tests.connect.test_parity_geographytype",
+ "pyspark.sql.tests.connect.test_parity_geometrytype",
"pyspark.sql.tests.connect.test_parity_datasources",
"pyspark.sql.tests.connect.test_parity_errors",
"pyspark.sql.tests.connect.test_parity_catalog",
diff --git a/python/pyspark/errors/error-conditions.json
b/python/pyspark/errors/error-conditions.json
index d169e6293a1b..51bbdd862516 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -549,6 +549,16 @@
"<arg1> and <arg2> should be of the same length, got <arg1_length> and
<arg2_length>."
]
},
+ "MALFORMED_GEOGRAPHY": {
+ "message": [
+ "Geography binary is malformed. Please check the data source is valid."
+ ]
+ },
+ "MALFORMED_GEOMETRY": {
+ "message": [
+ "Geometry binary is malformed. Please check the data source is valid."
+ ]
+ },
"MALFORMED_VARIANT": {
"message": [
"Variant binary is malformed. Please check the data source is valid."
diff --git a/python/pyspark/sql/connect/types.py
b/python/pyspark/sql/connect/types.py
index 7e8f76861079..d3352b618d7c 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -50,6 +50,8 @@ from pyspark.sql.types import (
NullType,
NumericType,
VariantType,
+ GeographyType,
+ GeometryType,
UserDefinedType,
)
from pyspark.errors import PySparkAssertionError, PySparkValueError
@@ -191,6 +193,10 @@ def pyspark_types_to_proto_types(data_type: DataType) ->
pb2.DataType:
ret.array.contains_null = data_type.containsNull
elif isinstance(data_type, VariantType):
ret.variant.CopyFrom(pb2.DataType.Variant())
+ elif isinstance(data_type, GeometryType):
+ ret.geometry.srid = data_type.srid
+ elif isinstance(data_type, GeographyType):
+ ret.geography.srid = data_type.srid
elif isinstance(data_type, UserDefinedType):
json_value = data_type.jsonValue()
ret.udt.type = "udt"
@@ -303,6 +309,18 @@ def proto_schema_to_pyspark_data_type(schema:
pb2.DataType) -> DataType:
)
elif schema.HasField("variant"):
return VariantType()
+ elif schema.HasField("geometry"):
+ srid = schema.geometry.srid
+ if srid == GeometryType.MIXED_SRID:
+ return GeometryType("ANY")
+ else:
+ return GeometryType(srid)
+ elif schema.HasField("geography"):
+ srid = schema.geography.srid
+ if srid == GeographyType.MIXED_SRID:
+ return GeographyType("ANY")
+ else:
+ return GeographyType(srid)
elif schema.HasField("udt"):
assert schema.udt.type == "udt"
json_value = {}
diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index a8f621277a0a..f73727d1d534 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -28,6 +28,10 @@ from pyspark.sql.types import (
BinaryType,
DataType,
DecimalType,
+ GeographyType,
+ Geography,
+ GeometryType,
+ Geometry,
MapType,
NullType,
Row,
@@ -89,6 +93,10 @@ class LocalDataToArrowConversion:
return True
elif isinstance(dataType, VariantType):
return True
+ elif isinstance(dataType, GeometryType):
+ return True
+ elif isinstance(dataType, GeographyType):
+ return True
else:
return False
@@ -392,6 +400,34 @@ class LocalDataToArrowConversion:
return convert_variant
+ elif isinstance(dataType, GeographyType):
+
+ def convert_geography(value: Any) -> Any:
+ if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must
not be None")
+ return None
+ elif isinstance(value, Geography):
+ return dataType.toInternal(value)
+ else:
+ raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")
+
+ return convert_geography
+
+ elif isinstance(dataType, GeometryType):
+
+ def convert_geometry(value: Any) -> Any:
+ if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must
not be None")
+ return None
+ elif isinstance(value, Geometry):
+ return dataType.toInternal(value)
+ else:
+ raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")
+
+ return convert_geometry
+
elif not nullable:
def convert_other(value: Any) -> Any:
@@ -511,6 +547,10 @@ class ArrowTableToRowsConversion:
return True
elif isinstance(dataType, VariantType):
return True
+ elif isinstance(dataType, GeographyType):
+ return True
+ elif isinstance(dataType, GeometryType):
+ return True
else:
return False
@@ -719,6 +759,40 @@ class ArrowTableToRowsConversion:
return convert_variant
+ elif isinstance(dataType, GeographyType):
+
+ def convert_geography(value: Any) -> Any:
+ if value is None:
+ return None
+ elif (
+ isinstance(value, dict)
+ and all(key in value for key in ["wkb", "srid"])
+ and isinstance(value["wkb"], bytes)
+ and isinstance(value["srid"], int)
+ ):
+ return Geography.fromWKB(value["wkb"], value["srid"])
+ else:
+ raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")
+
+ return convert_geography
+
+ elif isinstance(dataType, GeometryType):
+
+ def convert_geometry(value: Any) -> Any:
+ if value is None:
+ return None
+ elif (
+ isinstance(value, dict)
+ and all(key in value for key in ["wkb", "srid"])
+ and isinstance(value["wkb"], bytes)
+ and isinstance(value["srid"], int)
+ ):
+ return Geometry.fromWKB(value["wkb"], value["srid"])
+ else:
+ raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")
+
+ return convert_geometry
+
else:
if none_on_identity:
return None
diff --git a/python/pyspark/sql/tests/connect/test_parity_geographytype.py
b/python/pyspark/sql/tests/connect/test_parity_geographytype.py
new file mode 100644
index 000000000000..501bbed20ff1
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_geographytype.py
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.test_geographytype import GeographyTypeTestMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class GeographyTypeParityTest(GeographyTypeTestMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.connect.test_parity_geographytype import * # noqa:
F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_geometrytype.py
b/python/pyspark/sql/tests/connect/test_parity_geometrytype.py
new file mode 100644
index 000000000000..b95321b3c61b
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_geometrytype.py
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.test_geometrytype import GeometryTypeTestMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class GeometryTypeParityTest(GeometryTypeTestMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.connect.test_parity_geometrytype import * # noqa:
F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 440100dba931..8aae39880072 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -90,6 +90,8 @@ __all__ = [
"TimestampNTZType",
"DecimalType",
"DoubleType",
+ "Geography",
+ "Geometry",
"FloatType",
"ByteType",
"IntegerType",
@@ -616,6 +618,20 @@ class GeographyType(SpatialType):
# The JSON representation always uses the CRS and algorithm value.
return f"geography({self._crs}, {self._alg})"
+ def needConversion(self) -> bool:
+ return True
+
+ def fromInternal(self, obj: Dict) -> Optional["Geography"]:
+ if obj is None or not all(key in obj for key in ["srid", "bytes"]):
+ return None
+ return Geography(obj["bytes"], obj["srid"])
+
+ def toInternal(self, geography: Any) -> Any:
+ if geography is None:
+ return None
+ assert isinstance(geography, Geography)
+ return {"srid": geography.srid, "wkb": geography.wkb}
+
class GeometryType(SpatialType):
"""
@@ -700,6 +716,20 @@ class GeometryType(SpatialType):
# The JSON representation always uses the CRS value.
return f"geometry({self._crs})"
+ def needConversion(self) -> bool:
+ return True
+
+ def fromInternal(self, obj: Dict) -> Optional["Geometry"]:
+ if obj is None or not all(key in obj for key in ["srid", "bytes"]):
+ return None
+ return Geometry(obj["bytes"], obj["srid"])
+
+ def toInternal(self, geometry: Any) -> Any:
+ if geometry is None:
+ return None
+ assert isinstance(geometry, Geometry)
+ return {"srid": geometry.srid, "wkb": geometry.wkb}
+
class ByteType(IntegralType):
"""Byte data type, representing signed 8-bit integers."""
@@ -2039,6 +2069,144 @@ class VariantVal:
return VariantVal(value, metadata)
+class Geography:
+ """
+ A class to represent a Geography value in Python.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ wkb : bytes
+ The bytes representing the WKB of Geography.
+
+ srid : integer
+ The integer value representing SRID of Geography.
+
+ Methods
+ -------
+ getBytes()
+ Returns the WKB of Geography.
+
+ getSrid()
+ Returns the SRID of Geography.
+
+ Examples
+ --------
+ >>> g =
Geography.fromWKB(bytes.fromhex('010100000000000000000031400000000000001c40'),
4326)
+ >>> g.getBytes().hex()
+ '010100000000000000000031400000000000001c40'
+ >>> g.getSrid()
+ 4326
+ """
+
+ def __init__(self, wkb: bytes, srid: int):
+ self.wkb = wkb
+ self.srid = srid
+
+ def __str__(self) -> str:
+ return "Geography(%r, %d)" % (self.wkb, self.srid)
+
+ def __repr__(self) -> str:
+ return "Geography(%r, %d)" % (self.wkb, self.srid)
+
+ def getSrid(self) -> int:
+ """
+ Returns the SRID of Geography.
+ """
+ return self.srid
+
+ def getBytes(self) -> bytes:
+ """
+ Returns the WKB of Geography.
+ """
+ return self.wkb
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, Geography):
+ # Don't attempt to compare against unrelated types.
+ return NotImplemented
+
+ return self.wkb == other.wkb and self.srid == other.srid
+
+ @classmethod
+ def fromWKB(cls, wkb: bytes, srid: int) -> "Geography":
+ """
+ Construct Python Geography object from WKB.
+ :return: Python representation of the Geography type value.
+ """
+ return Geography(wkb, srid)
+
+
+class Geometry:
+ """
+ A class to represent a Geometry value in Python.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ wkb : bytes
+ The bytes representing the WKB of Geometry.
+
+ srid : integer
+ The integer value representing SRID of Geometry.
+
+ Methods
+ -------
+ getBytes()
+ Returns the WKB of Geometry.
+
+ getSrid()
+ Returns the SRID of Geometry.
+
+ Examples
+ --------
+ >>> g =
Geometry.fromWKB(bytes.fromhex('010100000000000000000031400000000000001c40'), 0)
+ >>> g.getBytes().hex()
+ '010100000000000000000031400000000000001c40'
+ >>> g.getSrid()
+ 0
+ """
+
+ def __init__(self, wkb: bytes, srid: int):
+ self.wkb = wkb
+ self.srid = srid
+
+ def __str__(self) -> str:
+ return "Geometry(%r, %d)" % (self.wkb, self.srid)
+
+ def __repr__(self) -> str:
+ return "Geometry(%r, %d)" % (self.wkb, self.srid)
+
+ def getSrid(self) -> int:
+ """
+ Returns the SRID of Geometry.
+ """
+ return self.srid
+
+ def getBytes(self) -> bytes:
+ """
+ Returns the WKB of Geometry.
+ """
+ return self.wkb
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, Geometry):
+ # Don't attempt to compare against unrelated types.
+ return NotImplemented
+
+ return self.wkb == other.wkb and self.srid == other.srid
+
+ @classmethod
+ def fromWKB(cls, wkb: bytes, srid: int) -> "Geometry":
+ """
+ Construct Python Geometry object from WKB.
+ :return: Python representation of the Geometry type value.
+ """
+ return Geometry(wkb, srid)
+
+
_atomic_types: List[Type[DataType]] = [
StringType,
CharType,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]