Copilot commented on code in PR #2871:
URL: https://github.com/apache/sedona/pull/2871#discussion_r3170017354


##########
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala:
##########
@@ -843,9 +886,20 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends SparkStrategy {
       .map { distanceExpr =>
         matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
           case Some(side) =>
-            if (broadcastSide.get == side) (Some(distanceExpr), None)
-            else if (distanceExpr.references.isEmpty) (Some(distanceExpr), 
None)
-            else (None, Some(distanceExpr))
+            if (geographyShape) {
+              // Geography distance joins read the per-row radius from the 
stream-side
+              // GeographyJoinShape inside GeographyDistanceRefiner, so the 
radius MUST
+              // also flow to the stream side. For literal radii we still keep 
the
+              // build-side expansion (mirror of the geometry literal 
optimisation).
+              if (broadcastSide.get == side) (Some(distanceExpr), 
Some(distanceExpr))
+              else if (distanceExpr.references.isEmpty)
+                (Some(distanceExpr), Some(distanceExpr))
+              else (None, Some(distanceExpr))

Review Comment:
   For Geography ST_DWithin broadcast joins, this branch forwards the distance 
expression to the stream side even when the expression is bound to the 
broadcast/index side (i.e., `broadcastSide.get == side` and 
`distanceExpr.references.nonEmpty`). At runtime `BroadcastIndexJoinExec` 
binds/evaluates the stream-side distance against `streamed.output`, so a 
broadcast-side column reference will fail to bind/evaluate. Consider either (a) 
not planning BroadcastIndexJoinExec in this case, or (b) keeping the radius 
only on the index side and teaching `GeographyDistanceRefiner` to read the 
per-row radius from the build-side `GeographyJoinShape.radius` when the 
stream-side radius is not available (and only expanding stream envelopes when 
the radius is stream-bound or literal).
   ```suggestion
                 // Geography distance joins currently require the per-row 
radius to be
                 // available on the streamed side for 
GeographyDistanceRefiner. A
                 // distance expression that is bound to the broadcast/index 
side cannot
                 // be forwarded as a stream-side expression because it will 
later be
                 // bound against streamed.output in BroadcastIndexJoinExec.
                 //
                 // Keep literal radii on both sides (preserving build-side 
expansion),
                 // allow stream-bound radii on the stream side, and reject the
                 // unsupported broadcast-bound non-literal case to avoid 
planning an
                 // invalid broadcast geography join.
                 if (distanceExpr.references.isEmpty) {
                   (Some(distanceExpr), Some(distanceExpr))
                 } else if (broadcastSide.get == side) {
                   throw new UnsupportedOperationException(
                     "Geography distance broadcast joins do not support 
non-literal " +
                       "distance expressions bound to the broadcast/index side")
                 } else {
                   (None, Some(distanceExpr))
                 }
   ```



##########
spark/common/src/test/scala/org/apache/sedona/sql/geography/BroadcastIndexJoinGeographySuite.scala:
##########
@@ -150,4 +150,212 @@ class BroadcastIndexJoinGeographySuite extends 
TestBaseScala {
       assert(pairs === Set((0, 0), (1, 1), (2, 2)))
     }
   }
+
+  describe("Geography broadcast spatial join (ST_Within)") {
+
+    it("plans BroadcastIndexJoinExec when the polygon side is broadcast") {
+      val joined =
+        pointGeogDf.join(broadcast(polygonGeogDf), expr("ST_Within(pt_geog, 
poly_geog)"))
+      assert(planUsesBroadcastIndexJoin(joined))
+      assert(joined.count() === 3)
+    }
+
+    it("plans BroadcastIndexJoinExec when the point side is broadcast") {
+      val joined =
+        polygonGeogDf.join(broadcast(pointGeogDf), expr("ST_Within(pt_geog, 
poly_geog)"))
+      assert(planUsesBroadcastIndexJoin(joined))
+      assert(joined.count() === 3)
+    }
+
+    it("returns the correct (poly_id, pt_id) pairs") {
+      val rows = pointGeogDf
+        .join(broadcast(polygonGeogDf), expr("ST_Within(pt_geog, poly_geog)"))
+        .selectExpr("poly_id", "pt_id")
+        .collect()
+        .map(r => (r.getInt(0), r.getInt(1)))
+        .toSet
+      assert(rows === Set((0, 0), (1, 1), (2, 2)))
+    }
+
+    it("supports LEFT OUTER with the polygon side broadcast") {
+      val joined = pointGeogDf
+        .join(broadcast(polygonGeogDf), expr("ST_Within(pt_geog, poly_geog)"), 
"left_outer")
+      assert(planUsesBroadcastIndexJoin(joined))
+      assert(joined.count() === 6)
+      assert(joined.where("poly_id IS NULL").count() === 3)
+    }
+  }
+
+  describe("Geography broadcast spatial join (ST_Intersects)") {
+
+    it("plans BroadcastIndexJoinExec when the polygon side is broadcast") {
+      val joined =
+        pointGeogDf.join(broadcast(polygonGeogDf), 
expr("ST_Intersects(poly_geog, pt_geog)"))
+      assert(planUsesBroadcastIndexJoin(joined))
+      assert(joined.count() === 3)
+    }
+
+    it("returns the correct (poly_id, pt_id) pairs") {
+      val rows = pointGeogDf
+        .join(broadcast(polygonGeogDf), expr("ST_Intersects(poly_geog, 
pt_geog)"))
+        .selectExpr("poly_id", "pt_id")
+        .collect()
+        .map(r => (r.getInt(0), r.getInt(1)))
+        .toSet
+      assert(rows === Set((0, 0), (1, 1), (2, 2)))
+    }
+
+    it("handles antimeridian-spanning polygons correctly") {
+      import sparkSession.implicits._
+      val polyDf = Seq((100, "POLYGON((170 -1, -170 -1, -170 1, 170 1, 170 
-1))"))
+        .toDF("poly_id", "wkt")
+        .selectExpr("poly_id", "ST_GeogFromWKT(wkt, 4326) AS poly_geog")
+
+      val ptDf = Seq((1, "POINT(175 0)"), (2, "POINT(-175 0)"), (3, "POINT(0 
0)"))
+        .toDF("pt_id", "wkt")
+        .selectExpr("pt_id", "ST_GeogFromWKT(wkt, 4326) AS pt_geog")
+
+      val joined = ptDf.join(broadcast(polyDf), expr("ST_Intersects(poly_geog, 
pt_geog)"))
+      assert(planUsesBroadcastIndexJoin(joined))
+      val matched = joined.selectExpr("pt_id").collect().map(_.getInt(0)).toSet
+      assert(matched === Set(1, 2))
+    }
+  }
+
+  private lazy val pointsLeftDf = {
+    import sparkSession.implicits._
+    Seq((0, "POINT(0 0)"), (1, "POINT(1 1)"), (2, "POINT(2 2)"), (3, "POINT(99 
99)"))
+      .toDF("id_l", "wkt")
+      .selectExpr("id_l", "ST_GeogFromWKT(wkt, 4326) AS geog_l")
+  }
+  private lazy val pointsRightDf = {
+    import sparkSession.implicits._
+    Seq((10, "POINT(0 0)"), (11, "POINT(1 1)"), (12, "POINT(2 2)"), (13, 
"POINT(50 50)"))
+      .toDF("id_r", "wkt")
+      .selectExpr("id_r", "ST_GeogFromWKT(wkt, 4326) AS geog_r")
+  }
+
+  describe("Geography broadcast spatial join (ST_Equals)") {
+
+    it("plans BroadcastIndexJoinExec and matches identical points") {
+      val joined =
+        pointsLeftDf.join(broadcast(pointsRightDf), expr("ST_Equals(geog_l, 
geog_r)"))
+      assert(planUsesBroadcastIndexJoin(joined))
+      val pairs = joined
+        .selectExpr("id_l", "id_r")
+        .collect()
+        .map(r => (r.getInt(0), r.getInt(1)))
+        .toSet
+      assert(pairs === Set((0, 10), (1, 11), (2, 12)))
+    }
+  }
+
+  private lazy val pointsADf = {
+    import sparkSession.implicits._
+    Seq((0, "POINT(0 0)"), (1, "POINT(1 0)"), (2, "POINT(2 0)"))
+      .toDF("id_a", "wkt")
+      .selectExpr("id_a", "ST_GeogFromWKT(wkt, 4326) AS geog_a")
+  }
+  private lazy val pointsBDf = {
+    import sparkSession.implicits._
+    Seq(
+      (10, "POINT(0 0)"), // 0 m from (0,0)
+      (11, "POINT(1 0)"), // 0 m from (1,0)
+      (12, "POINT(0 1)") // ~111 km north of (0,0)
+    ).toDF("id_b", "wkt")
+      .selectExpr("id_b", "ST_GeogFromWKT(wkt, 4326) AS geog_b")
+  }
+
+  describe("Geography broadcast spatial join (ST_DWithin)") {
+
+    it("plans BroadcastIndexJoinExec when the right side is broadcast") {
+      val joined =
+        pointsADf.join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b, 
1000.0)"))
+      assert(planUsesBroadcastIndexJoin(joined))
+    }
+
+    it("returns only same-location pairs at a tight threshold (1 km)") {
+      val pairs = pointsADf
+        .join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b, 1000.0)"))
+        .selectExpr("id_a", "id_b")
+        .collect()
+        .map(r => (r.getInt(0), r.getInt(1)))
+        .toSet
+      assert(pairs === Set((0, 10), (1, 11)))
+    }
+
+    it("returns the additional cross-row pair at a wide threshold (200 km)") {
+      // 200 km covers the ~111 km north neighbour from (0,0) -> (0,1) and the 
~111 km
+      // east-west neighbours.
+      val pairs = pointsADf
+        .join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b, 
200000.0)"))
+        .selectExpr("id_a", "id_b")
+        .collect()
+        .map(r => (r.getInt(0), r.getInt(1)))
+        .toSet
+      // (0,0)↔(0,0), (0,0)↔(1,0), (0,0)↔(0,1)
+      // (1,0)↔(0,0), (1,0)↔(1,0), (1,0)↔(0,1)
+      // (2,0)↔(1,0)  (only one within 200 km — (0,0) is ~222 km, (0,1) is 
~244 km)
+      assert(pairs.contains((0, 10)))
+      assert(pairs.contains((0, 11)))
+      assert(pairs.contains((0, 12)))
+      assert(pairs.contains((1, 10)))
+      assert(pairs.contains((1, 11)))
+      assert(pairs.contains((1, 12)))
+      assert(pairs.contains((2, 11)))
+      assert(!pairs.contains((2, 10)))
+    }
+
+    it("supports a per-row column-distance threshold") {
+      import sparkSession.implicits._
+      val withRadius =
+        Seq((0, "POINT(0 0)", 1000.0), (1, "POINT(1 0)", 1.0), (2, "POINT(2 
0)", 200000.0))
+          .toDF("id_a", "wkt", "radius_m")
+          .selectExpr("id_a", "ST_GeogFromWKT(wkt, 4326) AS geog_a", 
"radius_m")
+
+      val joined =
+        withRadius.join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b, 
radius_m)"))
+      assert(planUsesBroadcastIndexJoin(joined))
+      val pairs = joined
+        .selectExpr("id_a", "id_b")
+        .collect()
+        .map(r => (r.getInt(0), r.getInt(1)))
+        .toSet
+      // id_a=0 with 1 km: only (0,0) self-match — id_b=10
+      // id_a=1 with 1 m: only (1,0) self-match — id_b=11
+      // id_a=2 with 200 km: only (1,0) at ~111 km — id_b=11
+      assert(pairs === Set((0, 10), (1, 11), (2, 11)))
+    }
+
+    it("supports LEFT OUTER with the right side broadcast") {
+      val joined = pointsADf.join(
+        broadcast(pointsBDf),
+        expr("ST_DWithin(geog_a, geog_b, 1000.0)"),
+        "left_outer")
+      assert(planUsesBroadcastIndexJoin(joined))
+      // id_a=2 has no match within 1km -> NULL right side. Counts: 
(0,10),(1,11),(2,NULL).
+      assert(joined.count() === 3)
+      assert(joined.where("id_b IS NULL").count() === 1)
+    }
+
+    it("rejects ST_DWithin(geog, geog, dist, useSpheroid) at analysis time") {
+      // The 4-arg ST_DWithin is geometry-only; passing Geography arguments 
fails at
+      // analysis time with a DATATYPE_MISMATCH before the planner runs. There 
is no
+      // Geography overload of the 4-arg form because Geography is always 
spheroidal.
+      val ex = intercept[Throwable] {
+        pointsADf
+          .join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b, 1000.0, 
true)"))
+          .queryExecution
+          .sparkPlan
+      }
+      val msg = Iterator
+        .iterate[Throwable](ex)(t => if (t == null) null else t.getCause)
+        .takeWhile(_ != null)
+        .map(_.getMessage)
+        .mkString(" | ")
+      assert(
+        msg.contains("st_dwithin") && msg.contains("data type mismatch"),

Review Comment:
   This assertion is brittle across Spark versions/locales because it relies on 
exact, case-sensitive substrings in the analyzed exception message (e.g., 
"st_dwithin" and "data type mismatch"). To reduce cross-version flakiness, 
normalize the message (e.g., `toLowerCase`) and/or assert on a stable Spark 
error class/condition (e.g., contains "DATATYPE_MISMATCH") rather than the 
human-readable text.
   ```suggestion
         val normalizedMsg = msg.toLowerCase(java.util.Locale.ROOT)
         assert(
           normalizedMsg.contains("st_dwithin") &&
             (normalizedMsg.contains("datatype_mismatch") ||
               normalizedMsg.contains("data type mismatch")),
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to