This is an automated email from the ASF dual-hosted git repository.

pan3793 pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new da311c4b69dc [SPARK-57040][SQL] JDBC connector supports pushdown 
TABLESAMPLE SYSTEM
da311c4b69dc is described below

commit da311c4b69dc70aa7342e78b3f15c1e9869d4041
Author: Cheng Pan <[email protected]>
AuthorDate: Wed May 27 02:48:55 2026 +0800

    [SPARK-57040][SQL] JDBC connector supports pushdown TABLESAMPLE SYSTEM
    
    ### What changes were proposed in this pull request?
    
    This PR contains 3 parts:
    
    1. `JdbcDialect` API change - deprecate `def supportsTableSample: Boolean` 
and `def getTableSample(sample: TableSampleInfo): String`, introduce a `def 
compileTableSample(sample: TableSampleInfo): Option[String]` as the replacement.
      1.1. this is a correctness fix - PostgreSQL and Databricks SQL do not 
support `withReplacement = true` (actually, it seems no mainstream RDBMS or 
OLAP engine supports TABLESAMPLE with replacement, Spark only supports it in 
DataFrame API), but the current implementation ignores `withReplacement` and 
always treats it as `withReplacement = false`, which is incorrect.
      1.2. it's a pre-step to support `TABLESAMPLE SYSTEM` as some RDBMSs may 
only support `TABLESAMPLE BERNOULLI` but not `TABLESAMPLE SYSTEM`.
    
    2. Mark the old `pushTableSample` method as deprecated and suggest the new 
API, connectors are suggested to only implement the new `pushTableSample` that 
has a `sampleMethod` parameter.
    
    3. Extend the built-in JDBC connector to support pushdown `TABLESAMPLE 
SYSTEM`, for now, this is only applicable to the PostgreSQL dialect.
    
    ### Why are the changes needed?
    
    Correctness fix - TABLESAMPLE `withReplacement = true` should not push down 
to PostgreSQL and Databricks JDBC connector.
    
    Make `JdbcDialects` API more flexible - RDBMS may partially support 
TABLESAMPLE, now they can answer whether the expression can be pushed down by 
testing the real TABLESAMPLE expression instead of blindly answering yes/no 
without knowing the input.
    
    Extend the built-in JDBC connector feature.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, see the above section for details.
    
    ### How was this patch tested?
    
    New UT/IT cases are added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Contains Content Generated-by: MiMo-V2.5-Pro
    
    Closes #56092 from pan3793/SPARK-57040.
    
    Authored-by: Cheng Pan <[email protected]>
    Signed-off-by: Cheng Pan <[email protected]>
    (cherry picked from commit 499172a5f879623dffa4eba122284a3ec9315235)
    Signed-off-by: Cheng Pan <[email protected]>
---
 .../sql/jdbc/v2/PostgresIntegrationSuite.scala     |  2 ++
 .../org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala  | 38 ++++++++++++++++++++--
 .../read/SupportsPushDownTableSample.java          |  8 +++--
 .../catalyst/expressions/V2ExpressionUtils.scala   |  8 ++++-
 .../sql/execution/datasources/jdbc/JDBCRDD.scala   | 13 ++++----
 .../execution/datasources/jdbc/JDBCRelation.scala  |  5 ++-
 .../execution/datasources/v2/jdbc/JDBCScan.scala   |  5 ++-
 .../datasources/v2/jdbc/JDBCScanBuilder.scala      | 20 +++++++-----
 .../v2/jdbc/JDBCV1RelationFromV2Scan.scala         |  7 ++--
 .../apache/spark/sql/jdbc/DatabricksDialect.scala  | 10 +++---
 .../org/apache/spark/sql/jdbc/JdbcDialects.scala   | 22 +++++++++++++
 .../spark/sql/jdbc/JdbcSQLQueryBuilder.scala       | 10 ++++++
 .../apache/spark/sql/jdbc/PostgresDialect.scala    | 11 +++----
 .../datasources/jdbc/JdbcTaskInterruptSuite.scala  |  2 +-
 14 files changed, 119 insertions(+), 42 deletions(-)

diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
index faeb39108c4f..d57d3aa5ea03 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
@@ -244,6 +244,8 @@ class PostgresIntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JDBCT
 
   override def supportsTableSample: Boolean = true
 
+  override def supportsTableSampleSystem: Boolean = true
+
   override def supportsIndex: Boolean = true
 
   override def indexOptions: String = "FILLFACTOR=70"
diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
index df5dfdf7deaf..79366189c20d 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
@@ -148,7 +148,7 @@ private[v2] trait V2JDBCTest
       partitionColumn: String)
   val tableNameToPartinioningOptions: Map[String, PartitioningInfo] = Map(
     "employee" -> PartitioningInfo("4", "1", "8", "dept"),
-    // new_table is used in "SPARK-37038: Test TABLESAMPLE" test
+    // new_table is used in "SPARK-37038,SPARK-57040: Test TABLESAMPLE" test
     "new_table" -> PartitioningInfo("4", "1", "20", "col1")
   )
 
@@ -470,6 +470,8 @@ private[v2] trait V2JDBCTest
 
   def supportsTableSample: Boolean = false
 
+  def supportsTableSampleSystem: Boolean = false
+
   test("SPARK-48172: Test CONTAINS") {
     val df1 = spark.sql(
       s"""
@@ -699,9 +701,20 @@ private[v2] trait V2JDBCTest
     assert(rows12(5).getString(0) === 
"special_character_underscorenot_present")
   }
 
+  test("SPARK-57040: TABLESAMPLE with replacement is not pushed down") {
+    withTable(s"$catalogName.new_table") {
+      sql(s"CREATE TABLE $catalogName.new_table (col1 INT, col2 INT)")
+      spark.range(10).select($"id" * 2, $"id" * 2 + 
1).write.insertInto(s"$catalogName.new_table")
+      val df = spark.read.table(s"$catalogName.new_table")
+        .sample(withReplacement = true, fraction = 0.5, seed = 12345)
+      checkSamplePushed(df, false)
+      assert(df.collect().length > 0)
+    }
+  }
+
   val partitioningEnabledTestCase = Seq(true, false)
   gridTest(
-    "SPARK-37038: Test TABLESAMPLE"
+    "SPARK-37038,SPARK-57040: Test TABLESAMPLE"
   )(partitioningEnabledTestCase) { partitioningEnabled =>
     if (supportsTableSample) {
       withTable(s"$catalogName.new_table") {
@@ -789,6 +802,27 @@ private[v2] trait V2JDBCTest
         checkSamplePushed(df8, false)
         checkFilterPushed(df8)
         assert(df8.collect().length < 10)
+
+        // SYSTEM sampling pushdown
+        if (supportsTableSampleSystem) {
+          val df9 = sql(s"SELECT * FROM $catalogName.new_table $tableOptions " 
+
+            "TABLESAMPLE SYSTEM (50 PERCENT)")
+          checkSamplePushed(df9)
+          if (partitioningEnabled) {
+            multiplePartitionAdditionalCheck(df1, partitionInfo)
+          }
+          assert(df9.collect().length <= 10)
+
+          // SYSTEM sampling + column pruning
+          val df10 = sql(s"SELECT col1 FROM $catalogName.new_table 
$tableOptions " +
+            "TABLESAMPLE SYSTEM (50 PERCENT)")
+          checkSamplePushed(df10)
+          checkColumnPruned(df10, "col1")
+          if (partitioningEnabled) {
+            multiplePartitionAdditionalCheck(df1, partitionInfo)
+          }
+          assert(df10.collect().length <= 10)
+        }
       }
     }
   }
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
index 3ceb7ed2de14..000588e9c318 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.connector.read;
 
 import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.errors.QueryCompilationErrors;
 
 /**
  * A mix-in interface for {@link Scan}. Data sources can implement this 
interface to
@@ -31,11 +32,14 @@ public interface SupportsPushDownTableSample extends 
ScanBuilder {
   /**
    * Pushes down BERNOULLI (row-level) SAMPLE to the data source.
    */
-  boolean pushTableSample(
+  @Deprecated(since = "4.2.0")
+  default boolean pushTableSample(
       double lowerBound,
       double upperBound,
       boolean withReplacement,
-      long seed);
+      long seed) {
+    throw QueryCompilationErrors.mustOverrideOneMethodError("pushTableSample");
+  }
 
   /**
    * Pushes down SAMPLE to the data source with the specified sampling method.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
index d747bebd5cfe..c561677ed5ad 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
@@ -26,12 +26,13 @@ import org.apache.spark.sql.catalyst.{InternalRow, 
SQLConfHelper}
 import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, 
UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.encoders.EncoderUtils
 import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, 
LogicalPlan, SampleMethod}
 import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
 import org.apache.spark.sql.connector.catalog.functions._
 import 
org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
 import org.apache.spark.sql.connector.expressions.{BucketTransform, Cast => 
V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, 
IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, 
NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => 
V2SortOrder, SortValue, Transform}
 import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, 
AlwaysTrue}
+import org.apache.spark.sql.connector.read.{SampleMethod => V2SampleMethod}
 import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.connector.PartitionPredicateImpl
@@ -168,6 +169,11 @@ object V2ExpressionUtils extends SQLConfHelper with 
Logging {
     case V2NullOrdering.NULLS_LAST => NullsLast
   }
 
+  def toCatalyst(sampleMethod: V2SampleMethod): SampleMethod = sampleMethod 
match {
+    case V2SampleMethod.BERNOULLI => SampleMethod.Bernoulli
+    case V2SampleMethod.SYSTEM => SampleMethod.System
+  }
+
   def resolveScalarFunction(
       scalarFunc: ScalarFunction[_],
       arguments: Seq[Expression]): Expression = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 16b25a9e6f70..425f98cad031 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -31,7 +31,6 @@ import org.apache.spark.sql.connector.catalog.Identifier
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.datasources.{DataSourceMetricsMixin, 
ExternalEngineDatasourceRDD}
-import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
 import org.apache.spark.sql.types._
@@ -141,7 +140,7 @@ object JDBCRDD extends Logging {
    * @param options - JDBC options that contains url, table and other 
information.
    * @param outputSchema - The schema of the columns or aggregate columns to 
SELECT.
    * @param groupByColumns - The pushed down group by columns.
-   * @param sample - The pushed down tableSample.
+   * @param sampleClause - The pushed down table sample SQL clause.
    * @param limit - The pushed down limit. If the value is 0, it means no 
limit or limit
    *                is not pushed down.
    * @param sortOrders - The sort orders cooperates with limit to realize top 
N.
@@ -158,7 +157,7 @@ object JDBCRDD extends Logging {
       options: JDBCOptions,
       outputSchema: Option[StructType] = None,
       groupByColumns: Option[Array[String]] = None,
-      sample: Option[TableSampleInfo] = None,
+      sampleClause: Option[String] = None,
       limit: Int = 0,
       sortOrders: Array[String] = Array.empty[String],
       offset: Int = 0,
@@ -184,7 +183,7 @@ object JDBCRDD extends Logging {
       options,
       databaseMetadata = 
JDBCDatabaseMetadata.fromJDBCConnectionFactory(connectionFactory),
       groupByColumns,
-      sample,
+      sampleClause,
       limit,
       sortOrders,
       offset,
@@ -209,7 +208,7 @@ class JDBCRDD(
     options: JDBCOptions,
     databaseMetadata: JDBCDatabaseMetadata,
     groupByColumns: Option[Array[String]],
-    sample: Option[TableSampleInfo],
+    sampleClause: Option[String],
     limit: Int,
     sortOrders: Array[String],
     offset: Int,
@@ -252,8 +251,8 @@ class JDBCRDD(
       builder = builder.withGroupByColumns(groupByKeys)
     }
 
-    sample.foreach { tableSampleInfo =>
-      builder = builder.withTableSample(tableSampleInfo)
+    sampleClause.foreach { clause =>
+      builder = builder.withTableSampleClause(clause)
     }
 
     builder.build()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 05e30207314a..972bb3e35ee6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, 
DateTimeUtils, Timesta
 import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, 
stringToDate, stringToTimestamp}
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.jdbc.JdbcDialects
@@ -314,7 +313,7 @@ private[sql] case class JDBCRelation(
       finalSchema: StructType,
       predicates: Array[Predicate],
       groupByColumns: Option[Array[String]],
-      tableSample: Option[TableSampleInfo],
+      tableSampleClause: Option[String],
       limit: Int,
       sortOrders: Array[String],
       offset: Int): RDD[Row] = {
@@ -328,7 +327,7 @@ private[sql] case class JDBCRelation(
       jdbcOptions,
       Some(finalSchema),
       groupByColumns,
-      tableSample,
+      tableSampleClause,
       limit,
       sortOrders,
       offset,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
index 75f3b04287aa..510f3b525e97 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
@@ -20,7 +20,6 @@ import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.connector.read.V1Scan
 import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
-import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
 import org.apache.spark.sql.sources.{BaseRelation, TableScan}
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.ArrayImplicits._
@@ -31,7 +30,7 @@ case class JDBCScan(
     pushedPredicates: Array[Predicate],
     pushedAggregateColumn: Array[String] = Array(),
     groupByColumns: Option[Array[String]],
-    tableSample: Option[TableSampleInfo],
+    tableSampleClause: Option[String],
     pushedLimit: Int,
     sortOrders: Array[String],
     pushedOffset: Int) extends V1Scan {
@@ -46,7 +45,7 @@ case class JDBCScan(
       pushedPredicates,
       pushedAggregateColumn,
       groupByColumns,
-      tableSample,
+      tableSampleClause,
       pushedLimit,
       sortOrders,
       pushedOffset).asInstanceOf[T]
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index b758ddd35e0d..45d5f920b9be 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -21,11 +21,12 @@ import scala.util.control.NonFatal
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.LogKeys.{JOIN_CONDITION, JOIN_TYPE, SCHEMA}
 import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils
 import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
 import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.connector.join.JoinType
-import org.apache.spark.sql.connector.read.{ScanBuilder, 
SupportsPushDownAggregates, SupportsPushDownJoin, SupportsPushDownLimit, 
SupportsPushDownOffset, SupportsPushDownRequiredColumns, 
SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
+import org.apache.spark.sql.connector.read.{SampleMethod, ScanBuilder, 
SupportsPushDownAggregates, SupportsPushDownJoin, SupportsPushDownLimit, 
SupportsPushDownOffset, SupportsPushDownRequiredColumns, 
SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
 import org.apache.spark.sql.execution.datasources.PartitioningUtils
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, 
JDBCPartition, JDBCRDD, JDBCRelation}
 import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
@@ -57,7 +58,7 @@ case class JDBCScanBuilder(
 
   private var finalSchema = schema
 
-  private var tableSample: Option[TableSampleInfo] = None
+  private var tableSampleClause: Option[String] = None
 
   private var pushedLimit = 0
 
@@ -251,7 +252,7 @@ case class JDBCScanBuilder(
     pushedPredicate = Array.empty[Predicate]
     // Table sample is pushed down already as well, so we need to reset it to 
None to not push it
     // down again when join pushdown is triggered again on this 
JDBCScanBuilder.
-    tableSample = None
+    tableSampleClause = None
 
     true
   }
@@ -275,10 +276,13 @@ case class JDBCScanBuilder(
       lowerBound: Double,
       upperBound: Double,
       withReplacement: Boolean,
-      seed: Long): Boolean = {
-    if (jdbcOptions.pushDownTableSample && dialect.supportsTableSample) {
-      this.tableSample = Some(TableSampleInfo(lowerBound, upperBound, 
withReplacement, seed))
-      return true
+      seed: Long,
+      sampleMethod: SampleMethod): Boolean = {
+    if (jdbcOptions.pushDownTableSample) {
+      val sample = TableSampleInfo(
+        lowerBound, upperBound, withReplacement, seed, 
V2ExpressionUtils.toCatalyst(sampleMethod))
+      this.tableSampleClause = dialect.compileTableSample(sample)
+      return this.tableSampleClause.isDefined
     }
     false
   }
@@ -343,7 +347,7 @@ case class JDBCScanBuilder(
     // be used in sql string.
     JDBCScan(JDBCRelation(schema, parts, jdbcOptions, 
additionalMetrics)(session),
       finalSchema, pushedPredicate, pushedAggregateList, pushedGroupBys,
-      tableSample, pushedLimit, sortOrders, pushedOffset)
+      tableSampleClause, pushedLimit, sortOrders, pushedOffset)
   }
 
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCV1RelationFromV2Scan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCV1RelationFromV2Scan.scala
index feb33effae23..ef08d0ad94b9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCV1RelationFromV2Scan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCV1RelationFromV2Scan.scala
@@ -20,7 +20,6 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{Row, SQLContext}
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
-import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
 import org.apache.spark.sql.sources.{BaseRelation, TableScan}
 import org.apache.spark.sql.types.StructType
 
@@ -35,7 +34,7 @@ case class JDBCV1RelationFromV2Scan(
     pushedPredicates: Array[Predicate],
     pushedAggregateColumn: Array[String] = Array(),
     groupByColumns: Option[Array[String]],
-    tableSample: Option[TableSampleInfo],
+    tableSampleClause: Option[String],
     pushedLimit: Int,
     sortOrders: Array[String],
     pushedOffset: Int) extends BaseRelation with TableScan {
@@ -49,8 +48,8 @@ case class JDBCV1RelationFromV2Scan(
       pushedAggregateColumn
     }
 
-    relation.buildScan(columnList, prunedSchema, pushedPredicates, 
groupByColumns, tableSample,
-      pushedLimit, sortOrders, pushedOffset)
+    relation.buildScan(columnList, prunedSchema, pushedPredicates, 
groupByColumns,
+      tableSampleClause, pushedLimit, sortOrders, pushedOffset)
   }
 
   override def toString: String = "JDBC v1 Relation from v2 scan"
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala
index 9124c1b88909..a56aa90d6d72 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala
@@ -18,9 +18,11 @@
 package org.apache.spark.sql.jdbc
 
 import java.sql.{Connection, SQLException}
+import java.util.Locale
 
 import scala.collection.mutable.ArrayBuilder
 
+import org.apache.spark.sql.catalyst.plans.logical.SampleMethod
 import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
 import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
 import org.apache.spark.sql.types._
@@ -71,10 +73,10 @@ private case class DatabricksDialect() extends JdbcDialect 
with NoLegacyJDBCErro
 
   override def supportsOffset: Boolean = true
 
-  override def supportsTableSample: Boolean = true
-
-  override def getTableSample(sample: TableSampleInfo): String = {
-    s"TABLESAMPLE (${(sample.upperBound - sample.lowerBound) * 100}) 
REPEATABLE (${sample.seed})"
+  override def compileTableSample(sample: TableSampleInfo): Option[String] = {
+    if (sample.withReplacement || sample.sampleMethod == SampleMethod.System) 
return None
+    Some(s"TABLESAMPLE 
${sample.sampleMethod.toString.toUpperCase(Locale.ROOT)}" +
+      s" (${(sample.upperBound - sample.lowerBound) * 100}) REPEATABLE 
(${sample.seed})")
   }
 
   override def supportsHint: Boolean = true
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 1ddf22834fbe..2c915c734b24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.plans.logical.SampleMethod
 import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, 
TimestampFormatter}
 import 
org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateTimeToMicros, 
toJavaTimestampNoRebase}
 import org.apache.spark.sql.catalyst.util.IntervalUtils.{fromDayTimeString, 
fromYearMonthString, getDuration}
@@ -868,11 +869,32 @@ abstract class JdbcDialect extends Serializable with 
Logging {
    */
   def supportsOffset: Boolean = false
 
+  @deprecated("Use compileTableSample instead", "4.2.0")
   def supportsTableSample: Boolean = false
 
+  @deprecated("Use compileTableSample instead", "4.2.0")
   def getTableSample(sample: TableSampleInfo): String =
     throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3183")
 
+  /**
+   * Compile a 
[[org.apache.spark.sql.execution.datasources.v2.TableSampleInfo]] into a
+   * SQL `TABLESAMPLE` clause, or return [[scala.None]] if the dialect cannot 
represent
+   * the requested sampling semantics (e.g. sampling with replacement).
+   *
+   * The default implementation delegates to [[getTableSample]] when 
[[supportsTableSample]]
+   * is true and the requested sample is BERNOULLI without replacement (the 
contract
+   * predating this method), and returns [[scala.None]] otherwise.
+   */
+  @Since("4.2.0")
+  def compileTableSample(sample: TableSampleInfo): Option[String] = {
+    if (supportsTableSample && !sample.withReplacement &&
+        sample.sampleMethod == SampleMethod.Bernoulli) {
+      Some(getTableSample(sample))
+    } else {
+      None
+    }
+  }
+
   def supportsHint: Boolean = false
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala
index 93af5890711c..4dd6631699cb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala
@@ -172,12 +172,22 @@ class JdbcSQLQueryBuilder(dialect: JdbcDialect, options: 
JDBCOptions) {
   /**
    * Constructs the table sample clause that following dialect's SQL syntax.
    */
+  @deprecated("Use withTableSampleClause(String) instead", "4.2.0")
   def withTableSample(sample: TableSampleInfo): JdbcSQLQueryBuilder = {
     tableSampleClause = dialect.getTableSample(sample)
 
     this
   }
 
+  /**
+   * Sets a pre-compiled table sample clause directly.
+   */
+  def withTableSampleClause(clause: String): JdbcSQLQueryBuilder = {
+    tableSampleClause = clause
+
+    this
+  }
+
   /**
    * Represents JOIN subquery in case Join has been pushed down. This value 
should be used
    * instead of options.tableOrQuery if join has been pushed down.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index 606d8f69760d..dd57c129179e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -376,15 +376,12 @@ private case class PostgresDialect()
 
   override def supportsOffset: Boolean = true
 
-  override def supportsTableSample: Boolean = true
-
   override def supportsJoin: Boolean = true
 
-  override def getTableSample(sample: TableSampleInfo): String = {
-    // hard-coded to BERNOULLI for now because Spark doesn't have a way to 
specify sample
-    // method name
-    "TABLESAMPLE BERNOULLI" +
-      s" (${(sample.upperBound - sample.lowerBound) * 100}) REPEATABLE 
(${sample.seed})"
+  override def compileTableSample(sample: TableSampleInfo): Option[String] = {
+    if (sample.withReplacement) return None
+    Some(s"TABLESAMPLE 
${sample.sampleMethod.toString.toUpperCase(Locale.ROOT)}" +
+      s" (${(sample.upperBound - sample.lowerBound) * 100}) REPEATABLE 
(${sample.seed})")
   }
 
   override def renameTable(oldTable: Identifier, newTable: Identifier): String 
= {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcTaskInterruptSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcTaskInterruptSuite.scala
index 7475fb34638d..768e36d64ee0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcTaskInterruptSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcTaskInterruptSuite.scala
@@ -289,7 +289,7 @@ class JdbcTaskInterruptSuite extends SharedSparkSession {
         options = options,
         databaseMetadata = 
JDBCDatabaseMetadata.fromJDBCConnectionFactory(getConnection),
         groupByColumns = None,
-        sample = None,
+        sampleClause = None,
         limit = 0,
         sortOrders = Array.empty,
         offset = 0,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to