This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5f83b892b5d1 [SPARK-54021][GEO][SQL] Implement Geography and Geometry
accessors across Catalyst
5f83b892b5d1 is described below
commit 5f83b892b5d1ae06602b2dff716f67e451f87e23
Author: Uros Bojanic <[email protected]>
AuthorDate: Tue Oct 28 23:36:18 2025 +0800
[SPARK-54021][GEO][SQL] Implement Geography and Geometry accessors across
Catalyst
### What changes were proposed in this pull request?
Added Geography and Geometry accessors to core row/column interfaces,
extended codegen and physical type handling to properly recognize geospatial
types, enabled writing/reading of Geography and Geometry values in unsafe
writer, and added other necessary plumbing for Geography and Geometry in
projection/row utilities in order to thread through the new accessors.
Note that the GEOMETRY and GEOGRAPHY physical types were recently included
to Spark SQL as part of: https://github.com/apache/spark/pull/52629.
### Why are the changes needed?
To provide first-class support for GEOGRAPHY and GEOMETRY within Catalyst.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added new tests to:
- `GenerateUnsafeProjectionSuite.scala`
- `UnsafeRowWriterSuite.scala`
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52723 from uros-db/geo-interfaces.
Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalyst/expressions/SpecializedGetters.java | 6 ++++++
.../expressions/SpecializedGettersReader.java | 6 ++++++
.../sql/catalyst/expressions/UnsafeArrayData.java | 14 +++++++++++++
.../spark/sql/catalyst/expressions/UnsafeRow.java | 14 +++++++++++++
.../catalyst/expressions/codegen/UnsafeWriter.java | 10 +++++++++
.../apache/spark/sql/vectorized/ColumnVector.java | 12 +++++++++++
.../apache/spark/sql/vectorized/ColumnarArray.java | 12 +++++++++++
.../spark/sql/vectorized/ColumnarBatchRow.java | 20 ++++++++++++++++++
.../apache/spark/sql/vectorized/ColumnarRow.java | 16 +++++++++++++++
.../spark/sql/catalyst/ProjectingInternalRow.scala | 10 ++++++++-
.../spark/sql/catalyst/encoders/EncoderUtils.scala | 6 ++++--
.../expressions/InterpretedUnsafeProjection.scala | 4 ++++
.../spark/sql/catalyst/expressions/JoinedRow.scala | 8 +++++++-
.../expressions/codegen/CodeGenerator.scala | 8 ++++++++
.../spark/sql/catalyst/expressions/rows.scala | 4 +++-
.../spark/sql/catalyst/util/GenericArrayData.scala | 4 +++-
.../codegen/GenerateUnsafeProjectionSuite.scala | 6 +++++-
.../expressions/codegen/UnsafeRowWriterSuite.scala | 24 +++++++++++++++++++++-
.../execution/vectorized/MutableColumnarRow.java | 16 +++++++++++++++
.../org/apache/spark/sql/execution/Columnar.scala | 16 +++++++++++++++
20 files changed, 208 insertions(+), 8 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
index b88a892db4b4..2a3a6884c3c6 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
@@ -25,6 +25,8 @@ import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.unsafe.types.VariantVal;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
public interface SpecializedGetters {
@@ -50,6 +52,10 @@ public interface SpecializedGetters {
byte[] getBinary(int ordinal);
+ GeographyVal getGeography(int ordinal);
+
+ GeometryVal getGeometry(int ordinal);
+
CalendarInterval getInterval(int ordinal);
VariantVal getVariant(int ordinal);
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
index a771805f6e5d..830aa0d0d0fb 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
@@ -60,6 +60,12 @@ public final class SpecializedGettersReader {
if (physicalDataType instanceof PhysicalStringType) {
return obj.getUTF8String(ordinal);
}
+ if (physicalDataType instanceof PhysicalGeographyType) {
+ return obj.getGeography(ordinal);
+ }
+ if (physicalDataType instanceof PhysicalGeometryType) {
+ return obj.getGeometry(ordinal);
+ }
if (physicalDataType instanceof PhysicalDecimalType dt) {
return obj.getDecimal(ordinal, dt.precision(), dt.scale());
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index be3e5a7d5043..09ac634955fc 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -40,6 +40,8 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.unsafe.types.VariantVal;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
@@ -222,6 +224,18 @@ public final class UnsafeArrayData extends ArrayData
implements Externalizable,
return bytes;
}
+ @Override
+ public GeographyVal getGeography(int ordinal) {
+ byte[] bytes = getBinary(ordinal);
+ return (bytes == null) ? null : GeographyVal.fromBytes(bytes);
+ }
+
+ @Override
+ public GeometryVal getGeometry(int ordinal) {
+ byte[] bytes = getBinary(ordinal);
+ return (bytes == null) ? null : GeometryVal.fromBytes(bytes);
+ }
+
@Override
public CalendarInterval getInterval(int ordinal) {
if (isNullAt(ordinal)) return null;
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 8741c206f2bb..ff9eeea9bf12 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -40,6 +40,8 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.unsafe.types.VariantVal;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
@@ -417,6 +419,18 @@ public final class UnsafeRow extends InternalRow
implements Externalizable, Kryo
}
}
+ @Override
+ public GeographyVal getGeography(int ordinal) {
+ byte[] bytes = getBinary(ordinal);
+ return (bytes == null) ? null : GeographyVal.fromBytes(bytes);
+ }
+
+ @Override
+ public GeometryVal getGeometry(int ordinal) {
+ byte[] bytes = getBinary(ordinal);
+ return (bytes == null) ? null : GeometryVal.fromBytes(bytes);
+ }
+
@Override
public CalendarInterval getInterval(int ordinal) {
if (isNullAt(ordinal)) {
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index 8e6d08bdadb8..e2abc108bb1b 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -24,6 +24,8 @@ import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.unsafe.types.VariantVal;
@@ -111,6 +113,14 @@ public abstract class UnsafeWriter {
writeUnalignedBytes(ordinal, input.getBaseObject(), input.getBaseOffset(),
input.numBytes());
}
+ public final void write(int ordinal, GeographyVal input) {
+ write(ordinal, input.getBytes());
+ }
+
+ public final void write(int ordinal, GeometryVal input) {
+ write(ordinal, input.getBytes());
+ }
+
public final void write(int ordinal, byte[] input) {
write(ordinal, input, 0, input.length);
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
index eb7142867776..8e9a5a620b3e 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
@@ -25,6 +25,8 @@ import org.apache.spark.sql.types.UserDefinedType;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.unsafe.types.VariantVal;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
/**
* An interface representing in-memory columnar data in Spark. This interface
defines the main APIs
@@ -288,6 +290,16 @@ public abstract class ColumnVector implements
AutoCloseable {
*/
public abstract byte[] getBinary(int rowId);
+ public GeographyVal getGeography(int rowId) {
+ byte[] bytes = getBinary(rowId);
+ return (bytes == null) ? null : GeographyVal.fromBytes(bytes);
+ }
+
+ public GeometryVal getGeometry(int rowId) {
+ byte[] bytes = getBinary(rowId);
+ return (bytes == null) ? null : GeometryVal.fromBytes(bytes);
+ }
+
/**
* Returns the calendar interval type value for {@code rowId}. If the slot
for
* {@code rowId} is null, it should return null.
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
index 12a2879794b1..fad1817aca19 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
@@ -26,6 +26,8 @@ import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.unsafe.types.VariantVal;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
/**
* Array abstraction in {@link ColumnVector}.
@@ -174,6 +176,16 @@ public final class ColumnarArray extends ArrayData {
return data.getBinary(offset + ordinal);
}
+ @Override
+ public GeographyVal getGeography(int ordinal) {
+ return data.getGeography(offset + ordinal);
+ }
+
+ @Override
+ public GeometryVal getGeometry(int ordinal) {
+ return data.getGeometry(offset + ordinal);
+ }
+
@Override
public CalendarInterval getInterval(int ordinal) {
return data.getInterval(offset + ordinal);
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java
index d05b3e2dc2d9..4be45dc5d399 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java
@@ -27,6 +27,8 @@ import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.unsafe.types.VariantVal;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
/**
* This class wraps an array of {@link ColumnVector} and provides a row view.
@@ -72,6 +74,10 @@ public final class ColumnarBatchRow extends InternalRow {
row.update(i, getUTF8String(i).copy());
} else if (pdt instanceof PhysicalBinaryType) {
row.update(i, getBinary(i));
+ } else if (pdt instanceof PhysicalGeographyType) {
+ row.update(i, getGeography(i));
+ } else if (pdt instanceof PhysicalGeometryType) {
+ row.update(i, getGeometry(i));
} else if (pdt instanceof PhysicalDecimalType t) {
row.setDecimal(i, getDecimal(i, t.precision(), t.scale()),
t.precision());
} else if (pdt instanceof PhysicalStructType t) {
@@ -132,6 +138,16 @@ public final class ColumnarBatchRow extends InternalRow {
return columns[ordinal].getBinary(rowId);
}
+ @Override
+ public GeographyVal getGeography(int ordinal) {
+ return columns[ordinal].getGeography(rowId);
+ }
+
+ @Override
+ public GeometryVal getGeometry(int ordinal) {
+ return columns[ordinal].getGeometry(rowId);
+ }
+
@Override
public CalendarInterval getInterval(int ordinal) {
return columns[ordinal].getInterval(rowId);
@@ -177,6 +193,10 @@ public final class ColumnarBatchRow extends InternalRow {
return getUTF8String(ordinal);
} else if (dataType instanceof BinaryType) {
return getBinary(ordinal);
+ } else if (dataType instanceof GeographyType) {
+ return getGeography(ordinal);
+ } else if (dataType instanceof GeometryType) {
+ return getGeometry(ordinal);
} else if (dataType instanceof DecimalType t) {
return getDecimal(ordinal, t.precision(), t.scale());
} else if (dataType instanceof DateType) {
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
index b14cd3429e47..d9e65afe1cb0 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
@@ -25,6 +25,8 @@ import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.unsafe.types.VariantVal;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
/**
* Row abstraction in {@link ColumnVector}.
@@ -77,6 +79,10 @@ public final class ColumnarRow extends InternalRow {
row.update(i, getUTF8String(i).copy());
} else if (pdt instanceof PhysicalBinaryType) {
row.update(i, getBinary(i));
+ } else if (pdt instanceof PhysicalGeographyType) {
+ row.update(i, getGeography(i));
+ } else if (pdt instanceof PhysicalGeometryType) {
+ row.update(i, getGeometry(i));
} else if (pdt instanceof PhysicalDecimalType t) {
row.setDecimal(i, getDecimal(i, t.precision(), t.scale()),
t.precision());
} else if (pdt instanceof PhysicalStructType t) {
@@ -137,6 +143,16 @@ public final class ColumnarRow extends InternalRow {
return data.getChild(ordinal).getBinary(rowId);
}
+ @Override
+ public GeographyVal getGeography(int ordinal) {
+ return data.getChild(ordinal).getGeography(rowId);
+ }
+
+ @Override
+ public GeometryVal getGeometry(int ordinal) {
+ return data.getChild(ordinal).getGeometry(rowId);
+ }
+
@Override
public CalendarInterval getInterval(int ordinal) {
return data.getChild(ordinal).getInterval(rowId);
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala
index 20cf80e88e42..0e451db6cfe2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{DataType, Decimal, StructType}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
+import org.apache.spark.unsafe.types._
/**
* An [[InternalRow]] that projects particular columns from another
[[InternalRow]] without copying
@@ -93,6 +93,14 @@ case class ProjectingInternalRow(schema: StructType,
row.getBinary(colOrdinals(ordinal))
}
+ override def getGeography(ordinal: Int): GeographyVal = {
+ row.getGeography(colOrdinals(ordinal))
+ }
+
+ override def getGeometry(ordinal: Int): GeometryVal = {
+ row.getGeometry(colOrdinals(ordinal))
+ }
+
override def getInterval(ordinal: Int): CalendarInterval = {
row.getInterval(colOrdinals(ordinal))
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
index a9f398c34654..e7b53344abbd 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
@@ -24,8 +24,8 @@ import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, C
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.types.{PhysicalBinaryType,
PhysicalIntegerType, PhysicalLongType}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
-import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType,
ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType,
Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType,
ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType,
TimeType, UserDefinedType, VariantType, YearMonthIntervalType}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType,
ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType,
Decimal, DecimalType, DoubleType, FloatType, GeographyType, GeometryType,
IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType,
TimestampNTZType, TimestampType, TimeType, UserDefinedType, VariantType,
YearMonthIntervalType}
+import org.apache.spark.unsafe.types.{CalendarInterval, GeographyVal,
GeometryVal, UTF8String, VariantVal}
/**
* :: DeveloperApi ::
@@ -107,6 +107,8 @@ object EncoderUtils {
case _: StructType => classOf[InternalRow]
case _: ArrayType => classOf[ArrayData]
case _: MapType => classOf[MapData]
+ case _: GeographyType => classOf[GeographyVal]
+ case _: GeometryType => classOf[GeometryVal]
case ObjectType(cls) => cls
case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object])
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
index 004cd576ace0..53b3e0598d58 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -163,6 +163,10 @@ object InterpretedUnsafeProjection {
case _: PhysicalStringType => (v, i) => writer.write(i,
v.getUTF8String(i))
+ case _: PhysicalGeographyType => (v, i) => writer.write(i,
v.getGeography(i))
+
+ case _: PhysicalGeometryType => (v, i) => writer.write(i,
v.getGeometry(i))
+
case PhysicalVariantType => (v, i) => writer.write(i, v.getVariant(i))
case PhysicalStructType(fields) =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
index 345f2b3030b5..4211dd5e4df0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
+import org.apache.spark.unsafe.types._
/**
* A mutable wrapper that makes two rows appear as a single concatenated row.
Designed to
@@ -114,6 +114,12 @@ class JoinedRow extends InternalRow {
override def getBinary(i: Int): Array[Byte] =
if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i -
row1.numFields)
+ override def getGeography(i: Int): GeographyVal =
+ if (i < row1.numFields) row1.getGeography(i) else row2.getGeography(i -
row1.numFields)
+
+ override def getGeometry(i: Int): GeometryVal =
+ if (i < row1.numFields) row1.getGeometry(i) else row2.getGeometry(i -
row1.numFields)
+
override def getArray(i: Int): ArrayData =
if (i < row1.numFields) row1.getArray(i) else row2.getArray(i -
row1.numFields)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 8c702815e9b9..13b1d329f7ec 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1518,6 +1518,8 @@ object CodeGenerator extends Logging {
classOf[Platform].getName,
classOf[InternalRow].getName,
classOf[UnsafeRow].getName,
+ classOf[GeographyVal].getName,
+ classOf[GeometryVal].getName,
classOf[UTF8String].getName,
classOf[Decimal].getName,
classOf[CalendarInterval].getName,
@@ -1682,6 +1684,8 @@ object CodeGenerator extends Logging {
case _ => PhysicalDataType(dataType) match {
case _: PhysicalArrayType => s"$input.getArray($ordinal)"
case PhysicalBinaryType => s"$input.getBinary($ordinal)"
+ case _: PhysicalGeographyType => s"$input.getGeography($ordinal)"
+ case _: PhysicalGeometryType => s"$input.getGeometry($ordinal)"
case PhysicalCalendarIntervalType => s"$input.getInterval($ordinal)"
case t: PhysicalDecimalType => s"$input.getDecimal($ordinal,
${t.precision}, ${t.scale})"
case _: PhysicalMapType => s"$input.getMap($ordinal)"
@@ -1960,6 +1964,8 @@ object CodeGenerator extends Logging {
* Returns the Java type for a DataType.
*/
def javaType(dt: DataType): String = dt match {
+ case _: GeographyType => "GeographyVal"
+ case _: GeometryType => "GeometryVal"
case udt: UserDefinedType[_] => javaType(udt.sqlType)
case ObjectType(cls) if cls.isArray =>
s"${javaType(ObjectType(cls.getComponentType))}[]"
case ObjectType(cls) => cls.getName
@@ -1995,6 +2001,8 @@ object CodeGenerator extends Logging {
case DoubleType => java.lang.Double.TYPE
case _: DecimalType => classOf[Decimal]
case BinaryType => classOf[Array[Byte]]
+ case _: GeographyType => classOf[GeographyVal]
+ case _: GeometryType => classOf[GeometryVal]
case _: StringType => classOf[UTF8String]
case CalendarIntervalType => classOf[CalendarInterval]
case _: StructType => classOf[InternalRow]
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 8379069c53d9..b8d6054fc6fc 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
+import org.apache.spark.unsafe.types._
import org.apache.spark.util.ArrayImplicits._
/**
@@ -45,6 +45,8 @@ trait BaseGenericInternalRow extends InternalRow {
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal =
getAs(ordinal)
override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
+ override def getGeography(ordinal: Int): GeographyVal = getAs(ordinal)
+ override def getGeometry(ordinal: Int): GeometryVal = getAs(ordinal)
override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
override def getVariant(ordinal: Int): VariantVal = getAs(ordinal)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
index 7ff36bef5a4b..808a3d43bf20 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
@@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{DataType, Decimal}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
+import org.apache.spark.unsafe.types._
class GenericArrayData(val array: Array[Any]) extends ArrayData {
@@ -72,6 +72,8 @@ class GenericArrayData(val array: Array[Any]) extends
ArrayData {
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal =
getAs(ordinal)
override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
+ override def getGeography(ordinal: Int): GeographyVal = getAs(ordinal)
+ override def getGeometry(ordinal: Int): GeometryVal = getAs(ordinal)
override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
override def getVariant(ordinal: Int): VariantVal = getAs(ordinal)
override def getStruct(ordinal: Int, numFields: Int): InternalRow =
getAs(ordinal)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
index eeb05139a3e5..9c0d610f35f6 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData,
MapData}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
+import org.apache.spark.unsafe.types._
class GenerateUnsafeProjectionSuite extends SparkFunSuite {
test("Test unsafe projection string access pattern") {
@@ -87,6 +87,8 @@ object AlwaysNull extends InternalRow {
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal =
notSupported
override def getUTF8String(ordinal: Int): UTF8String = notSupported
override def getBinary(ordinal: Int): Array[Byte] = notSupported
+ override def getGeography(ordinal: Int): GeographyVal = notSupported
+ override def getGeometry(ordinal: Int): GeometryVal = notSupported
override def getInterval(ordinal: Int): CalendarInterval = notSupported
override def getVariant(ordinal: Int): VariantVal = notSupported
override def getStruct(ordinal: Int, numFields: Int): InternalRow =
notSupported
@@ -117,6 +119,8 @@ object AlwaysNonNull extends InternalRow {
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal =
notSupported
override def getUTF8String(ordinal: Int): UTF8String =
UTF8String.fromString("test")
override def getBinary(ordinal: Int): Array[Byte] = notSupported
+ override def getGeography(ordinal: Int): GeographyVal = notSupported
+ override def getGeometry(ordinal: Int): GeometryVal = notSupported
override def getInterval(ordinal: Int): CalendarInterval = notSupported
override def getVariant(ordinal: Int): VariantVal = notSupported
override def getStruct(ordinal: Int, numFields: Int): InternalRow =
notSupported
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
index e2a416b773aa..de62f8b46b7d 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.Decimal
-import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
+import org.apache.spark.unsafe.types._
class UnsafeRowWriterSuite extends SparkFunSuite {
@@ -51,6 +51,28 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
assert(res1 == res2)
}
+ test("write and get geography through UnsafeRowWriter") {
+ val rowWriter = new UnsafeRowWriter(2)
+ rowWriter.resetRowWriter()
+ rowWriter.setNullAt(0)
+ assert(rowWriter.getRow.isNullAt(0))
+ assert(rowWriter.getRow.getGeography(0) === null)
+ val geography = GeographyVal.fromBytes(Array[Byte](1, 2, 3))
+ rowWriter.write(1, geography)
+ assert(rowWriter.getRow.getGeography(1).getBytes sameElements
geography.getBytes)
+ }
+
+ test("write and get geometry through UnsafeRowWriter") {
+ val rowWriter = new UnsafeRowWriter(2)
+ rowWriter.resetRowWriter()
+ rowWriter.setNullAt(0)
+ assert(rowWriter.getRow.isNullAt(0))
+ assert(rowWriter.getRow.getGeometry(0) === null)
+ val geometry = GeometryVal.fromBytes(Array[Byte](1, 2, 3))
+ rowWriter.write(1, geometry)
+ assert(rowWriter.getRow.getGeometry(1).getBytes sameElements
geometry.getBytes)
+ }
+
test("write and get calendar intervals through UnsafeRowWriter") {
val rowWriter = new UnsafeRowWriter(2)
rowWriter.resetRowWriter()
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
index 42d39457330c..49c27f977562 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
@@ -29,6 +29,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.sql.vectorized.ColumnarRow;
import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.unsafe.types.VariantVal;
@@ -76,6 +78,10 @@ public final class MutableColumnarRow extends InternalRow {
row.update(i, getUTF8String(i).copy());
} else if (dt instanceof BinaryType) {
row.update(i, getBinary(i));
+ } else if (dt instanceof GeographyType) {
+ row.update(i, getGeography(i));
+ } else if (dt instanceof GeometryType) {
+ row.update(i, getGeometry(i));
} else if (dt instanceof DecimalType t) {
row.setDecimal(i, getDecimal(i, t.precision(), t.scale()),
t.precision());
} else if (dt instanceof DateType) {
@@ -142,6 +148,16 @@ public final class MutableColumnarRow extends InternalRow {
return columns[ordinal].getBinary(rowId);
}
+ @Override
+ public GeographyVal getGeography(int ordinal) {
+ return columns[ordinal].getGeography(rowId);
+ }
+
+ @Override
+ public GeometryVal getGeometry(int ordinal) {
+ return columns[ordinal].getGeometry(rowId);
+ }
+
@Override
public CalendarInterval getInterval(int ordinal) {
return columns[ordinal].getInterval(rowId);
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
index f9193cd0495f..877b5e638b6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
@@ -267,6 +267,8 @@ private object RowToColumnConverter {
case LongType | TimestampType | TimestampNTZType | _:
DayTimeIntervalType => LongConverter
case DoubleType => DoubleConverter
case StringType => StringConverter
+ case _: GeographyType => GeographyConverter
+ case _: GeometryType => GeometryConverter
case CalendarIntervalType => CalendarConverter
case VariantType => VariantConverter
case at: ArrayType => ArrayConverter(getConverterForType(at.elementType,
at.containsNull))
@@ -338,6 +340,20 @@ private object RowToColumnConverter {
}
}
+ private object GeographyConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv:
WritableColumnVector): Unit = {
+ val data = row.getGeography(column).getBytes
+ cv.appendByteArray(data, 0, data.length)
+ }
+ }
+
+ private object GeometryConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv:
WritableColumnVector): Unit = {
+ val data = row.getGeometry(column).getBytes
+ cv.appendByteArray(data, 0, data.length)
+ }
+ }
+
private object CalendarConverter extends TypeConverter {
override def append(row: SpecializedGetters, column: Int, cv:
WritableColumnVector): Unit = {
val c = row.getInterval(column)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]