Copilot commented on code in PR #2871:
URL: https://github.com/apache/sedona/pull/2871#discussion_r3170017354
##########
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala:
##########
@@ -843,9 +886,20 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends SparkStrategy {
.map { distanceExpr =>
matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
case Some(side) =>
- if (broadcastSide.get == side) (Some(distanceExpr), None)
- else if (distanceExpr.references.isEmpty) (Some(distanceExpr),
None)
- else (None, Some(distanceExpr))
+ if (geographyShape) {
+ // Geography distance joins read the per-row radius from the
stream-side
+ // GeographyJoinShape inside GeographyDistanceRefiner, so the
radius MUST
+ // also flow to the stream side. For literal radii we still keep
the
+ // build-side expansion (mirror of the geometry literal
optimisation).
+ if (broadcastSide.get == side) (Some(distanceExpr),
Some(distanceExpr))
+ else if (distanceExpr.references.isEmpty)
+ (Some(distanceExpr), Some(distanceExpr))
+ else (None, Some(distanceExpr))
Review Comment:
For Geography ST_DWithin broadcast joins, this branch forwards the distance
expression to the stream side even when the expression is bound to the
broadcast/index side (i.e., `broadcastSide.get == side` and
`distanceExpr.references.nonEmpty`). At runtime `BroadcastIndexJoinExec`
binds/evaluates the stream-side distance against `streamed.output`, so a
broadcast-side column reference will fail to bind/evaluate. Consider either (a)
not planning BroadcastIndexJoinExec in this case, or (b) keeping the radius
only on the index side and teaching `GeographyDistanceRefiner` to read the
per-row radius from the build-side `GeographyJoinShape.radius` when the
stream-side radius is not available (and only expanding stream envelopes when
the radius is stream-bound or literal).
```suggestion
// Geography distance joins currently require the per-row
radius to be
// available on the streamed side for
GeographyDistanceRefiner. A
// distance expression that is bound to the broadcast/index
side cannot
// be forwarded as a stream-side expression because it will
later be
// bound against streamed.output in BroadcastIndexJoinExec.
//
// Keep literal radii on both sides (preserving build-side
expansion),
// allow stream-bound radii on the stream side, and reject the
// unsupported broadcast-bound non-literal case to avoid
planning an
// invalid broadcast geography join.
if (distanceExpr.references.isEmpty) {
(Some(distanceExpr), Some(distanceExpr))
} else if (broadcastSide.get == side) {
throw new UnsupportedOperationException(
"Geography distance broadcast joins do not support
non-literal " +
"distance expressions bound to the broadcast/index side")
} else {
(None, Some(distanceExpr))
}
```
##########
spark/common/src/test/scala/org/apache/sedona/sql/geography/BroadcastIndexJoinGeographySuite.scala:
##########
@@ -150,4 +150,212 @@ class BroadcastIndexJoinGeographySuite extends
TestBaseScala {
assert(pairs === Set((0, 0), (1, 1), (2, 2)))
}
}
+
+ describe("Geography broadcast spatial join (ST_Within)") {
+
+ it("plans BroadcastIndexJoinExec when the polygon side is broadcast") {
+ val joined =
+ pointGeogDf.join(broadcast(polygonGeogDf), expr("ST_Within(pt_geog,
poly_geog)"))
+ assert(planUsesBroadcastIndexJoin(joined))
+ assert(joined.count() === 3)
+ }
+
+ it("plans BroadcastIndexJoinExec when the point side is broadcast") {
+ val joined =
+ polygonGeogDf.join(broadcast(pointGeogDf), expr("ST_Within(pt_geog,
poly_geog)"))
+ assert(planUsesBroadcastIndexJoin(joined))
+ assert(joined.count() === 3)
+ }
+
+ it("returns the correct (poly_id, pt_id) pairs") {
+ val rows = pointGeogDf
+ .join(broadcast(polygonGeogDf), expr("ST_Within(pt_geog, poly_geog)"))
+ .selectExpr("poly_id", "pt_id")
+ .collect()
+ .map(r => (r.getInt(0), r.getInt(1)))
+ .toSet
+ assert(rows === Set((0, 0), (1, 1), (2, 2)))
+ }
+
+ it("supports LEFT OUTER with the polygon side broadcast") {
+ val joined = pointGeogDf
+ .join(broadcast(polygonGeogDf), expr("ST_Within(pt_geog, poly_geog)"),
"left_outer")
+ assert(planUsesBroadcastIndexJoin(joined))
+ assert(joined.count() === 6)
+ assert(joined.where("poly_id IS NULL").count() === 3)
+ }
+ }
+
+ describe("Geography broadcast spatial join (ST_Intersects)") {
+
+ it("plans BroadcastIndexJoinExec when the polygon side is broadcast") {
+ val joined =
+ pointGeogDf.join(broadcast(polygonGeogDf),
expr("ST_Intersects(poly_geog, pt_geog)"))
+ assert(planUsesBroadcastIndexJoin(joined))
+ assert(joined.count() === 3)
+ }
+
+ it("returns the correct (poly_id, pt_id) pairs") {
+ val rows = pointGeogDf
+ .join(broadcast(polygonGeogDf), expr("ST_Intersects(poly_geog,
pt_geog)"))
+ .selectExpr("poly_id", "pt_id")
+ .collect()
+ .map(r => (r.getInt(0), r.getInt(1)))
+ .toSet
+ assert(rows === Set((0, 0), (1, 1), (2, 2)))
+ }
+
+ it("handles antimeridian-spanning polygons correctly") {
+ import sparkSession.implicits._
+ val polyDf = Seq((100, "POLYGON((170 -1, -170 -1, -170 1, 170 1, 170
-1))"))
+ .toDF("poly_id", "wkt")
+ .selectExpr("poly_id", "ST_GeogFromWKT(wkt, 4326) AS poly_geog")
+
+ val ptDf = Seq((1, "POINT(175 0)"), (2, "POINT(-175 0)"), (3, "POINT(0
0)"))
+ .toDF("pt_id", "wkt")
+ .selectExpr("pt_id", "ST_GeogFromWKT(wkt, 4326) AS pt_geog")
+
+ val joined = ptDf.join(broadcast(polyDf), expr("ST_Intersects(poly_geog,
pt_geog)"))
+ assert(planUsesBroadcastIndexJoin(joined))
+ val matched = joined.selectExpr("pt_id").collect().map(_.getInt(0)).toSet
+ assert(matched === Set(1, 2))
+ }
+ }
+
+ private lazy val pointsLeftDf = {
+ import sparkSession.implicits._
+ Seq((0, "POINT(0 0)"), (1, "POINT(1 1)"), (2, "POINT(2 2)"), (3, "POINT(99
99)"))
+ .toDF("id_l", "wkt")
+ .selectExpr("id_l", "ST_GeogFromWKT(wkt, 4326) AS geog_l")
+ }
+ private lazy val pointsRightDf = {
+ import sparkSession.implicits._
+ Seq((10, "POINT(0 0)"), (11, "POINT(1 1)"), (12, "POINT(2 2)"), (13,
"POINT(50 50)"))
+ .toDF("id_r", "wkt")
+ .selectExpr("id_r", "ST_GeogFromWKT(wkt, 4326) AS geog_r")
+ }
+
+ describe("Geography broadcast spatial join (ST_Equals)") {
+
+ it("plans BroadcastIndexJoinExec and matches identical points") {
+ val joined =
+ pointsLeftDf.join(broadcast(pointsRightDf), expr("ST_Equals(geog_l,
geog_r)"))
+ assert(planUsesBroadcastIndexJoin(joined))
+ val pairs = joined
+ .selectExpr("id_l", "id_r")
+ .collect()
+ .map(r => (r.getInt(0), r.getInt(1)))
+ .toSet
+ assert(pairs === Set((0, 10), (1, 11), (2, 12)))
+ }
+ }
+
+ private lazy val pointsADf = {
+ import sparkSession.implicits._
+ Seq((0, "POINT(0 0)"), (1, "POINT(1 0)"), (2, "POINT(2 0)"))
+ .toDF("id_a", "wkt")
+ .selectExpr("id_a", "ST_GeogFromWKT(wkt, 4326) AS geog_a")
+ }
+ private lazy val pointsBDf = {
+ import sparkSession.implicits._
+ Seq(
+ (10, "POINT(0 0)"), // 0 m from (0,0)
+ (11, "POINT(1 0)"), // 0 m from (1,0)
+ (12, "POINT(0 1)") // ~111 km north of (0,0)
+ ).toDF("id_b", "wkt")
+ .selectExpr("id_b", "ST_GeogFromWKT(wkt, 4326) AS geog_b")
+ }
+
+ describe("Geography broadcast spatial join (ST_DWithin)") {
+
+ it("plans BroadcastIndexJoinExec when the right side is broadcast") {
+ val joined =
+ pointsADf.join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b,
1000.0)"))
+ assert(planUsesBroadcastIndexJoin(joined))
+ }
+
+ it("returns only same-location pairs at a tight threshold (1 km)") {
+ val pairs = pointsADf
+ .join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b, 1000.0)"))
+ .selectExpr("id_a", "id_b")
+ .collect()
+ .map(r => (r.getInt(0), r.getInt(1)))
+ .toSet
+ assert(pairs === Set((0, 10), (1, 11)))
+ }
+
+ it("returns the additional cross-row pair at a wide threshold (200 km)") {
+ // 200 km covers the ~111 km north neighbour from (0,0) -> (0,1) and the
~111 km
+ // east-west neighbours.
+ val pairs = pointsADf
+ .join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b,
200000.0)"))
+ .selectExpr("id_a", "id_b")
+ .collect()
+ .map(r => (r.getInt(0), r.getInt(1)))
+ .toSet
+ // (0,0)↔(0,0), (0,0)↔(1,0), (0,0)↔(0,1)
+ // (1,0)↔(0,0), (1,0)↔(1,0), (1,0)↔(0,1)
+ // (2,0)↔(1,0) (only one within 200 km — (0,0) is ~222 km, (0,1) is
~244 km)
+ assert(pairs.contains((0, 10)))
+ assert(pairs.contains((0, 11)))
+ assert(pairs.contains((0, 12)))
+ assert(pairs.contains((1, 10)))
+ assert(pairs.contains((1, 11)))
+ assert(pairs.contains((1, 12)))
+ assert(pairs.contains((2, 11)))
+ assert(!pairs.contains((2, 10)))
+ }
+
+ it("supports a per-row column-distance threshold") {
+ import sparkSession.implicits._
+ val withRadius =
+ Seq((0, "POINT(0 0)", 1000.0), (1, "POINT(1 0)", 1.0), (2, "POINT(2
0)", 200000.0))
+ .toDF("id_a", "wkt", "radius_m")
+ .selectExpr("id_a", "ST_GeogFromWKT(wkt, 4326) AS geog_a",
"radius_m")
+
+ val joined =
+ withRadius.join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b,
radius_m)"))
+ assert(planUsesBroadcastIndexJoin(joined))
+ val pairs = joined
+ .selectExpr("id_a", "id_b")
+ .collect()
+ .map(r => (r.getInt(0), r.getInt(1)))
+ .toSet
+ // id_a=0 with 1 km: only (0,0) self-match — id_b=10
+ // id_a=1 with 1 m: only (1,0) self-match — id_b=11
+ // id_a=2 with 200 km: only (1,0) at ~111 km — id_b=11
+ assert(pairs === Set((0, 10), (1, 11), (2, 11)))
+ }
+
+ it("supports LEFT OUTER with the right side broadcast") {
+ val joined = pointsADf.join(
+ broadcast(pointsBDf),
+ expr("ST_DWithin(geog_a, geog_b, 1000.0)"),
+ "left_outer")
+ assert(planUsesBroadcastIndexJoin(joined))
+ // id_a=2 has no match within 1km -> NULL right side. Counts:
(0,10),(1,11),(2,NULL).
+ assert(joined.count() === 3)
+ assert(joined.where("id_b IS NULL").count() === 1)
+ }
+
+ it("rejects ST_DWithin(geog, geog, dist, useSpheroid) at analysis time") {
+ // The 4-arg ST_DWithin is geometry-only; passing Geography arguments
fails at
+ // analysis time with a DATATYPE_MISMATCH before the planner runs. There
is no
+ // Geography overload of the 4-arg form because Geography is always
spheroidal.
+ val ex = intercept[Throwable] {
+ pointsADf
+ .join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b, 1000.0,
true)"))
+ .queryExecution
+ .sparkPlan
+ }
+ val msg = Iterator
+ .iterate[Throwable](ex)(t => if (t == null) null else t.getCause)
+ .takeWhile(_ != null)
+ .map(_.getMessage)
+ .mkString(" | ")
+ assert(
+ msg.contains("st_dwithin") && msg.contains("data type mismatch"),
Review Comment:
This assertion is brittle across Spark versions/locales because it relies on
exact, case-sensitive substrings in the analyzed exception message (e.g.,
"st_dwithin" and "data type mismatch"). To reduce cross-version flakiness,
normalize the message (e.g., `toLowerCase`) and/or assert on a stable Spark
error class/condition (e.g., contains "DATATYPE_MISMATCH") rather than the
human-readable text.
```suggestion
val normalizedMsg = msg.toLowerCase(java.util.Locale.ROOT)
assert(
normalizedMsg.contains("st_dwithin") &&
(normalizedMsg.contains("datatype_mismatch") ||
normalizedMsg.contains("data type mismatch")),
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]