This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 45eea269e8 [Spark Connector] Escape column names when querying Pinot (#10663) 45eea269e8 is described below commit 45eea269e8942dd5c031c5b1789199a6c5397cd0 Author: Caner Balci <canerba...@gmail.com> AuthorDate: Tue Apr 25 12:54:27 2023 -0700 [Spark Connector] Escape column names when querying Pinot (#10663) --- .../spark/datasource/query/FilterPushDown.scala | 32 ++++++++++++---------- .../datasource/query/FilterPushDownTest.scala | 21 +++++++++----- .../spark/v3/datasource/query/FilterPushDown.scala | 32 ++++++++++++---------- .../ExampleSparkPinotConnectorTest.scala | 1 + .../v3/datasource/query/FilterPushDownTest.scala | 21 +++++++++----- .../spark/common/PinotClusterClient.scala | 4 +-- .../spark/common/query/ScanQueryGenerator.scala | 6 +++- .../common/query/ScanQueryGeneratorTest.scala | 24 ++++++++-------- 8 files changed, 84 insertions(+), 57 deletions(-) diff --git a/pinot-connectors/pinot-spark-2-connector/src/main/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDown.scala b/pinot-connectors/pinot-spark-2-connector/src/main/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDown.scala index 30954b566c..331594663a 100644 --- a/pinot-connectors/pinot-spark-2-connector/src/main/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDown.scala +++ b/pinot-connectors/pinot-spark-2-connector/src/main/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDown.scala @@ -81,25 +81,29 @@ private[pinot] object FilterPushDown { case _ => value } + private def escapeAttr(attr: String): String = { + if (attr.contains("\"")) attr else s""""$attr"""" + } + private def compileFilter(filter: Filter): Option[String] = { val whereCondition = filter match { - case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case EqualTo(attr, value) => s"${escapeAttr(attr)} = ${compileValue(value)}" case EqualNullSafe(attr, value) => - s"NOT ($attr != ${compileValue(value)} OR $attr IS NULL OR " + + s"NOT (${escapeAttr(attr)} != ${compileValue(value)} OR ${escapeAttr(attr)} IS NULL OR " + s"${compileValue(value)} IS NULL) OR " + - s"($attr IS NULL AND ${compileValue(value)} IS NULL)" - case LessThan(attr, value) => s"$attr < ${compileValue(value)}" - case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" - case IsNull(attr) => s"$attr IS NULL" - case IsNotNull(attr) => s"$attr IS NOT NULL" - case StringStartsWith(attr, value) => s"$attr LIKE '$value%'" - case StringEndsWith(attr, value) => s"$attr LIKE '%$value'" - case StringContains(attr, value) => s"$attr LIKE '%$value%'" + s"(${escapeAttr(attr)} IS NULL AND ${compileValue(value)} IS NULL)" + case LessThan(attr, value) => s"${escapeAttr(attr)} < ${compileValue(value)}" + case GreaterThan(attr, value) => s"${escapeAttr(attr)} > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"${escapeAttr(attr)} <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"${escapeAttr(attr)} >= ${compileValue(value)}" + case IsNull(attr) => s"${escapeAttr(attr)} IS NULL" + case IsNotNull(attr) => s"${escapeAttr(attr)} IS NOT NULL" + case StringStartsWith(attr, value) => s"${escapeAttr(attr)} LIKE '$value%'" + case StringEndsWith(attr, value) => s"${escapeAttr(attr)} LIKE '%$value'" + case StringContains(attr, value) => s"${escapeAttr(attr)} LIKE '%$value%'" case In(attr, value) if value.isEmpty => - s"CASE WHEN $attr IS NULL THEN NULL ELSE FALSE END" - case In(attr, value) => s"$attr IN (${compileValue(value)})" + s"CASE WHEN ${escapeAttr(attr)} IS NULL THEN NULL ELSE FALSE END" + case In(attr, value) => s"${escapeAttr(attr)} IN (${compileValue(value)})" case Not(f) => compileFilter(f).map(p => s"NOT ($p)").orNull case Or(f1, f2) => val or = Seq(f1, f2).flatMap(compileFilter) diff --git a/pinot-connectors/pinot-spark-2-connector/src/test/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDownTest.scala b/pinot-connectors/pinot-spark-2-connector/src/test/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDownTest.scala index 127aab211f..eeb961ef25 100644 --- a/pinot-connectors/pinot-spark-2-connector/src/test/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDownTest.scala +++ b/pinot-connectors/pinot-spark-2-connector/src/test/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDownTest.scala @@ -61,15 +61,22 @@ class FilterPushDownTest extends BaseTest { test("SQL query should be created from spark filters") { val whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters) val expectedOutput = - s"(attr1 = 1) AND (attr2 IN ('1', '2', '''5''')) AND (attr3 < 1) AND (attr4 <= 3) AND (attr5 > 10) AND " + - s"(attr6 >= 15) AND (NOT (attr7 = '1')) AND ((attr8 < 10) AND (attr9 <= 3)) AND " + - s"((attr10 = 'hello') OR (attr11 >= 13)) AND (attr12 LIKE '%pinot%') AND (attr13 IN (10, 20)) AND " + - s"(NOT (attr20 != '123' OR attr20 IS NULL OR '123' IS NULL) OR (attr20 IS NULL AND '123' IS NULL)) AND " + - s"(attr14 IS NULL) AND (attr15 IS NOT NULL) AND (attr16 LIKE 'pinot1%') AND (attr17 LIKE '%pinot2') AND " + - s"(attr18 = '2020-01-01 00:00:15.0') AND (attr19 < '2020-01-01') AND (attr21 = List(1, 2)) AND " + - s"(attr22 = 10.5)" + s"""("attr1" = 1) AND ("attr2" IN ('1', '2', '''5''')) AND ("attr3" < 1) AND ("attr4" <= 3) AND ("attr5" > 10) AND """ + + s"""("attr6" >= 15) AND (NOT ("attr7" = '1')) AND (("attr8" < 10) AND ("attr9" <= 3)) AND """ + + s"""(("attr10" = 'hello') OR ("attr11" >= 13)) AND ("attr12" LIKE '%pinot%') AND ("attr13" IN (10, 20)) AND """ + + s"""(NOT ("attr20" != '123' OR "attr20" IS NULL OR '123' IS NULL) OR ("attr20" IS NULL AND '123' IS NULL)) AND """ + + s"""("attr14" IS NULL) AND ("attr15" IS NOT NULL) AND ("attr16" LIKE 'pinot1%') AND ("attr17" LIKE '%pinot2') AND """ + + s"""("attr18" = '2020-01-01 00:00:15.0') AND ("attr19" < '2020-01-01') AND ("attr21" = List(1, 2)) AND """ + + s"""("attr22" = 10.5)""" whereClause.get shouldEqual expectedOutput } + test("Shouldn't escape column names which are already escaped") { + val whereClause = FilterPushDown.compileFiltersToSqlWhereClause( + Array(EqualTo("\"some\".\"nested\".\"column\"", 1))) + val expectedOutput = "(\"some\".\"nested\".\"column\" = 1)" + + whereClause.get shouldEqual expectedOutput + } } diff --git a/pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDown.scala b/pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDown.scala index 3d4e3f658d..cac50ec031 100644 --- a/pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDown.scala +++ b/pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDown.scala @@ -81,25 +81,29 @@ private[pinot] object FilterPushDown { case _ => value } + private def escapeAttr(attr: String): String = { + if (attr.contains("\"")) attr else s""""$attr"""" + } + private def compileFilter(filter: Filter): Option[String] = { val whereCondition = filter match { - case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case EqualTo(attr, value) => s"${escapeAttr(attr)} = ${compileValue(value)}" case EqualNullSafe(attr, value) => - s"NOT ($attr != ${compileValue(value)} OR $attr IS NULL OR " + + s"NOT (${escapeAttr(attr)} != ${compileValue(value)} OR ${escapeAttr(attr)} IS NULL OR " + s"${compileValue(value)} IS NULL) OR " + - s"($attr IS NULL AND ${compileValue(value)} IS NULL)" - case LessThan(attr, value) => s"$attr < ${compileValue(value)}" - case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" - case IsNull(attr) => s"$attr IS NULL" - case IsNotNull(attr) => s"$attr IS NOT NULL" - case StringStartsWith(attr, value) => s"$attr LIKE '$value%'" - case StringEndsWith(attr, value) => s"$attr LIKE '%$value'" - case StringContains(attr, value) => s"$attr LIKE '%$value%'" + s"(${escapeAttr(attr)} IS NULL AND ${compileValue(value)} IS NULL)" + case LessThan(attr, value) => s"${escapeAttr(attr)} < ${compileValue(value)}" + case GreaterThan(attr, value) => s"${escapeAttr(attr)} > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"${escapeAttr(attr)} <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"${escapeAttr(attr)} >= ${compileValue(value)}" + case IsNull(attr) => s"${escapeAttr(attr)} IS NULL" + case IsNotNull(attr) => s"${escapeAttr(attr)} IS NOT NULL" + case StringStartsWith(attr, value) => s"${escapeAttr(attr)} LIKE '$value%'" + case StringEndsWith(attr, value) => s"${escapeAttr(attr)} LIKE '%$value'" + case StringContains(attr, value) => s"${escapeAttr(attr)} LIKE '%$value%'" case In(attr, value) if value.isEmpty => - s"CASE WHEN $attr IS NULL THEN NULL ELSE FALSE END" - case In(attr, value) => s"$attr IN (${compileValue(value)})" + s"CASE WHEN ${escapeAttr(attr)} IS NULL THEN NULL ELSE FALSE END" + case In(attr, value) => s"${escapeAttr(attr)} IN (${compileValue(value)})" case Not(f) => compileFilter(f).map(p => s"NOT ($p)").orNull case Or(f1, f2) => val or = Seq(f1, f2).flatMap(compileFilter) diff --git a/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/ExampleSparkPinotConnectorTest.scala b/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/ExampleSparkPinotConnectorTest.scala index 48692ae0f7..3c2755baf7 100644 --- a/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/ExampleSparkPinotConnectorTest.scala +++ b/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/ExampleSparkPinotConnectorTest.scala @@ -49,6 +49,7 @@ object ExampleSparkPinotConnectorTest extends Logging { } def readOffline()(implicit spark: SparkSession): Unit = { + import spark.implicits._ log.info("## Reading `airlineStats_OFFLINE` table... ##") val data = spark.read .format("pinot") diff --git a/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDownTest.scala b/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDownTest.scala index 6202257b9b..1bf889ddee 100644 --- a/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDownTest.scala +++ b/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDownTest.scala @@ -61,15 +61,22 @@ class FilterPushDownTest extends BaseTest { test("SQL query should be created from spark filters") { val whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters) val expectedOutput = - s"(attr1 = 1) AND (attr2 IN ('1', '2', '''5''')) AND (attr3 < 1) AND (attr4 <= 3) AND (attr5 > 10) AND " + - s"(attr6 >= 15) AND (NOT (attr7 = '1')) AND ((attr8 < 10) AND (attr9 <= 3)) AND " + - s"((attr10 = 'hello') OR (attr11 >= 13)) AND (attr12 LIKE '%pinot%') AND (attr13 IN (10, 20)) AND " + - s"(NOT (attr20 != '123' OR attr20 IS NULL OR '123' IS NULL) OR (attr20 IS NULL AND '123' IS NULL)) AND " + - s"(attr14 IS NULL) AND (attr15 IS NOT NULL) AND (attr16 LIKE 'pinot1%') AND (attr17 LIKE '%pinot2') AND " + - s"(attr18 = '2020-01-01 00:00:15.0') AND (attr19 < '2020-01-01') AND (attr21 = List(1, 2)) AND " + - s"(attr22 = 10.5)" + s"""("attr1" = 1) AND ("attr2" IN ('1', '2', '''5''')) AND ("attr3" < 1) AND ("attr4" <= 3) AND ("attr5" > 10) AND """ + + s"""("attr6" >= 15) AND (NOT ("attr7" = '1')) AND (("attr8" < 10) AND ("attr9" <= 3)) AND """ + + s"""(("attr10" = 'hello') OR ("attr11" >= 13)) AND ("attr12" LIKE '%pinot%') AND ("attr13" IN (10, 20)) AND """ + + s"""(NOT ("attr20" != '123' OR "attr20" IS NULL OR '123' IS NULL) OR ("attr20" IS NULL AND '123' IS NULL)) AND """ + + s"""("attr14" IS NULL) AND ("attr15" IS NOT NULL) AND ("attr16" LIKE 'pinot1%') AND ("attr17" LIKE '%pinot2') AND """ + + s"""("attr18" = '2020-01-01 00:00:15.0') AND ("attr19" < '2020-01-01') AND ("attr21" = List(1, 2)) AND """ + + s"""("attr22" = 10.5)""" whereClause.get shouldEqual expectedOutput } + test("Shouldn't escape column names which are already escaped") { + val whereClause = FilterPushDown.compileFiltersToSqlWhereClause( + Array(EqualTo("\"some\".\"nested\".\"column\"", 1))) + val expectedOutput = "(\"some\".\"nested\".\"column\" = 1)" + + whereClause.get shouldEqual expectedOutput + } } diff --git a/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/PinotClusterClient.scala b/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/PinotClusterClient.scala index 52e86cbaf5..1c5dafe2a5 100644 --- a/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/PinotClusterClient.scala +++ b/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/PinotClusterClient.scala @@ -213,9 +213,9 @@ private[pinot] object PinotClusterClient extends Logging { private[pinot] case class TimeBoundaryInfo(timeColumn: String, timeValue: String) { - def getOfflinePredicate: String = s"$timeColumn < $timeValue" + def getOfflinePredicate: String = s""""$timeColumn" < $timeValue""" - def getRealtimePredicate: String = s"$timeColumn >= $timeValue" + def getRealtimePredicate: String = s""""$timeColumn" >= $timeValue""" } private[pinot] case class InstanceInfo(instanceName: String, diff --git a/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGenerator.scala b/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGenerator.scala index 4616bbd7c7..e6c1afb9c8 100644 --- a/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGenerator.scala +++ b/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGenerator.scala @@ -46,7 +46,11 @@ private[pinot] class ScanQueryGenerator( /** Get all columns if selecting columns empty(eg: resultDataFrame.count()) */ private def columnsAsExpression(): String = { - if (columns.isEmpty) "*" else columns.mkString(",") + if (columns.isEmpty) "*" else columns.map(escapeCol).mkString(",") + } + + private def escapeCol(col: String): String = { + if (col.contains("\"")) col else s""""$col"""" } /** Build realtime or offline SQL selection query. */ diff --git a/pinot-connectors/pinot-spark-common/src/test/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGeneratorTest.scala b/pinot-connectors/pinot-spark-common/src/test/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGeneratorTest.scala index 9be29e3c44..73fc9d8584 100644 --- a/pinot-connectors/pinot-spark-common/src/test/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGeneratorTest.scala +++ b/pinot-connectors/pinot-spark-common/src/test/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGeneratorTest.scala @@ -25,19 +25,19 @@ import org.apache.pinot.spi.config.table.TableType * Test SQL query generation from spark push down filters, selection columns etc. */ class ScanQueryGeneratorTest extends BaseTest { - private val columns = Array("c1, c2") + private val columns = Array("c1","c2") private val tableName = "tbl" private val tableType = Some(TableType.OFFLINE) private val whereClause = Some("c1 = 5 OR c2 = 'hello'") - private val limit = s"LIMIT ${Int.MaxValue}" + private val limit = s"""LIMIT ${Int.MaxValue}""" test("Queries should be created with given filters") { val pinotQueries = ScanQueryGenerator.generate(tableName, tableType, None, columns, whereClause, Set()) val expectedRealtimeQuery = - s"SELECT c1, c2 FROM ${tableName}_REALTIME WHERE ${whereClause.get} $limit" + s"""SELECT "c1","c2" FROM ${tableName}_REALTIME WHERE ${whereClause.get} $limit""" val expectedOfflineQuery = - s"SELECT c1, c2 FROM ${tableName}_OFFLINE WHERE ${whereClause.get} $limit" + s"""SELECT "c1","c2" FROM ${tableName}_OFFLINE WHERE ${whereClause.get} $limit""" pinotQueries.realtimeSelectQuery shouldEqual expectedRealtimeQuery pinotQueries.offlineSelectQuery shouldEqual expectedOfflineQuery @@ -48,12 +48,12 @@ class ScanQueryGeneratorTest extends BaseTest { val pinotQueries = ScanQueryGenerator .generate(tableName, tableType, Some(timeBoundaryInfo), columns, whereClause, Set()) - val realtimeWhereClause = s"${whereClause.get} AND timeCol >= 12345" - val offlineWhereClause = s"${whereClause.get} AND timeCol < 12345" + val realtimeWhereClause = s"""${whereClause.get} AND "timeCol" >= 12345""" + val offlineWhereClause = s"""${whereClause.get} AND "timeCol" < 12345""" val expectedRealtimeQuery = - s"SELECT c1, c2 FROM ${tableName}_REALTIME WHERE $realtimeWhereClause $limit" + s"""SELECT "c1","c2" FROM ${tableName}_REALTIME WHERE $realtimeWhereClause $limit""" val expectedOfflineQuery = - s"SELECT c1, c2 FROM ${tableName}_OFFLINE WHERE $offlineWhereClause $limit" + s"""SELECT "c1","c2" FROM ${tableName}_OFFLINE WHERE $offlineWhereClause $limit""" pinotQueries.realtimeSelectQuery shouldEqual expectedRealtimeQuery pinotQueries.offlineSelectQuery shouldEqual expectedOfflineQuery @@ -64,12 +64,12 @@ class ScanQueryGeneratorTest extends BaseTest { val pinotQueries = ScanQueryGenerator .generate(tableName, tableType, Some(timeBoundaryInfo), columns, None, Set()) - val realtimeWhereClause = s"timeCol >= 12345" - val offlineWhereClause = s"timeCol < 12345" + val realtimeWhereClause = s""""timeCol" >= 12345""" + val offlineWhereClause = s""""timeCol" < 12345""" val expectedRealtimeQuery = - s"SELECT c1, c2 FROM ${tableName}_REALTIME WHERE $realtimeWhereClause $limit" + s"""SELECT "c1","c2" FROM ${tableName}_REALTIME WHERE $realtimeWhereClause $limit""" val expectedOfflineQuery = - s"SELECT c1, c2 FROM ${tableName}_OFFLINE WHERE $offlineWhereClause $limit" + s"""SELECT "c1","c2" FROM ${tableName}_OFFLINE WHERE $offlineWhereClause $limit""" pinotQueries.realtimeSelectQuery shouldEqual expectedRealtimeQuery pinotQueries.offlineSelectQuery shouldEqual expectedOfflineQuery --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org