This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new bd90d7168d6e [SPARK-54110][GEO][SQL] Introduce type encoders for
Geography and Geometry types
bd90d7168d6e is described below
commit bd90d7168d6ee54de75e092ffa86fa06e67d3725
Author: Uros Bojanic <[email protected]>
AuthorDate: Tue Nov 4 02:51:30 2025 +0800
[SPARK-54110][GEO][SQL] Introduce type encoders for Geography and Geometry
types
### What changes were proposed in this pull request?
This PR introduces type encoders for `Geography` and `Geometry`.
Note that the server-side geospatial classes have already been introduced
as part of: https://github.com/apache/spark/pull/52737; while client-side
geospatial classes in external API have subsequently been introduced as part
of: https://github.com/apache/spark/pull/52804.
### Why are the changes needed?
These encoders are used to translate between (server) Spark Catalyst types
and (client) Java/Scala types.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added new Scala unit test suites for data frames:
- `GeographyDataFrameSuite`
- `GeometryDataFrameSuite`
Also, added appropriate test cases to:
- `RowSuite`
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52813 from uros-db/geo-expression-encoders.
Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 2f9431340937d4b2a6cde54cb35b780d6c03b512)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 6 +
.../main/scala/org/apache/spark/sql/Encoders.scala | 14 ++
.../src/main/scala/org/apache/spark/sql/Row.scala | 18 ++
.../scala/org/apache/spark/sql/SQLImplicits.scala | 8 +
.../spark/sql/catalyst/JavaTypeInference.scala | 6 +-
.../spark/sql/catalyst/ScalaReflection.scala | 4 +
.../sql/catalyst/encoders/AgnosticEncoder.scala | 6 +
.../spark/sql/catalyst/encoders/RowEncoder.scala | 4 +-
.../org/apache/spark/sql/types/GeographyType.scala | 28 +++-
.../org/apache/spark/sql/types/GeometryType.scala | 28 +++-
.../apache/spark/sql/catalyst/util/STUtils.java | 34 ++++
.../sql/catalyst/CatalystTypeConverters.scala | 42 ++++-
.../sql/catalyst/DeserializerBuildHelper.scala | 26 ++-
.../spark/sql/catalyst/SerializerBuildHelper.scala | 24 ++-
.../apache/spark/sql/GeographyDataFrameSuite.scala | 180 ++++++++++++++++++++
.../apache/spark/sql/GeometryDataFrameSuite.scala | 181 +++++++++++++++++++++
.../test/scala/org/apache/spark/sql/RowSuite.scala | 11 ++
17 files changed, 611 insertions(+), 9 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index c19e192d80ee..9d95d74cc21a 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -1888,6 +1888,12 @@
],
"sqlState" : "42623"
},
+ "GEO_ENCODER_SRID_MISMATCH_ERROR" : {
+ "message" : [
+ "Failed to encode <type> value because provided SRID <valueSrid> of a
value to encode does not match type SRID: <typeSrid>."
+ ],
+ "sqlState" : "42K09"
+ },
"GET_TABLES_BY_TYPE_UNSUPPORTED_BY_HIVE_VERSION" : {
"message" : [
"Hive 2.2 and lower versions don't support getTablesByType. Please use
Hive 2.3 or higher version."
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
index cb1402e1b0f4..7e698e58321e 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -162,6 +162,20 @@ object Encoders {
*/
def BINARY: Encoder[Array[Byte]] = BinaryEncoder
+ /**
+ * An encoder for Geometry data type.
+ *
+ * @since 4.1.0
+ */
+ def GEOMETRY(dt: GeometryType): Encoder[Geometry] = GeometryEncoder(dt)
+
+ /**
+ * An encoder for Geography data type.
+ *
+ * @since 4.1.0
+ */
+ def GEOGRAPHY(dt: GeographyType): Encoder[Geography] = GeographyEncoder(dt)
+
/**
* Creates an encoder that serializes instances of the `java.time.Duration`
class to the
* internal representation of nullable Catalyst's DayTimeIntervalType.
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala
b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala
index 764bdb17b37e..1019d4c9a227 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala
@@ -302,6 +302,24 @@ trait Row extends Serializable {
*/
def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i)
+ /**
+ * Returns the value at position i of date type as
org.apache.spark.sql.types.Geometry.
+ *
+ * @throws ClassCastException
+ * when data type does not match.
+ */
+ def getGeometry(i: Int): org.apache.spark.sql.types.Geometry =
+ getAs[org.apache.spark.sql.types.Geometry](i)
+
+ /**
+ * Returns the value at position i of date type as
org.apache.spark.sql.types.Geography.
+ *
+ * @throws ClassCastException
+ * when data type does not match.
+ */
+ def getGeography(i: Int): org.apache.spark.sql.types.Geography =
+ getAs[org.apache.spark.sql.types.Geography](i)
+
/**
* Returns the value at position i of date type as java.sql.Date.
*
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
b/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index a5b1060ca03d..9d64225b9663 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -104,6 +104,14 @@ trait EncoderImplicits extends LowPrioritySQLImplicits
with Serializable {
implicit def newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] =
DEFAULT_SCALA_DECIMAL_ENCODER
+ /** @since 4.1.0 */
+ implicit def newGeometryEncoder:
Encoder[org.apache.spark.sql.types.Geometry] =
+ DEFAULT_GEOMETRY_ENCODER
+
+ /** @since 4.1.0 */
+ implicit def newGeographyEncoder:
Encoder[org.apache.spark.sql.types.Geography] =
+ DEFAULT_GEOGRAPHY_ENCODER
+
/** @since 2.2.0 */
implicit def newDateEncoder: Encoder[java.sql.Date] = Encoders.DATE
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 906e6419b360..91947cf416fb 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
import org.apache.commons.lang3.reflect.{TypeUtils => JavaTypeUtils}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder,
BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder,
DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField,
IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder,
LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, PrimitiveBooleanEncoder,
PrimitiveByteEncoder, PrimitiveDoubleEncoder, Primit [...]
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder,
BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder,
DayTimeIntervalEncoder, DEFAULT_GEOGRAPHY_ENCODER, DEFAULT_GEOMETRY_ENCODER,
DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder,
JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, LocalTimeEncoder,
MapEncoder, PrimitiveBooleanEncoder, [...]
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
@@ -86,6 +86,10 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.lang.String] => StringEncoder
case c: Class[_] if c == classOf[Array[Byte]] => BinaryEncoder
+ case c: Class[_] if c == classOf[org.apache.spark.sql.types.Geometry] =>
+ DEFAULT_GEOMETRY_ENCODER
+ case c: Class[_] if c == classOf[org.apache.spark.sql.types.Geography] =>
+ DEFAULT_GEOGRAPHY_ENCODER
case c: Class[_] if c == classOf[java.math.BigDecimal] =>
DEFAULT_JAVA_DECIMAL_ENCODER
case c: Class[_] if c == classOf[java.math.BigInteger] => JavaBigIntEncoder
case c: Class[_] if c == classOf[java.time.LocalDate] =>
STRICT_LOCAL_DATE_ENCODER
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d2e0053597e4..6f5c4be42bbd 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -332,6 +332,10 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) =>
LocalDateTimeEncoder
case t if isSubtype(t, localTypeOf[java.time.LocalTime]) =>
LocalTimeEncoder
case t if isSubtype(t, localTypeOf[VariantVal]) => VariantEncoder
+ case t if isSubtype(t, localTypeOf[Geography]) =>
+ DEFAULT_GEOGRAPHY_ENCODER
+ case t if isSubtype(t, localTypeOf[Geometry]) =>
+ DEFAULT_GEOMETRY_ENCODER
case t if isSubtype(t, localTypeOf[Row]) => UnboundRowEncoder
// UDT encoders
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index 0c5295176608..20949c188cb8 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -246,6 +246,8 @@ object AgnosticEncoders {
case object DayTimeIntervalEncoder extends
LeafEncoder[Duration](DayTimeIntervalType())
case object YearMonthIntervalEncoder extends
LeafEncoder[Period](YearMonthIntervalType())
case object VariantEncoder extends LeafEncoder[VariantVal](VariantType)
+ case class GeographyEncoder(dt: GeographyType) extends
LeafEncoder[Geography](dt)
+ case class GeometryEncoder(dt: GeometryType) extends
LeafEncoder[Geometry](dt)
case class DateEncoder(override val lenientSerialization: Boolean)
extends LeafEncoder[jsql.Date](DateType)
case class LocalDateEncoder(override val lenientSerialization: Boolean)
@@ -277,6 +279,10 @@ object AgnosticEncoders {
ScalaDecimalEncoder(DecimalType.SYSTEM_DEFAULT)
val DEFAULT_JAVA_DECIMAL_ENCODER: JavaDecimalEncoder =
JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization =
false)
+ val DEFAULT_GEOMETRY_ENCODER: GeometryEncoder =
+ GeometryEncoder(GeometryType(Geometry.DEFAULT_SRID))
+ val DEFAULT_GEOGRAPHY_ENCODER: GeographyEncoder =
+ GeographyEncoder(GeographyType(Geography.DEFAULT_SRID))
/**
* Encoder that transforms external data into a representation that can be
further processed by
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 620278c66d21..73152017cf22 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.reflect.classTag
import org.apache.spark.sql.{AnalysisException, Row}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder,
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder,
BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder,
CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder,
IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder,
LocalTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder,
StringEncoder, Timesta [...]
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder,
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder,
BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder,
CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField,
GeographyEncoder, GeometryEncoder, InstantEncoder, IterableEncoder,
JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder,
MapEncoder, NullEncoder, RowEncoder => Agnosti [...]
import org.apache.spark.sql.errors.DataTypeErrorsBase
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types._
@@ -120,6 +120,8 @@ object RowEncoder extends DataTypeErrorsBase {
field.nullable,
field.metadata)
}.toImmutableArraySeq)
+ case g: GeographyType => GeographyEncoder(g)
+ case g: GeometryType => GeometryEncoder(g)
case _ =>
throw new AnalysisException(
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala
b/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala
index 638ae7935184..d72e5987abeb 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.types
import org.json4s.JsonAST.{JString, JValue}
-import org.apache.spark.SparkIllegalArgumentException
+import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.annotation.Experimental
import
org.apache.spark.sql.internal.types.GeographicSpatialReferenceSystemMapper
@@ -133,6 +133,27 @@ class GeographyType private (val crs: String, val
algorithm: EdgeInterpolationAl
// If the SRID is not mixed, we can only accept the same SRID.
isMixedSrid || gt.srid == srid
}
+
+ private[sql] def assertSridAllowedForType(otherSrid: Int): Unit = {
+ // If SRID is not mixed, SRIDs must match.
+ if (!isMixedSrid && otherSrid != srid) {
+ throw new SparkRuntimeException(
+ errorClass = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ messageParameters = Map(
+ "type" -> "GEOGRAPHY",
+ "valueSrid" -> otherSrid.toString,
+ "typeSrid" -> srid.toString))
+ } else if (isMixedSrid) {
+ // For fixed SRID geom types, we have a check that value matches the
type srid.
+ // For mixed SRID we need to do that check explicitly, as MIXED SRID can
accept any SRID.
+ // However it should accept only valid SRIDs.
+ if (!GeographyType.isSridSupported(otherSrid)) {
+ throw new SparkIllegalArgumentException(
+ errorClass = "ST_INVALID_SRID_VALUE",
+ messageParameters = Map("srid" -> otherSrid.toString))
+ }
+ }
+ }
}
@Experimental
@@ -157,6 +178,11 @@ object GeographyType extends SpatialType {
private final val GEOGRAPHY_MIXED_TYPE: GeographyType =
GeographyType(MIXED_CRS, GEOGRAPHY_DEFAULT_ALGORITHM)
+ /** Returns whether the given SRID is supported. */
+ private[types] def isSridSupported(srid: Int): Boolean = {
+ GeographicSpatialReferenceSystemMapper.getStringId(srid) != null
+ }
+
/**
* Constructors for GeographyType.
*/
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala
b/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala
index 77a6b365c042..f5bbbcba6706 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.types
import org.json4s.JsonAST.{JString, JValue}
-import org.apache.spark.SparkIllegalArgumentException
+import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.annotation.Experimental
import
org.apache.spark.sql.internal.types.CartesianSpatialReferenceSystemMapper
@@ -130,6 +130,27 @@ class GeometryType private (val crs: String) extends
AtomicType with Serializabl
// If the SRID is not mixed, we can only accept the same SRID.
isMixedSrid || gt.srid == srid
}
+
+ private[sql] def assertSridAllowedForType(otherSrid: Int): Unit = {
+ // If SRID is not mixed, SRIDs must match.
+ if (!isMixedSrid && otherSrid != srid) {
+ throw new SparkRuntimeException(
+ errorClass = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ messageParameters = Map(
+ "type" -> "GEOMETRY",
+ "valueSrid" -> otherSrid.toString,
+ "typeSrid" -> srid.toString))
+ } else if (isMixedSrid) {
+ // For fixed SRID geom types, we have a check that value matches the
type srid.
+ // For mixed SRID we need to do that check explicitly, as MIXED SRID can
accept any SRID.
+ // However it should accept only valid SRIDs.
+ if (!GeometryType.isSridSupported(otherSrid)) {
+ throw new SparkIllegalArgumentException(
+ errorClass = "ST_INVALID_SRID_VALUE",
+ messageParameters = Map("srid" -> otherSrid.toString))
+ }
+ }
+ }
}
@Experimental
@@ -149,6 +170,11 @@ object GeometryType extends SpatialType {
private final val GEOMETRY_MIXED_TYPE: GeometryType =
GeometryType(MIXED_CRS)
+ /** Returns whether the given SRID is supported. */
+ private[types] def isSridSupported(srid: Int): Boolean = {
+ CartesianSpatialReferenceSystemMapper.getStringId(srid) != null
+ }
+
/**
* Constructors for GeometryType.
*/
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
index aca3fdf1f100..9edeee26eb98 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
@@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.catalyst.util;
+import org.apache.spark.sql.types.GeographyType;
+import org.apache.spark.sql.types.GeometryType;
import org.apache.spark.unsafe.types.GeographyVal;
import org.apache.spark.unsafe.types.GeometryVal;
@@ -46,6 +48,38 @@ public final class STUtils {
return g.getValue();
}
+ /** Geospatial type encoder/decoder utilities. */
+
+ public static GeometryVal
serializeGeomFromWKB(org.apache.spark.sql.types.Geometry geometry,
+ GeometryType gt) {
+ int geometrySrid = geometry.getSrid();
+ gt.assertSridAllowedForType(geometrySrid);
+ return toPhysVal(Geometry.fromWkb(geometry.getBytes(), geometrySrid));
+ }
+
+ public static GeographyVal
serializeGeogFromWKB(org.apache.spark.sql.types.Geography geography,
+ GeographyType gt) {
+ int geographySrid = geography.getSrid();
+ gt.assertSridAllowedForType(geographySrid);
+ return toPhysVal(Geography.fromWkb(geography.getBytes(), geographySrid));
+ }
+
+ public static org.apache.spark.sql.types.Geometry deserializeGeom(
+ GeometryVal geometry, GeometryType gt) {
+ int geometrySrid = stSrid(geometry);
+ gt.assertSridAllowedForType(geometrySrid);
+ byte[] wkb = stAsBinary(geometry);
+ return org.apache.spark.sql.types.Geometry.fromWKB(wkb, geometrySrid);
+ }
+
+ public static org.apache.spark.sql.types.Geography deserializeGeog(
+ GeographyVal geography, GeographyType gt) {
+ int geographySrid = stSrid(geography);
+ gt.assertSridAllowedForType(geographySrid);
+ byte[] wkb = stAsBinary(geography);
+ return org.apache.spark.sql.types.Geography.fromWKB(wkb, geographySrid);
+ }
+
/** Methods for implementing ST expressions. */
// ST_AsBinary
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index c1e0674d391d..b8eee5e1c7c6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType._
import org.apache.spark.sql.types.YearMonthIntervalType._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.collection.Utils
@@ -69,6 +69,10 @@ object CatalystTypeConverters {
case CharType(length) => new CharConverter(length)
case VarcharType(length) => new VarcharConverter(length)
case _: StringType => StringConverter
+ case g: GeographyType =>
+ new GeographyConverter(g)
+ case g: GeometryType =>
+ new GeometryConverter(g)
case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
LocalDateConverter
case DateType => DateConverter
case _: TimeType => TimeConverter
@@ -345,6 +349,42 @@ object CatalystTypeConverters {
row.getUTF8String(column).toString
}
+ private class GeometryConverter(dataType: GeometryType)
+ extends CatalystTypeConverter[Any, org.apache.spark.sql.types.Geometry,
GeometryVal] {
+ override def toCatalystImpl(scalaValue: Any): GeometryVal = scalaValue
match {
+ case g: org.apache.spark.sql.types.Geometry =>
STUtils.serializeGeomFromWKB(g, dataType)
+ case other => throw new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_3219",
+ messageParameters = scala.collection.immutable.Map(
+ "other" -> other.toString,
+ "otherClass" -> other.getClass.getCanonicalName,
+ "dataType" -> StringType.sql))
+ }
+ override def toScala(catalystValue: GeometryVal):
org.apache.spark.sql.types.Geometry =
+ if (catalystValue == null) null
+ else STUtils.deserializeGeom(catalystValue, dataType)
+ override def toScalaImpl(row: InternalRow, column: Int):
org.apache.spark.sql.types.Geometry =
+ STUtils.deserializeGeom(row.getGeometry(0), dataType)
+ }
+
+ private class GeographyConverter(dataType: GeographyType)
+ extends CatalystTypeConverter[Any, org.apache.spark.sql.types.Geography,
GeographyVal] {
+ override def toCatalystImpl(scalaValue: Any): GeographyVal = scalaValue
match {
+ case g: org.apache.spark.sql.types.Geography =>
STUtils.serializeGeogFromWKB(g, dataType)
+ case other => throw new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_3219",
+ messageParameters = scala.collection.immutable.Map(
+ "other" -> other.toString,
+ "otherClass" -> other.getClass.getCanonicalName,
+ "dataType" -> StringType.sql))
+ }
+ override def toScala(catalystValue: GeographyVal):
org.apache.spark.sql.types.Geography =
+ if (catalystValue == null) null
+ else STUtils.deserializeGeog(catalystValue, dataType)
+ override def toScalaImpl(row: InternalRow, column: Int):
org.apache.spark.sql.types.Geography =
+ STUtils.deserializeGeog(row.getGeography(0), dataType)
+ }
+
private object DateConverter extends CatalystTypeConverter[Any, Date, Any] {
override def toCatalystImpl(scalaValue: Any): Int = scalaValue match {
case d: Date => DateTimeUtils.fromJavaDate(d)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index a051205829a1..60de179edb79 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal,
UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder,
AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec,
KryoSerializationCodec}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder,
InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder,
JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder,
LocalTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder,
PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder,
PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEnco [...]
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder,
GeographyEncoder, GeometryEncoder, InstantEncoder, IterableEncoder,
JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder,
LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder,
OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder,
PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, Primi [...]
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{dataTypeForClass,
externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField,
IsNull, Literal, MapKeys, MapValues, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull,
CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke,
NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap,
UnresolvedMapObjects, WrapOption}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData,
CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData,
CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils, STUtils}
import org.apache.spark.sql.types._
object DeserializerBuildHelper {
@@ -80,6 +80,24 @@ object DeserializerBuildHelper {
returnNullable = false)
}
+ def createDeserializerForGeometryType(inputObject: Expression, gt:
GeometryType): Expression = {
+ StaticInvoke(
+ classOf[STUtils],
+ ObjectType(classOf[Geometry]),
+ "deserializeGeom",
+ inputObject :: Literal.fromObject(gt) :: Nil,
+ returnNullable = false)
+ }
+
+ def createDeserializerForGeographyType(inputObject: Expression, gt:
GeographyType): Expression = {
+ StaticInvoke(
+ classOf[STUtils],
+ ObjectType(classOf[Geography]),
+ "deserializeGeog",
+ inputObject :: Literal.fromObject(gt) :: Nil,
+ returnNullable = false)
+ }
+
def createDeserializerForChar(
path: Expression,
returnNullable: Boolean,
@@ -290,6 +308,10 @@ object DeserializerBuildHelper {
"withName",
createDeserializerForString(path, returnNullable = false) :: Nil,
returnNullable = false)
+ case g: GeographyEncoder =>
+ createDeserializerForGeographyType(path, g.dt)
+ case g: GeometryEncoder =>
+ createDeserializerForGeometryType(path, g.dt)
case CharEncoder(length) =>
createDeserializerForChar(path, returnNullable = false, length)
case VarcharEncoder(length) =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index 82b3cdc508bf..06267bca0218 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -22,11 +22,11 @@ import scala.language.existentials
import org.apache.spark.sql.catalyst.{expressions => exprs}
import
org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder,
AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec,
KryoSerializationCodec}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder,
BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder,
CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder,
IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder,
JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder,
MapEncoder, OptionEncoder, PrimitiveLeafEncoder, P [...]
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder,
BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder,
CharEncoder, DateEncoder, DayTimeIntervalEncoder, GeographyEncoder,
GeometryEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder,
JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder,
LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, Opt [...]
import
org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor,
isNativeEncoder, lenientExternalDataTypeFor}
import org.apache.spark.sql.catalyst.expressions.{BoundReference,
CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal,
UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils,
DateTimeUtils, GenericArrayData, IntervalUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils,
DateTimeUtils, GenericArrayData, IntervalUtils, STUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -63,6 +63,24 @@ object SerializerBuildHelper {
Invoke(inputObject, "doubleValue", DoubleType)
}
+ def createSerializerForGeographyType(inputObject: Expression, gt:
GeographyType): Expression = {
+ StaticInvoke(
+ classOf[STUtils],
+ gt,
+ "serializeGeogFromWKB",
+ inputObject :: Literal.fromObject(gt) :: Nil,
+ returnNullable = false)
+ }
+
+ def createSerializerForGeometryType(inputObject: Expression, gt:
GeometryType): Expression = {
+ StaticInvoke(
+ classOf[STUtils],
+ gt,
+ "serializeGeomFromWKB",
+ inputObject :: Literal.fromObject(gt) :: Nil,
+ returnNullable = false)
+ }
+
def createSerializerForChar(inputObject: Expression, length: Int):
Expression = {
StaticInvoke(
classOf[CharVarcharCodegenUtils],
@@ -326,6 +344,8 @@ object SerializerBuildHelper {
case BoxedDoubleEncoder => createSerializerForDouble(input)
case JavaEnumEncoder(_) => createSerializerForJavaEnum(input)
case ScalaEnumEncoder(_, _) => createSerializerForScalaEnum(input)
+ case g: GeographyEncoder => createSerializerForGeographyType(input, g.dt)
+ case g: GeometryEncoder => createSerializerForGeometryType(input, g.dt)
case CharEncoder(length) => createSerializerForChar(input, length)
case VarcharEncoder(length) => createSerializerForVarchar(input, length)
case StringEncoder => createSerializerForString(input)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/GeographyDataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/GeographyDataFrameSuite.scala
new file mode 100644
index 000000000000..eeb1ba5ea9e2
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeographyDataFrameSuite.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.immutable.Seq
+
+import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+
+class GeographyDataFrameSuite extends QueryTest with SharedSparkSession {
+
+ val point1 = "010100000000000000000031400000000000001C40"
+ .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+ val point2 = "010100000000000000000035400000000000001E40"
+ .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+
+ test("decode geography value: SRID schema does not match input SRID data
schema") {
+ val rdd = sparkContext.parallelize(Seq(Row(Geography.fromWKB(point1, 0))))
+ val schema = StructType(Seq(StructField("col1", GeographyType(4326),
nullable = false)))
+ checkError(
+ // We look for cause, as all exception encoder errors are wrapped in
+ // EXPRESSION_ENCODING_FAILED.
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(rdd, schema).collect()
+ }.getCause.asInstanceOf[SparkRuntimeException],
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid"
-> "4326")
+ )
+
+ val javaRDD = sparkContext.parallelize(Seq(Row(Geography.fromWKB(point1,
0)))).toJavaRDD()
+ checkError(
+ // We look for cause, as all exception encoder errors are wrapped in
+ // EXPRESSION_ENCODING_FAILED.
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(javaRDD, schema).collect()
+ }.getCause.asInstanceOf[SparkRuntimeException],
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid"
-> "4326")
+ )
+
+ // For some reason this API does not use expression encoders,
+ // but CatalystTypeConverter, so we are not looking at cause.
+ val javaList = java.util.Arrays.asList(Row(Geography.fromWKB(point1, 0)))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(javaList, schema).collect()
+ },
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid"
-> "4326")
+ )
+
+ val geography1 = Geography.fromWKB(point1, 0)
+ val rdd2 = sparkContext.parallelize(Seq((geography1, 1)))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(rdd2).collect()
+ },
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid"
-> "4326")
+ )
+
+ // For some reason this API does not use expression encoders,
+ // but CatalystTypeConverter, so we are not looking at cause.
+ val seq = Seq((geography1, 1))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(seq).collect()
+ },
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid"
-> "4326")
+ )
+
+ import testImplicits._
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ Seq(geography1).toDF().collect()
+ }.getCause.asInstanceOf[SparkRuntimeException],
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid"
-> "4326")
+ )
+ }
+
+ test("decode geography value: mixed SRID schema is provided") {
+ val rdd = sparkContext.parallelize(
+ Seq(Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2,
4326))))
+ val schema = StructType(Seq(StructField("col1", GeographyType("ANY"),
nullable = false)))
+ val expectedResult = Seq(
+ Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2,
4326)))
+
+ val resultDF = spark.createDataFrame(rdd, schema)
+ checkAnswer(resultDF, expectedResult)
+
+ val javaRDD = sparkContext.parallelize(
+ Seq(Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2,
4326)))).toJavaRDD()
+ val resultJavaDF = spark.createDataFrame(javaRDD, schema)
+ checkAnswer(resultJavaDF, expectedResult)
+
+ val javaList = java.util.Arrays.asList(
+ Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2,
4326)))
+ val resultJavaListDF = spark.createDataFrame(javaList, schema)
+ checkAnswer(resultJavaListDF, expectedResult)
+
+ // Test that unsupported SRID with mixed schema will throw an error.
+ val rdd2 = sparkContext.parallelize(
+ Seq(Row(Geography.fromWKB(point1, 0)), Row(Geography.fromWKB(point2,
4326))))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(rdd2, schema).collect()
+ }.getCause.asInstanceOf[SparkIllegalArgumentException],
+ condition = "ST_INVALID_SRID_VALUE",
+ parameters = Map("srid" -> "0")
+ )
+ }
+
+ test("createDataFrame APIs with Geography.fromWKB") {
+ // 1. Test createDataFrame with RDD of Geography objects
+ val geography1 = Geography.fromWKB(point1, 4326)
+ val geography2 = Geography.fromWKB(point2)
+ val rdd = sparkContext.parallelize(Seq((geography1, 1), (geography2, 2),
(null, 3)))
+ val dfFromRDD = spark.createDataFrame(rdd)
+ checkAnswer(dfFromRDD, Seq(Row(geography1, 1), Row(geography2, 2),
Row(null, 3)))
+
+ // 2. Test createDataFrame with Seq of Geography objects
+ val seq = Seq((geography1, 1), (geography2, 2), (null, 3))
+ val dfFromSeq = spark.createDataFrame(seq)
+ checkAnswer(dfFromSeq, Seq(Row(geography1, 1), Row(geography2, 2),
Row(null, 3)))
+
+ // 3. Test createDataFrame with RDD of Rows and StructType schema
+ val geography3 = Geography.fromWKB(point1, 4326)
+ val geography4 = Geography.fromWKB(point2, 4326)
+ val rowRDD = sparkContext.parallelize(Seq(Row(geography3),
Row(geography4), Row(null)))
+ val schema = StructType(Seq(
+ StructField("geography", GeographyType(4326), nullable = true)
+ ))
+ val dfFromRowRDD = spark.createDataFrame(rowRDD, schema)
+ checkAnswer(dfFromRowRDD, Seq(Row(geography3), Row(geography4), Row(null)))
+
+ // 4. Test createDataFrame with JavaRDD of Rows and StructType schema
+ val javaRDD = sparkContext.parallelize(Seq(Row(geography3),
Row(geography4), Row(null)))
+ .toJavaRDD()
+ val dfFromJavaRDD = spark.createDataFrame(javaRDD, schema)
+ checkAnswer(dfFromJavaRDD, Seq(Row(geography3), Row(geography4),
Row(null)))
+
+ // 5. Test createDataFrame with Java List of Rows and StructType schema
+ val javaList = java.util.Arrays.asList(Row(geography3), Row(geography4),
Row(null))
+ val dfFromJavaList = spark.createDataFrame(javaList, schema)
+ checkAnswer(dfFromJavaList, Seq(Row(geography3), Row(geography4),
Row(null)))
+
+ // 6. Implicit conversion from Seq to DF
+ import testImplicits._
+ val implicitDf = Seq(geography1, geography2, null).toDF()
+ checkAnswer(implicitDf, Seq(Row(geography1), Row(geography2), Row(null)))
+ }
+
+ test("encode geography type") {
+ // A test WKB value corresponding to: POINT (17 7).
+ val pointString: String = "010100000000000000000031400000000000001C40"
+ val pointBytes: Array[Byte] = pointString
+ .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+ val df = spark.sql(s"SELECT ST_GeogFromWKB(X'$pointString')")
+ val expectedGeog = Geography.fromWKB(pointBytes, 4326)
+ checkAnswer(df, Seq(Row(expectedGeog)))
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/GeometryDataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/GeometryDataFrameSuite.scala
new file mode 100644
index 000000000000..bcc3cee7ebe3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeometryDataFrameSuite.scala
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.immutable.Seq
+
+import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+
+class GeometryDataFrameSuite extends QueryTest with SharedSparkSession {
+
+ val point1 = "010100000000000000000031400000000000001C40"
+ .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+ val point2 = "010100000000000000000035400000000000001E40"
+ .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+
+ test("decode geometry value: SRID schema does not match input SRID data
schema") {
+ val rdd = sparkContext.parallelize(Seq(Row(Geometry.fromWKB(point1, 0))))
+ val schema = StructType(Seq(StructField("col1", GeometryType(3857),
nullable = false)))
+ checkError(
+ // We look for cause, as all exception encoder errors are wrapped in
+ // EXPRESSION_ENCODING_FAILED.
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(rdd, schema).collect()
+ }.getCause.asInstanceOf[SparkRuntimeException],
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "0", "typeSrid" ->
"3857")
+ )
+
+ val schema2 = StructType(Seq(StructField("col1", GeometryType(0), nullable
= false)))
+ val javaRDD = sparkContext.parallelize(Seq(Row(Geometry.fromWKB(point1,
4326)))).toJavaRDD()
+ checkError(
+ // We look for cause, as all exception encoder errors are wrapped in
+ // EXPRESSION_ENCODING_FAILED.
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(javaRDD, schema2).collect()
+ }.getCause.asInstanceOf[SparkRuntimeException],
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid"
-> "0")
+ )
+
+ // For some reason this API does not use expression encoders,
+ // but CatalystTypeConverter, so we are not looking at cause.
+ val javaList = java.util.Arrays.asList(Row(Geometry.fromWKB(point1, 4326)))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(javaList, schema).collect()
+ },
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid"
-> "3857")
+ )
+
+ val geometry1 = Geometry.fromWKB(point1, 4326)
+ val rdd2 = sparkContext.parallelize(Seq((geometry1, 1)))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(rdd2).collect()
+ },
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid"
-> "0")
+ )
+
+ // For some reason this API does not use expression encoders,
+ // but CatalystTypeConverter, so we are not looking at cause.
+ val seq = Seq((geometry1, 1))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(seq).collect()
+ },
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid"
-> "0")
+ )
+
+ import testImplicits._
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ Seq(geometry1).toDF().collect()
+ }.getCause.asInstanceOf[SparkRuntimeException],
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid"
-> "0")
+ )
+ }
+
+ test("decode geometry value: mixed SRID schema is provided") {
+ val rdd = sparkContext.parallelize(
+ Seq(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2,
4326))))
+ val schema = StructType(Seq(StructField("col1", GeometryType("ANY"),
nullable = false)))
+ val expectedResult = Seq(
+ Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 4326)))
+
+ val resultDF = spark.createDataFrame(rdd, schema)
+ checkAnswer(resultDF, expectedResult)
+
+ val javaRDD = sparkContext.parallelize(
+ Seq(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2,
4326)))).toJavaRDD()
+ val resultJavaDF = spark.createDataFrame(javaRDD, schema)
+ checkAnswer(resultJavaDF, expectedResult)
+
+ val javaList = java.util.Arrays.asList(
+ Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 4326)))
+ val resultJavaListDF = spark.createDataFrame(javaList, schema)
+ checkAnswer(resultJavaListDF, expectedResult)
+
+ // Test that unsupported SRID with mixed schema will throw an error.
+ val rdd2 = sparkContext.parallelize(
+ Seq(Row(Geometry.fromWKB(point1, 1)), Row(Geometry.fromWKB(point2,
4326))))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(rdd2, schema).collect()
+ }.getCause.asInstanceOf[SparkIllegalArgumentException],
+ condition = "ST_INVALID_SRID_VALUE",
+ parameters = Map("srid" -> "1")
+ )
+ }
+
+ test("createDataFrame APIs with Geometry.fromWKB") {
+ // 1. Test createDataFrame with RDD of Geometry objects
+ val geometry1 = Geometry.fromWKB(point1, 0)
+ val geometry2 = Geometry.fromWKB(point2, 0)
+ val rdd = sparkContext.parallelize(Seq((geometry1, 1), (geometry2, 2),
(null, 3)))
+ val dfFromRDD = spark.createDataFrame(rdd)
+ checkAnswer(dfFromRDD, Seq(Row(geometry1, 1), Row(geometry2, 2), Row(null,
3)))
+
+ // 2. Test createDataFrame with Seq of Geometry objects
+ val seq = Seq((geometry1, 1), (geometry2, 2), (null, 3))
+ val dfFromSeq = spark.createDataFrame(seq)
+ checkAnswer(dfFromSeq, Seq(Row(geometry1, 1), Row(geometry2, 2), Row(null,
3)))
+
+ // 3. Test createDataFrame with RDD of Rows and StructType schema
+ val geometry3 = Geometry.fromWKB(point1, 4326)
+ val geometry4 = Geometry.fromWKB(point2, 4326)
+ val rowRDD = sparkContext.parallelize(Seq(Row(geometry3), Row(geometry4),
Row(null)))
+ val schema = StructType(Seq(
+ StructField("geometry", GeometryType(4326), nullable = true)
+ ))
+ val dfFromRowRDD = spark.createDataFrame(rowRDD, schema)
+ checkAnswer(dfFromRowRDD, Seq(Row(geometry3), Row(geometry4), Row(null)))
+
+ // 4. Test createDataFrame with JavaRDD of Rows and StructType schema
+ val javaRDD = sparkContext.parallelize(Seq(Row(geometry3), Row(geometry4),
Row(null)))
+ .toJavaRDD()
+ val dfFromJavaRDD = spark.createDataFrame(javaRDD, schema)
+ checkAnswer(dfFromJavaRDD, Seq(Row(geometry3), Row(geometry4), Row(null)))
+
+ // 5. Test createDataFrame with Java List of Rows and StructType schema
+ val javaList = java.util.Arrays.asList(Row(geometry3), Row(geometry4),
Row(null))
+ val dfFromJavaList = spark.createDataFrame(javaList, schema)
+ checkAnswer(dfFromJavaList, Seq(Row(geometry3), Row(geometry4), Row(null)))
+
+ // 6. Implicit conversion from Seq to DF
+ import testImplicits._
+ val implicitDf = Seq(geometry1, geometry2, null).toDF()
+ checkAnswer(implicitDf, Seq(Row(geometry1), Row(geometry2), Row(null)))
+ }
+
+ test("encode geometry type") {
+ // A test WKB value corresponding to: POINT (17 7).
+ val pointString: String = "010100000000000000000031400000000000001C40"
+ val pointBytes: Array[Byte] = pointString
+ .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+ val df = spark.sql(s"SELECT ST_GeomFromWKB(X'$pointString')")
+ val expectedGeom = Geometry.fromWKB(pointBytes, 0)
+ checkAnswer(df, Seq(Row(expectedGeom)))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index 5de4170a1c11..eb36b68cd617 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -136,4 +136,15 @@ class RowSuite extends SparkFunSuite with
SharedSparkSession {
parameters = Map("index" -> position.toString)
)
}
+
+ test("Geospatial row API - Geography and Geometry") {
+ // A test WKB value corresponding to: POINT (17 7).
+ val point = "010100000000000000000031400000000000001C40"
+ .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+
+ val row = Row(Geometry.fromWKB(point), Geography.fromWKB(point))
+
+ assert(row.getGeometry(0).getBytes() == point)
+ assert(row.getGeography(1).getBytes() == point)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]