This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cb0db21
[SPARK-25556][SPARK-17636][SPARK-31026][SPARK-31060][SQL][TEST-HIVE1.2] Nested
Column Predicate Pushdown for Parquet
cb0db21 is described below
commit cb0db213736de5c5c02b09a2d5c3e17254708ce1
Author: DB Tsai <[email protected]>
AuthorDate: Fri Mar 27 14:28:57 2020 +0800
[SPARK-25556][SPARK-17636][SPARK-31026][SPARK-31060][SQL][TEST-HIVE1.2]
Nested Column Predicate Pushdown for Parquet
### What changes were proposed in this pull request?
1. `DataSourceStrategy.scala` is extended to create
`org.apache.spark.sql.sources.Filter` from nested expressions.
2. Translation from nested `org.apache.spark.sql.sources.Filter` to
`org.apache.parquet.filter2.predicate.FilterPredicate` is implemented to
support nested predicate pushdown for Parquet.
### Why are the changes needed?
Better performance for handling nested predicate pushdown.
### Does this PR introduce any user-facing change?
No
### How was this patch tested?
New tests are added.
Closes #27728 from dbtsai/SPARK-17636.
Authored-by: DB Tsai <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/connector/catalog/CatalogV2Implicits.scala | 8 +
.../org/apache/spark/sql/internal/SQLConf.scala | 13 +
.../org/apache/spark/sql/sources/filters.scala | 60 +-
.../execution/datasources/DataSourceStrategy.scala | 15 +-
.../datasources/parquet/ParquetFilters.scala | 79 ++-
.../datasources/v2/orc/OrcScanBuilder.scala | 7 +-
.../datasources/DataSourceStrategySuite.scala | 54 +-
.../datasources/parquet/ParquetFilterSuite.scala | 754 ++++++++++++---------
.../datasources/parquet/ParquetIOSuite.scala | 20 +-
.../datasources/parquet/ParquetTest.scala | 12 +-
.../apache/spark/sql/sources/FiltersSuite.scala | 173 +++--
.../sql/execution/datasources/orc/OrcFilters.scala | 65 +-
.../sql/execution/datasources/orc/OrcFilters.scala | 65 +-
.../org/apache/spark/sql/hive/orc/OrcFilters.scala | 7 +-
14 files changed, 852 insertions(+), 480 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
index 71bab62..d90804f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
@@ -20,7 +20,9 @@ package org.apache.spark.sql.connector.catalog
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.connector.expressions.{BucketTransform,
IdentityTransform, LogicalExpressions, Transform}
+import org.apache.spark.sql.internal.SQLConf
/**
* Conversion helpers for working with v2 [[CatalogPlugin]].
@@ -132,4 +134,10 @@ private[sql] object CatalogV2Implicits {
part
}
}
+
+ private lazy val catalystSqlParser = new CatalystSqlParser(SQLConf.get)
+
+ def parseColumnPath(name: String): Seq[String] = {
+ catalystSqlParser.parseMultipartIdentifier(name)
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 9065bd8..90a889a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2062,6 +2062,17 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val NESTED_PREDICATE_PUSHDOWN_ENABLED =
+ buildConf("spark.sql.optimizer.nestedPredicatePushdown.enabled")
+ .internal()
+ .doc("When true, Spark tries to push down predicates for nested columns
and or names " +
+ "containing `dots` to data sources. Currently, Parquet implements both
optimizations " +
+ "while ORC only supports predicates for names containing `dots`. The
other data sources" +
+ "don't support this feature yet.")
+ .version("3.0.0")
+ .booleanConf
+ .createWithDefault(true)
+
val SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED =
buildConf("spark.sql.optimizer.serializer.nestedSchemaPruning.enabled")
.internal()
@@ -3048,6 +3059,8 @@ class SQLConf extends Serializable with Logging {
def nestedSchemaPruningEnabled: Boolean =
getConf(NESTED_SCHEMA_PRUNING_ENABLED)
+ def nestedPredicatePushdownEnabled: Boolean =
getConf(NESTED_PREDICATE_PUSHDOWN_ENABLED)
+
def serializerNestedSchemaPruningEnabled: Boolean =
getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala
index 020dd79..319073e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.sources
import org.apache.spark.annotation.{Evolving, Stable}
+import
org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines all the filters that we can push down to the data sources.
@@ -32,6 +33,10 @@ import org.apache.spark.annotation.{Evolving, Stable}
sealed abstract class Filter {
/**
* List of columns that are referenced by this filter.
+ *
+ * Note that, each element in `references` represents a column; `dots` are
used as separators
+ * for nested columns. If any part of the names contains `dots`, it is
quoted to avoid confusion.
+ *
* @since 2.1.0
*/
def references: Array[String]
@@ -40,12 +45,32 @@ sealed abstract class Filter {
case f: Filter => f.references
case _ => Array.empty
}
+
+ /**
+ * List of columns that are referenced by this filter.
+ *
+ * @return each element is a column name as an array of string
multi-identifier
+ * @since 3.0.0
+ */
+ def v2references: Array[Array[String]] = {
+ this.references.map(parseColumnPath(_).toArray)
+ }
+
+ /**
+ * If any of the references of this filter contains nested column
+ */
+ private[sql] def containsNestedColumn: Boolean = {
+ this.v2references.exists(_.length > 1)
+ }
}
/**
- * A filter that evaluates to `true` iff the attribute evaluates to a value
+ * A filter that evaluates to `true` iff the column evaluates to a value
* equal to `value`.
*
+ * @param attribute of the column to be evaluated; `dots` are used as
separators
+ * for nested columns. If any part of the names contains
`dots`,
+ * it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
@@ -58,6 +83,9 @@ case class EqualTo(attribute: String, value: Any) extends
Filter {
* in that it returns `true` (rather than NULL) if both inputs are NULL, and
`false`
* (rather than NULL) if one of the input is NULL and the other is not NULL.
*
+ * @param attribute of the column to be evaluated; `dots` are used as
separators
+ * for nested columns. If any part of the names contains
`dots`,
+ * it is quoted to avoid confusion.
* @since 1.5.0
*/
@Stable
@@ -69,6 +97,9 @@ case class EqualNullSafe(attribute: String, value: Any)
extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to a value
* greater than `value`.
*
+ * @param attribute of the column to be evaluated; `dots` are used as
separators
+ * for nested columns. If any part of the names contains
`dots`,
+ * it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
@@ -80,6 +111,9 @@ case class GreaterThan(attribute: String, value: Any)
extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to a value
* greater than or equal to `value`.
*
+ * @param attribute of the column to be evaluated; `dots` are used as
separators
+ * for nested columns. If any part of the names contains
`dots`,
+ * it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
@@ -91,6 +125,9 @@ case class GreaterThanOrEqual(attribute: String, value: Any)
extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to a value
* less than `value`.
*
+ * @param attribute of the column to be evaluated; `dots` are used as
separators
+ * for nested columns. If any part of the names contains
`dots`,
+ * it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
@@ -102,6 +139,9 @@ case class LessThan(attribute: String, value: Any) extends
Filter {
* A filter that evaluates to `true` iff the attribute evaluates to a value
* less than or equal to `value`.
*
+ * @param attribute of the column to be evaluated; `dots` are used as
separators
+ * for nested columns. If any part of the names contains
`dots`,
+ * it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
@@ -112,6 +152,9 @@ case class LessThanOrEqual(attribute: String, value: Any)
extends Filter {
/**
* A filter that evaluates to `true` iff the attribute evaluates to one of the
values in the array.
*
+ * @param attribute of the column to be evaluated; `dots` are used as
separators
+ * for nested columns. If any part of the names contains
`dots`,
+ * it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
@@ -139,6 +182,9 @@ case class In(attribute: String, values: Array[Any])
extends Filter {
/**
* A filter that evaluates to `true` iff the attribute evaluates to null.
*
+ * @param attribute of the column to be evaluated; `dots` are used as
separators
+ * for nested columns. If any part of the names contains
`dots`,
+ * it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
@@ -149,6 +195,9 @@ case class IsNull(attribute: String) extends Filter {
/**
* A filter that evaluates to `true` iff the attribute evaluates to a non-null
value.
*
+ * @param attribute of the column to be evaluated; `dots` are used as
separators
+ * for nested columns. If any part of the names contains
`dots`,
+ * it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
@@ -190,6 +239,9 @@ case class Not(child: Filter) extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to
* a string that starts with `value`.
*
+ * @param attribute of the column to be evaluated; `dots` are used as
separators
+ * for nested columns. If any part of the names contains
`dots`,
+ * it is quoted to avoid confusion.
* @since 1.3.1
*/
@Stable
@@ -201,6 +253,9 @@ case class StringStartsWith(attribute: String, value:
String) extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to
* a string that ends with `value`.
*
+ * @param attribute of the column to be evaluated; `dots` are used as
separators
+ * for nested columns. If any part of the names contains
`dots`,
+ * it is quoted to avoid confusion.
* @since 1.3.1
*/
@Stable
@@ -212,6 +267,9 @@ case class StringEndsWith(attribute: String, value: String)
extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to
* a string that contains the string `value`.
*
+ * @param attribute of the column to be evaluated; `dots` are used as
separators
+ * for nested columns. If any part of the names contains
`dots`,
+ * it is quoted to avoid confusion.
* @since 1.3.1
*/
@Stable
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 1641b66..faf3760 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -652,10 +652,19 @@ object DataSourceStrategy {
*/
object PushableColumn {
def unapply(e: Expression): Option[String] = {
- def helper(e: Expression) = e match {
- case a: Attribute => Some(a.name)
+ val nestedPredicatePushdownEnabled =
SQLConf.get.nestedPredicatePushdownEnabled
+ import
org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
+ def helper(e: Expression): Option[Seq[String]] = e match {
+ case a: Attribute =>
+ if (nestedPredicatePushdownEnabled || !a.name.contains(".")) {
+ Some(Seq(a.name))
+ } else {
+ None
+ }
+ case s: GetStructField if nestedPredicatePushdownEnabled =>
+ helper(s.child).map(_ :+ s.childSchema(s.ordinal).name)
case _ => None
}
- helper(e)
+ helper(e).map(_.quoted)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index 0706501..f206f59 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -27,7 +27,7 @@ import scala.collection.JavaConverters.asScalaBufferConverter
import org.apache.parquet.filter2.predicate._
import org.apache.parquet.filter2.predicate.SparkFilterApi._
import org.apache.parquet.io.api.Binary
-import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType,
PrimitiveComparator}
+import org.apache.parquet.schema.{DecimalMetadata, GroupType, MessageType,
OriginalType, PrimitiveComparator, PrimitiveType, Type}
import org.apache.parquet.schema.OriginalType._
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
@@ -49,15 +49,35 @@ class ParquetFilters(
pushDownInFilterThreshold: Int,
caseSensitive: Boolean) {
// A map which contains parquet field name and data type, if predicate push
down applies.
- private val nameToParquetField : Map[String, ParquetField] = {
- // Here we don't flatten the fields in the nested schema but just look up
through
- // root fields. Currently, accessing to nested fields does not push down
filters
- // and it does not support to create filters for them.
- val primitiveFields =
-
schema.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f
=>
- f.getName -> ParquetField(f.getName,
- ParquetSchemaType(f.getOriginalType,
- f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata))
+ //
+ // Each key in `nameToParquetField` represents a column; `dots` are used as
separators for
+ // nested columns. If any part of the names contains `dots`, it is quoted to
avoid confusion.
+ // See `org.apache.spark.sql.connector.catalog.quote` for implementation
details.
+ private val nameToParquetField : Map[String, ParquetPrimitiveField] = {
+ // Recursively traverse the parquet schema to get primitive fields that
can be pushed-down.
+ // `parentFieldNames` is used to keep track of the current nested level
when traversing.
+ def getPrimitiveFields(
+ fields: Seq[Type],
+ parentFieldNames: Array[String] = Array.empty):
Seq[ParquetPrimitiveField] = {
+ fields.flatMap {
+ case p: PrimitiveType =>
+ Some(ParquetPrimitiveField(fieldNames = parentFieldNames :+
p.getName,
+ fieldType = ParquetSchemaType(p.getOriginalType,
+ p.getPrimitiveTypeName, p.getTypeLength, p.getDecimalMetadata)))
+ // Note that when g is a `Struct`, `g.getOriginalType` is `null`.
+ // When g is a `Map`, `g.getOriginalType` is `MAP`.
+ // When g is a `List`, `g.getOriginalType` is `LIST`.
+ case g: GroupType if g.getOriginalType == null =>
+ getPrimitiveFields(g.getFields.asScala, parentFieldNames :+
g.getName)
+ // Parquet only supports push-down for primitive types; as a result,
Map and List types
+ // are removed.
+ case _ => None
+ }
+ }
+
+ val primitiveFields = getPrimitiveFields(schema.getFields.asScala).map {
field =>
+ import
org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
+ (field.fieldNames.toSeq.quoted, field)
}
if (caseSensitive) {
primitiveFields.toMap
@@ -75,13 +95,13 @@ class ParquetFilters(
}
/**
- * Holds a single field information stored in the underlying parquet file.
+ * Holds a single primitive field information stored in the underlying
parquet file.
*
- * @param fieldName field name in parquet file
+ * @param fieldNames a field name as an array of string multi-identifier in
parquet file
* @param fieldType field type related info in parquet file
*/
- private case class ParquetField(
- fieldName: String,
+ private case class ParquetPrimitiveField(
+ fieldNames: Array[String],
fieldType: ParquetSchemaType)
private case class ParquetSchemaType(
@@ -472,13 +492,8 @@ class ParquetFilters(
case _ => false
}
- // Parquet does not allow dots in the column name because dots are used as a
column path
- // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter
predicates
- // with missing columns. The incorrect results could be got from Parquet
when we push down
- // filters for the column having dots in the names. Thus, we do not push
down such filters.
- // See SPARK-20364.
private def canMakeFilterOn(name: String, value: Any): Boolean = {
- nameToParquetField.contains(name) && !name.contains(".") &&
valueCanMakeFilterOn(name, value)
+ nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value)
}
/**
@@ -509,38 +524,38 @@ class ParquetFilters(
predicate match {
case sources.IsNull(name) if canMakeFilterOn(name, null) =>
makeEq.lift(nameToParquetField(name).fieldType)
- .map(_(Array(nameToParquetField(name).fieldName), null))
+ .map(_(nameToParquetField(name).fieldNames, null))
case sources.IsNotNull(name) if canMakeFilterOn(name, null) =>
makeNotEq.lift(nameToParquetField(name).fieldType)
- .map(_(Array(nameToParquetField(name).fieldName), null))
+ .map(_(nameToParquetField(name).fieldNames, null))
case sources.EqualTo(name, value) if canMakeFilterOn(name, value) =>
makeEq.lift(nameToParquetField(name).fieldType)
- .map(_(Array(nameToParquetField(name).fieldName), value))
+ .map(_(nameToParquetField(name).fieldNames, value))
case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name,
value) =>
makeNotEq.lift(nameToParquetField(name).fieldType)
- .map(_(Array(nameToParquetField(name).fieldName), value))
+ .map(_(nameToParquetField(name).fieldNames, value))
case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value)
=>
makeEq.lift(nameToParquetField(name).fieldType)
- .map(_(Array(nameToParquetField(name).fieldName), value))
+ .map(_(nameToParquetField(name).fieldNames, value))
case sources.Not(sources.EqualNullSafe(name, value)) if
canMakeFilterOn(name, value) =>
makeNotEq.lift(nameToParquetField(name).fieldType)
- .map(_(Array(nameToParquetField(name).fieldName), value))
+ .map(_(nameToParquetField(name).fieldNames, value))
case sources.LessThan(name, value) if canMakeFilterOn(name, value) =>
makeLt.lift(nameToParquetField(name).fieldType)
- .map(_(Array(nameToParquetField(name).fieldName), value))
+ .map(_(nameToParquetField(name).fieldNames, value))
case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name,
value) =>
makeLtEq.lift(nameToParquetField(name).fieldType)
- .map(_(Array(nameToParquetField(name).fieldName), value))
+ .map(_(nameToParquetField(name).fieldNames, value))
case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) =>
makeGt.lift(nameToParquetField(name).fieldType)
- .map(_(Array(nameToParquetField(name).fieldName), value))
+ .map(_(nameToParquetField(name).fieldNames, value))
case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name,
value) =>
makeGtEq.lift(nameToParquetField(name).fieldType)
- .map(_(Array(nameToParquetField(name).fieldName), value))
+ .map(_(nameToParquetField(name).fieldNames, value))
case sources.And(lhs, rhs) =>
// At here, it is not safe to just convert one side and remove the
other side
@@ -591,13 +606,13 @@ class ParquetFilters(
&& values.distinct.length <= pushDownInFilterThreshold =>
values.distinct.flatMap { v =>
makeEq.lift(nameToParquetField(name).fieldType)
- .map(_(Array(nameToParquetField(name).fieldName), v))
+ .map(_(nameToParquetField(name).fieldNames, v))
}.reduceLeftOption(FilterApi.or)
case sources.StringStartsWith(name, prefix)
if pushDownStartWith && canMakeFilterOn(name, prefix) =>
Option(prefix).map { v =>
-
FilterApi.userDefined(binaryColumn(Array(nameToParquetField(name).fieldName)),
+
FilterApi.userDefined(binaryColumn(nameToParquetField(name).fieldNames),
new UserDefinedPredicate[Binary] with Serializable {
private val strToBinary = Binary.fromReusedByteArray(v.getBytes)
private val size = strToBinary.length
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
index 1421ffd..9f40f5f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
@@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.orc.mapreduce.OrcInputFormat
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.orc.OrcFilters
@@ -59,8 +60,10 @@ case class OrcScanBuilder(
// changed `hadoopConf` in executors.
OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames)
}
- val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap
- _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap,
filters).toArray
+ val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) ->
f.dataType).toMap
+ // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so
they are removed.
+ val newFilters = filters.filter(!_.containsNestedColumn)
+ _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap,
newFilters).toArray
}
filters
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
index 7bd3213..a775a97 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
@@ -26,15 +26,61 @@ import org.apache.spark.sql.types.{IntegerType, StringType,
StructField, StructT
class DataSourceStrategySuite extends PlanTest with SharedSparkSession {
val attrInts = Seq(
- 'cint.int
+ 'cint.int,
+ Symbol("c.int").int,
+ GetStructField('a.struct(StructType(
+ StructField("cstr", StringType, nullable = true) ::
+ StructField("cint", IntegerType, nullable = true) :: Nil)), 1, None),
+ GetStructField('a.struct(StructType(
+ StructField("c.int", IntegerType, nullable = true) ::
+ StructField("cstr", StringType, nullable = true) :: Nil)), 0, None),
+ GetStructField(Symbol("a.b").struct(StructType(
+ StructField("cstr1", StringType, nullable = true) ::
+ StructField("cstr2", StringType, nullable = true) ::
+ StructField("cint", IntegerType, nullable = true) :: Nil)), 2, None),
+ GetStructField(Symbol("a.b").struct(StructType(
+ StructField("c.int", IntegerType, nullable = true) :: Nil)), 0, None),
+ GetStructField(GetStructField('a.struct(StructType(
+ StructField("cstr1", StringType, nullable = true) ::
+ StructField("b", StructType(StructField("cint", IntegerType, nullable
= true) ::
+ StructField("cstr2", StringType, nullable = true) :: Nil)) :: Nil)),
1, None), 0, None)
).zip(Seq(
- "cint"
+ "cint",
+ "`c.int`", // single level field that contains `dot` in name
+ "a.cint", // two level nested field
+ "a.`c.int`", // two level nested field, and nested level contains `dot`
+ "`a.b`.cint", // two level nested field, and top level contains `dot`
+ "`a.b`.`c.int`", // two level nested field, and both levels contain `dot`
+ "a.b.cint" // three level nested field
))
val attrStrs = Seq(
- 'cstr.string
+ 'cstr.string,
+ Symbol("c.str").string,
+ GetStructField('a.struct(StructType(
+ StructField("cint", IntegerType, nullable = true) ::
+ StructField("cstr", StringType, nullable = true) :: Nil)), 1, None),
+ GetStructField('a.struct(StructType(
+ StructField("c.str", StringType, nullable = true) ::
+ StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None),
+ GetStructField(Symbol("a.b").struct(StructType(
+ StructField("cint1", IntegerType, nullable = true) ::
+ StructField("cint2", IntegerType, nullable = true) ::
+ StructField("cstr", StringType, nullable = true) :: Nil)), 2, None),
+ GetStructField(Symbol("a.b").struct(StructType(
+ StructField("c.str", StringType, nullable = true) :: Nil)), 0, None),
+ GetStructField(GetStructField('a.struct(StructType(
+ StructField("cint1", IntegerType, nullable = true) ::
+ StructField("b", StructType(StructField("cstr", StringType, nullable =
true) ::
+ StructField("cint2", IntegerType, nullable = true) :: Nil)) ::
Nil)), 1, None), 0, None)
).zip(Seq(
- "cstr"
+ "cstr",
+ "`c.str`", // single level field that contains `dot` in name
+ "a.cstr", // two level nested field
+ "a.`c.str`", // two level nested field, and nested level contains `dot`
+ "`a.b`.cstr", // two level nested field, and top level contains `dot`
+ "`a.b`.`c.str`", // two level nested field, and both levels contain `dot`
+ "a.b.cstr" // three level nested field
))
test("translate simple expression") { attrInts.zip(attrStrs)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 4e0c1c2..d1161e3 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -103,22 +103,42 @@ abstract class ParquetFilterSuite extends QueryTest with
ParquetTest with Shared
checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df)
}
- private def checkBinaryFilterPredicate
- (predicate: Predicate, filterClass: Class[_ <: FilterPredicate],
expected: Seq[Row])
- (implicit df: DataFrame): Unit = {
- def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = {
- assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted)
{
-
df.rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted
- }
+ /**
+ * Takes single level `inputDF` dataframe to generate multi-level nested
+ * dataframes as new test data.
+ */
+ private def withNestedDataFrame(inputDF: DataFrame)
+ (runTest: (DataFrame, String, Any => Any) => Unit): Unit = {
+ assert(inputDF.schema.fields.length == 1)
+ assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType])
+ val df = inputDF.toDF("temp")
+ Seq(
+ (
+ df.withColumnRenamed("temp", "a"),
+ "a", // zero nesting
+ (x: Any) => x),
+ (
+ df.withColumn("a", struct(df("temp") as "b")).drop("temp"),
+ "a.b", // one level nesting
+ (x: Any) => Row(x)),
+ (
+ df.withColumn("a", struct(struct(df("temp") as "c") as
"b")).drop("temp"),
+ "a.b.c", // two level nesting
+ (x: Any) => Row(Row(x))
+ ),
+ (
+ df.withColumnRenamed("temp", "a.b"),
+ "`a.b`", // zero nesting with column name containing `dots`
+ (x: Any) => x
+ ),
+ (
+ df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"),
+ "`a.b`.`c.d`", // one level nesting with column names containing `dots`
+ (x: Any) => Row(x)
+ )
+ ).foreach { case (df, colName, resultFun) =>
+ runTest(df, colName, resultFun)
}
-
- checkFilterPredicate(df, predicate, filterClass, checkBinaryAnswer _,
expected)
- }
-
- private def checkBinaryFilterPredicate
- (predicate: Predicate, filterClass: Class[_ <: FilterPredicate],
expected: Array[Byte])
- (implicit df: DataFrame): Unit = {
- checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df)
}
private def testTimestampPushdown(data: Seq[Timestamp]): Unit = {
@@ -128,36 +148,38 @@ abstract class ParquetFilterSuite extends QueryTest with
ParquetTest with Shared
val ts3 = data(2)
val ts4 = data(3)
- withParquetDataFrame(data.map(i => Tuple1(i))) { implicit df =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i =>
Row.apply(i)))
-
- checkFilterPredicate('_1 === ts1, classOf[Eq[_]], ts1)
- checkFilterPredicate('_1 <=> ts1, classOf[Eq[_]], ts1)
- checkFilterPredicate('_1 =!= ts1, classOf[NotEq[_]],
- Seq(ts2, ts3, ts4).map(i => Row.apply(i)))
-
- checkFilterPredicate('_1 < ts2, classOf[Lt[_]], ts1)
- checkFilterPredicate('_1 > ts1, classOf[Gt[_]], Seq(ts2, ts3, ts4).map(i
=> Row.apply(i)))
- checkFilterPredicate('_1 <= ts1, classOf[LtEq[_]], ts1)
- checkFilterPredicate('_1 >= ts4, classOf[GtEq[_]], ts4)
-
- checkFilterPredicate(Literal(ts1) === '_1, classOf[Eq[_]], ts1)
- checkFilterPredicate(Literal(ts1) <=> '_1, classOf[Eq[_]], ts1)
- checkFilterPredicate(Literal(ts2) > '_1, classOf[Lt[_]], ts1)
- checkFilterPredicate(Literal(ts3) < '_1, classOf[Gt[_]], ts4)
- checkFilterPredicate(Literal(ts1) >= '_1, classOf[LtEq[_]], ts1)
- checkFilterPredicate(Literal(ts4) <= '_1, classOf[GtEq[_]], ts4)
-
- checkFilterPredicate(!('_1 < ts4), classOf[GtEq[_]], ts4)
- checkFilterPredicate('_1 < ts2 || '_1 > ts3, classOf[Operators.Or],
Seq(Row(ts1), Row(ts4)))
- }
- }
-
- private def testDecimalPushDown(data: DataFrame)(f: DataFrame => Unit): Unit
= {
- withTempPath { file =>
- data.write.parquet(file.getCanonicalPath)
- readParquetFile(file.toString)(f)
+ import testImplicits._
+ withNestedDataFrame(data.map(i => Tuple1(i)).toDF()) { case (inputDF,
colName, resultFun) =>
+ withParquetDataFrame(inputDF) { implicit df =>
+ val tsAttr = df(colName).expr
+ assert(df(colName).expr.dataType === TimestampType)
+
+ checkFilterPredicate(tsAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]],
+ data.map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(tsAttr === ts1, classOf[Eq[_]], resultFun(ts1))
+ checkFilterPredicate(tsAttr <=> ts1, classOf[Eq[_]], resultFun(ts1))
+ checkFilterPredicate(tsAttr =!= ts1, classOf[NotEq[_]],
+ Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(tsAttr < ts2, classOf[Lt[_]], resultFun(ts1))
+ checkFilterPredicate(tsAttr > ts1, classOf[Gt[_]],
+ Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i))))
+ checkFilterPredicate(tsAttr <= ts1, classOf[LtEq[_]], resultFun(ts1))
+ checkFilterPredicate(tsAttr >= ts4, classOf[GtEq[_]], resultFun(ts4))
+
+ checkFilterPredicate(Literal(ts1) === tsAttr, classOf[Eq[_]],
resultFun(ts1))
+ checkFilterPredicate(Literal(ts1) <=> tsAttr, classOf[Eq[_]],
resultFun(ts1))
+ checkFilterPredicate(Literal(ts2) > tsAttr, classOf[Lt[_]],
resultFun(ts1))
+ checkFilterPredicate(Literal(ts3) < tsAttr, classOf[Gt[_]],
resultFun(ts4))
+ checkFilterPredicate(Literal(ts1) >= tsAttr, classOf[LtEq[_]],
resultFun(ts1))
+ checkFilterPredicate(Literal(ts4) <= tsAttr, classOf[GtEq[_]],
resultFun(ts4))
+
+ checkFilterPredicate(!(tsAttr < ts4), classOf[GtEq[_]], resultFun(ts4))
+ checkFilterPredicate(tsAttr < ts2 || tsAttr > ts3,
classOf[Operators.Or],
+ Seq(Row(resultFun(ts1)), Row(resultFun(ts4))))
+ }
}
}
@@ -187,201 +209,273 @@ abstract class ParquetFilterSuite extends QueryTest
with ParquetTest with Shared
}
test("filter pushdown - boolean") {
- withParquetDataFrame((true :: false :: Nil).map(b =>
Tuple1.apply(Option(b)))) { implicit df =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true),
Row(false)))
-
- checkFilterPredicate('_1 === true, classOf[Eq[_]], true)
- checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true)
- checkFilterPredicate('_1 =!= true, classOf[NotEq[_]], false)
+ val data = (true :: false :: Nil).map(b => Tuple1.apply(Option(b)))
+ import testImplicits._
+ withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
+ withParquetDataFrame(inputDF) { implicit df =>
+ val booleanAttr = df(colName).expr
+ assert(df(colName).expr.dataType === BooleanType)
+
+ checkFilterPredicate(booleanAttr.isNull, classOf[Eq[_]],
Seq.empty[Row])
+ checkFilterPredicate(booleanAttr.isNotNull, classOf[NotEq[_]],
+ Seq(Row(resultFun(true)), Row(resultFun(false))))
+
+ checkFilterPredicate(booleanAttr === true, classOf[Eq[_]],
resultFun(true))
+ checkFilterPredicate(booleanAttr <=> true, classOf[Eq[_]],
resultFun(true))
+ checkFilterPredicate(booleanAttr =!= true, classOf[NotEq[_]],
resultFun(false))
+ }
}
}
test("filter pushdown - tinyint") {
- withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toByte)))) {
implicit df =>
- assert(df.schema.head.dataType === ByteType)
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('_1 === 1.toByte, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 <=> 1.toByte, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 =!= 1.toByte, classOf[NotEq[_]], (2 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('_1 < 2.toByte, classOf[Lt[_]], 1)
- checkFilterPredicate('_1 > 3.toByte, classOf[Gt[_]], 4)
- checkFilterPredicate('_1 <= 1.toByte, classOf[LtEq[_]], 1)
- checkFilterPredicate('_1 >= 4.toByte, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(Literal(1.toByte) === '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(1.toByte) <=> '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(2.toByte) > '_1, classOf[Lt[_]], 1)
- checkFilterPredicate(Literal(3.toByte) < '_1, classOf[Gt[_]], 4)
- checkFilterPredicate(Literal(1.toByte) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4.toByte) <= '_1, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(!('_1 < 4.toByte), classOf[GtEq[_]], 4)
- checkFilterPredicate('_1 < 2.toByte || '_1 > 3.toByte,
- classOf[Operators.Or], Seq(Row(1), Row(4)))
+ val data = (1 to 4).map(i => Tuple1(Option(i.toByte)))
+ import testImplicits._
+ withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
+ withParquetDataFrame(inputDF) { implicit df =>
+ val tinyIntAttr = df(colName).expr
+ assert(df(colName).expr.dataType === ByteType)
+
+ checkFilterPredicate(tinyIntAttr.isNull, classOf[Eq[_]],
Seq.empty[Row])
+ checkFilterPredicate(tinyIntAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(tinyIntAttr === 1.toByte, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(tinyIntAttr <=> 1.toByte, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(tinyIntAttr =!= 1.toByte, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(tinyIntAttr < 2.toByte, classOf[Lt[_]],
resultFun(1))
+ checkFilterPredicate(tinyIntAttr > 3.toByte, classOf[Gt[_]],
resultFun(4))
+ checkFilterPredicate(tinyIntAttr <= 1.toByte, classOf[LtEq[_]],
resultFun(1))
+ checkFilterPredicate(tinyIntAttr >= 4.toByte, classOf[GtEq[_]],
resultFun(4))
+
+ checkFilterPredicate(Literal(1.toByte) === tinyIntAttr,
classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(1.toByte) <=> tinyIntAttr,
classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(2.toByte) > tinyIntAttr, classOf[Lt[_]],
resultFun(1))
+ checkFilterPredicate(Literal(3.toByte) < tinyIntAttr, classOf[Gt[_]],
resultFun(4))
+ checkFilterPredicate(Literal(1.toByte) >= tinyIntAttr,
classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(Literal(4.toByte) <= tinyIntAttr,
classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(!(tinyIntAttr < 4.toByte), classOf[GtEq[_]],
resultFun(4))
+ checkFilterPredicate(tinyIntAttr < 2.toByte || tinyIntAttr > 3.toByte,
+ classOf[Operators.Or], Seq(Row(resultFun(1)), Row(resultFun(4))))
+ }
}
}
test("filter pushdown - smallint") {
- withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) {
implicit df =>
- assert(df.schema.head.dataType === ShortType)
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('_1 === 1.toShort, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 <=> 1.toShort, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 =!= 1.toShort, classOf[NotEq[_]], (2 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('_1 < 2.toShort, classOf[Lt[_]], 1)
- checkFilterPredicate('_1 > 3.toShort, classOf[Gt[_]], 4)
- checkFilterPredicate('_1 <= 1.toShort, classOf[LtEq[_]], 1)
- checkFilterPredicate('_1 >= 4.toShort, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(Literal(1.toShort) === '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(1.toShort) <=> '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(2.toShort) > '_1, classOf[Lt[_]], 1)
- checkFilterPredicate(Literal(3.toShort) < '_1, classOf[Gt[_]], 4)
- checkFilterPredicate(Literal(1.toShort) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4.toShort) <= '_1, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(!('_1 < 4.toShort), classOf[GtEq[_]], 4)
- checkFilterPredicate('_1 < 2.toShort || '_1 > 3.toShort,
- classOf[Operators.Or], Seq(Row(1), Row(4)))
+ val data = (1 to 4).map(i => Tuple1(Option(i.toShort)))
+ import testImplicits._
+ withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
+ withParquetDataFrame(inputDF) { implicit df =>
+ val smallIntAttr = df(colName).expr
+ assert(df(colName).expr.dataType === ShortType)
+
+ checkFilterPredicate(smallIntAttr.isNull, classOf[Eq[_]],
Seq.empty[Row])
+ checkFilterPredicate(smallIntAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(smallIntAttr === 1.toShort, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(smallIntAttr <=> 1.toShort, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(smallIntAttr =!= 1.toShort, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(smallIntAttr < 2.toShort, classOf[Lt[_]],
resultFun(1))
+ checkFilterPredicate(smallIntAttr > 3.toShort, classOf[Gt[_]],
resultFun(4))
+ checkFilterPredicate(smallIntAttr <= 1.toShort, classOf[LtEq[_]],
resultFun(1))
+ checkFilterPredicate(smallIntAttr >= 4.toShort, classOf[GtEq[_]],
resultFun(4))
+
+ checkFilterPredicate(Literal(1.toShort) === smallIntAttr,
classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(1.toShort) <=> smallIntAttr,
classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(2.toShort) > smallIntAttr,
classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(Literal(3.toShort) < smallIntAttr,
classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(Literal(1.toShort) >= smallIntAttr,
classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(Literal(4.toShort) <= smallIntAttr,
classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(!(smallIntAttr < 4.toShort), classOf[GtEq[_]],
resultFun(4))
+ checkFilterPredicate(smallIntAttr < 2.toShort || smallIntAttr >
3.toShort,
+ classOf[Operators.Or], Seq(Row(resultFun(1)), Row(resultFun(4))))
+ }
}
}
test("filter pushdown - integer") {
- withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
- checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
- checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
- checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or],
Seq(Row(1), Row(4)))
+ val data = (1 to 4).map(i => Tuple1(Option(i)))
+ import testImplicits._
+ withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
+ withParquetDataFrame(inputDF) { implicit df =>
+ val intAttr = df(colName).expr
+ assert(df(colName).expr.dataType === IntegerType)
+
+ checkFilterPredicate(intAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(intAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(intAttr === 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(intAttr <=> 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(intAttr =!= 1, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(intAttr < 2, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(intAttr > 3, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(intAttr <= 1, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(intAttr >= 4, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(Literal(1) === intAttr, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(1) <=> intAttr, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(2) > intAttr, classOf[Lt[_]],
resultFun(1))
+ checkFilterPredicate(Literal(3) < intAttr, classOf[Gt[_]],
resultFun(4))
+ checkFilterPredicate(Literal(1) >= intAttr, classOf[LtEq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(4) <= intAttr, classOf[GtEq[_]],
resultFun(4))
+
+ checkFilterPredicate(!(intAttr < 4), classOf[GtEq[_]], resultFun(4))
+ checkFilterPredicate(intAttr < 2 || intAttr > 3, classOf[Operators.Or],
+ Seq(Row(resultFun(1)), Row(resultFun(4))))
+ }
}
}
test("filter pushdown - long") {
- withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) {
implicit df =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
- checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
- checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
- checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or],
Seq(Row(1), Row(4)))
+ val data = (1 to 4).map(i => Tuple1(Option(i.toLong)))
+ import testImplicits._
+ withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
+ withParquetDataFrame(inputDF) { implicit df =>
+ val longAttr = df(colName).expr
+ assert(df(colName).expr.dataType === LongType)
+
+ checkFilterPredicate(longAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(longAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(longAttr === 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(longAttr <=> 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(longAttr =!= 1, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(longAttr < 2, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(longAttr > 3, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(longAttr <= 1, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(longAttr >= 4, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(Literal(1) === longAttr, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(1) <=> longAttr, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(2) > longAttr, classOf[Lt[_]],
resultFun(1))
+ checkFilterPredicate(Literal(3) < longAttr, classOf[Gt[_]],
resultFun(4))
+ checkFilterPredicate(Literal(1) >= longAttr, classOf[LtEq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(4) <= longAttr, classOf[GtEq[_]],
resultFun(4))
+
+ checkFilterPredicate(!(longAttr < 4), classOf[GtEq[_]], resultFun(4))
+ checkFilterPredicate(longAttr < 2 || longAttr > 3,
classOf[Operators.Or],
+ Seq(Row(resultFun(1)), Row(resultFun(4))))
+ }
}
}
test("filter pushdown - float") {
- withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) {
implicit df =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
- checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
- checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
- checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or],
Seq(Row(1), Row(4)))
+ val data = (1 to 4).map(i => Tuple1(Option(i.toFloat)))
+ import testImplicits._
+ withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
+ withParquetDataFrame(inputDF) { implicit df =>
+ val floatAttr = df(colName).expr
+ assert(df(colName).expr.dataType === FloatType)
+
+ checkFilterPredicate(floatAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(floatAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(floatAttr === 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(floatAttr <=> 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(floatAttr =!= 1, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(floatAttr < 2, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(floatAttr > 3, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(floatAttr <= 1, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(floatAttr >= 4, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(Literal(1) === floatAttr, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(1) <=> floatAttr, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(2) > floatAttr, classOf[Lt[_]],
resultFun(1))
+ checkFilterPredicate(Literal(3) < floatAttr, classOf[Gt[_]],
resultFun(4))
+ checkFilterPredicate(Literal(1) >= floatAttr, classOf[LtEq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(4) <= floatAttr, classOf[GtEq[_]],
resultFun(4))
+
+ checkFilterPredicate(!(floatAttr < 4), classOf[GtEq[_]], resultFun(4))
+ checkFilterPredicate(floatAttr < 2 || floatAttr > 3,
classOf[Operators.Or],
+ Seq(Row(resultFun(1)), Row(resultFun(4))))
+ }
}
}
test("filter pushdown - double") {
- withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) {
implicit df =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
- checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
- checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
- checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or],
Seq(Row(1), Row(4)))
+ val data = (1 to 4).map(i => Tuple1(Option(i.toDouble)))
+ import testImplicits._
+ withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
+ withParquetDataFrame(inputDF) { implicit df =>
+ val doubleAttr = df(colName).expr
+ assert(df(colName).expr.dataType === DoubleType)
+
+ checkFilterPredicate(doubleAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(doubleAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(doubleAttr === 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(doubleAttr <=> 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(doubleAttr =!= 1, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(doubleAttr < 2, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(doubleAttr > 3, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(doubleAttr <= 1, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(doubleAttr >= 4, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(Literal(1) === doubleAttr, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(1) <=> doubleAttr, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(2) > doubleAttr, classOf[Lt[_]],
resultFun(1))
+ checkFilterPredicate(Literal(3) < doubleAttr, classOf[Gt[_]],
resultFun(4))
+ checkFilterPredicate(Literal(1) >= doubleAttr, classOf[LtEq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(4) <= doubleAttr, classOf[GtEq[_]],
resultFun(4))
+
+ checkFilterPredicate(!(doubleAttr < 4), classOf[GtEq[_]], resultFun(4))
+ checkFilterPredicate(doubleAttr < 2 || doubleAttr > 3,
classOf[Operators.Or],
+ Seq(Row(resultFun(1)), Row(resultFun(4))))
+ }
}
}
test("filter pushdown - string") {
- withParquetDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df
=>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(
- '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i =>
Row.apply(i.toString)))
-
- checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1")
- checkFilterPredicate('_1 <=> "1", classOf[Eq[_]], "1")
- checkFilterPredicate(
- '_1 =!= "1", classOf[NotEq[_]], (2 to 4).map(i =>
Row.apply(i.toString)))
-
- checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1")
- checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4")
- checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1")
- checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4")
-
- checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1")
- checkFilterPredicate(Literal("1") <=> '_1, classOf[Eq[_]], "1")
- checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1")
- checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4")
- checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1")
- checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4")
-
- checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4")
- checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or],
Seq(Row("1"), Row("4")))
+ val data = (1 to 4).map(i => Tuple1(Option(i.toString)))
+ import testImplicits._
+ withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
+ withParquetDataFrame(inputDF) { implicit df =>
+ val stringAttr = df(colName).expr
+ assert(df(colName).expr.dataType === StringType)
+
+ checkFilterPredicate(stringAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(stringAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i.toString))))
+
+ checkFilterPredicate(stringAttr === "1", classOf[Eq[_]],
resultFun("1"))
+ checkFilterPredicate(stringAttr <=> "1", classOf[Eq[_]],
resultFun("1"))
+ checkFilterPredicate(stringAttr =!= "1", classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i.toString))))
+
+ checkFilterPredicate(stringAttr < "2", classOf[Lt[_]], resultFun("1"))
+ checkFilterPredicate(stringAttr > "3", classOf[Gt[_]], resultFun("4"))
+ checkFilterPredicate(stringAttr <= "1", classOf[LtEq[_]],
resultFun("1"))
+ checkFilterPredicate(stringAttr >= "4", classOf[GtEq[_]],
resultFun("4"))
+
+ checkFilterPredicate(Literal("1") === stringAttr, classOf[Eq[_]],
resultFun("1"))
+ checkFilterPredicate(Literal("1") <=> stringAttr, classOf[Eq[_]],
resultFun("1"))
+ checkFilterPredicate(Literal("2") > stringAttr, classOf[Lt[_]],
resultFun("1"))
+ checkFilterPredicate(Literal("3") < stringAttr, classOf[Gt[_]],
resultFun("4"))
+ checkFilterPredicate(Literal("1") >= stringAttr, classOf[LtEq[_]],
resultFun("1"))
+ checkFilterPredicate(Literal("4") <= stringAttr, classOf[GtEq[_]],
resultFun("4"))
+
+ checkFilterPredicate(!(stringAttr < "4"), classOf[GtEq[_]],
resultFun("4"))
+ checkFilterPredicate(stringAttr < "2" || stringAttr > "3",
classOf[Operators.Or],
+ Seq(Row(resultFun("1")), Row(resultFun("4"))))
+ }
}
}
@@ -390,32 +484,39 @@ abstract class ParquetFilterSuite extends QueryTest with
ParquetTest with Shared
def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8)
}
- withParquetDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df =>
- checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b)
- checkBinaryFilterPredicate('_1 <=> 1.b, classOf[Eq[_]], 1.b)
-
- checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkBinaryFilterPredicate(
- '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i =>
Row.apply(i.b)).toSeq)
-
- checkBinaryFilterPredicate(
- '_1 =!= 1.b, classOf[NotEq[_]], (2 to 4).map(i =>
Row.apply(i.b)).toSeq)
-
- checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b)
- checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b)
- checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b)
- checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b)
-
- checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b)
- checkBinaryFilterPredicate(Literal(1.b) <=> '_1, classOf[Eq[_]], 1.b)
- checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b)
- checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b)
- checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b)
- checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b)
-
- checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b)
- checkBinaryFilterPredicate(
- '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b)))
+ val data = (1 to 4).map(i => Tuple1(Option(i.b)))
+ import testImplicits._
+ withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
+ withParquetDataFrame(inputDF) { implicit df =>
+ val binaryAttr: Expression = df(colName).expr
+ assert(df(colName).expr.dataType === BinaryType)
+
+ checkFilterPredicate(binaryAttr === 1.b, classOf[Eq[_]],
resultFun(1.b))
+ checkFilterPredicate(binaryAttr <=> 1.b, classOf[Eq[_]],
resultFun(1.b))
+
+ checkFilterPredicate(binaryAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(binaryAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i.b))))
+
+ checkFilterPredicate(binaryAttr =!= 1.b, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i.b))))
+
+ checkFilterPredicate(binaryAttr < 2.b, classOf[Lt[_]], resultFun(1.b))
+ checkFilterPredicate(binaryAttr > 3.b, classOf[Gt[_]], resultFun(4.b))
+ checkFilterPredicate(binaryAttr <= 1.b, classOf[LtEq[_]],
resultFun(1.b))
+ checkFilterPredicate(binaryAttr >= 4.b, classOf[GtEq[_]],
resultFun(4.b))
+
+ checkFilterPredicate(Literal(1.b) === binaryAttr, classOf[Eq[_]],
resultFun(1.b))
+ checkFilterPredicate(Literal(1.b) <=> binaryAttr, classOf[Eq[_]],
resultFun(1.b))
+ checkFilterPredicate(Literal(2.b) > binaryAttr, classOf[Lt[_]],
resultFun(1.b))
+ checkFilterPredicate(Literal(3.b) < binaryAttr, classOf[Gt[_]],
resultFun(4.b))
+ checkFilterPredicate(Literal(1.b) >= binaryAttr, classOf[LtEq[_]],
resultFun(1.b))
+ checkFilterPredicate(Literal(4.b) <= binaryAttr, classOf[GtEq[_]],
resultFun(4.b))
+
+ checkFilterPredicate(!(binaryAttr < 4.b), classOf[GtEq[_]],
resultFun(4.b))
+ checkFilterPredicate(binaryAttr < 2.b || binaryAttr > 3.b,
classOf[Operators.Or],
+ Seq(Row(resultFun(1.b)), Row(resultFun(4.b))))
+ }
}
}
@@ -424,40 +525,53 @@ abstract class ParquetFilterSuite extends QueryTest with
ParquetTest with Shared
def date: Date = Date.valueOf(s)
}
- val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21")
-
- withParquetDataFrame(data.map(i => Tuple1(i.date))) { implicit df =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i =>
Row.apply(i.date)))
-
- checkFilterPredicate('_1 === "2018-03-18".date, classOf[Eq[_]],
"2018-03-18".date)
- checkFilterPredicate('_1 <=> "2018-03-18".date, classOf[Eq[_]],
"2018-03-18".date)
- checkFilterPredicate('_1 =!= "2018-03-18".date, classOf[NotEq[_]],
- Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i =>
Row.apply(i.date)))
-
- checkFilterPredicate('_1 < "2018-03-19".date, classOf[Lt[_]],
"2018-03-18".date)
- checkFilterPredicate('_1 > "2018-03-20".date, classOf[Gt[_]],
"2018-03-21".date)
- checkFilterPredicate('_1 <= "2018-03-18".date, classOf[LtEq[_]],
"2018-03-18".date)
- checkFilterPredicate('_1 >= "2018-03-21".date, classOf[GtEq[_]],
"2018-03-21".date)
-
- checkFilterPredicate(
- Literal("2018-03-18".date) === '_1, classOf[Eq[_]], "2018-03-18".date)
- checkFilterPredicate(
- Literal("2018-03-18".date) <=> '_1, classOf[Eq[_]], "2018-03-18".date)
- checkFilterPredicate(
- Literal("2018-03-19".date) > '_1, classOf[Lt[_]], "2018-03-18".date)
- checkFilterPredicate(
- Literal("2018-03-20".date) < '_1, classOf[Gt[_]], "2018-03-21".date)
- checkFilterPredicate(
- Literal("2018-03-18".date) >= '_1, classOf[LtEq[_]], "2018-03-18".date)
- checkFilterPredicate(
- Literal("2018-03-21".date) <= '_1, classOf[GtEq[_]], "2018-03-21".date)
-
- checkFilterPredicate(!('_1 < "2018-03-21".date), classOf[GtEq[_]],
"2018-03-21".date)
- checkFilterPredicate(
- '_1 < "2018-03-19".date || '_1 > "2018-03-20".date,
- classOf[Operators.Or],
- Seq(Row("2018-03-18".date), Row("2018-03-21".date)))
+ val data = Seq("2018-03-18", "2018-03-19", "2018-03-20",
"2018-03-21").map(_.date)
+ import testImplicits._
+ withNestedDataFrame(data.map(i => Tuple1(i)).toDF()) { case (inputDF,
colName, resultFun) =>
+ withParquetDataFrame(inputDF) { implicit df =>
+ val dateAttr: Expression = df(colName).expr
+ assert(df(colName).expr.dataType === DateType)
+
+ checkFilterPredicate(dateAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(dateAttr.isNotNull, classOf[NotEq[_]],
+ data.map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(dateAttr === "2018-03-18".date, classOf[Eq[_]],
+ resultFun("2018-03-18".date))
+ checkFilterPredicate(dateAttr <=> "2018-03-18".date, classOf[Eq[_]],
+ resultFun("2018-03-18".date))
+ checkFilterPredicate(dateAttr =!= "2018-03-18".date, classOf[NotEq[_]],
+ Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i =>
Row.apply(resultFun(i.date))))
+
+ checkFilterPredicate(dateAttr < "2018-03-19".date, classOf[Lt[_]],
+ resultFun("2018-03-18".date))
+ checkFilterPredicate(dateAttr > "2018-03-20".date, classOf[Gt[_]],
+ resultFun("2018-03-21".date))
+ checkFilterPredicate(dateAttr <= "2018-03-18".date, classOf[LtEq[_]],
+ resultFun("2018-03-18".date))
+ checkFilterPredicate(dateAttr >= "2018-03-21".date, classOf[GtEq[_]],
+ resultFun("2018-03-21".date))
+
+ checkFilterPredicate(Literal("2018-03-18".date) === dateAttr,
classOf[Eq[_]],
+ resultFun("2018-03-18".date))
+ checkFilterPredicate(Literal("2018-03-18".date) <=> dateAttr,
classOf[Eq[_]],
+ resultFun("2018-03-18".date))
+ checkFilterPredicate(Literal("2018-03-19".date) > dateAttr,
classOf[Lt[_]],
+ resultFun("2018-03-18".date))
+ checkFilterPredicate(Literal("2018-03-20".date) < dateAttr,
classOf[Gt[_]],
+ resultFun("2018-03-21".date))
+ checkFilterPredicate(Literal("2018-03-18".date) >= dateAttr,
classOf[LtEq[_]],
+ resultFun("2018-03-18".date))
+ checkFilterPredicate(Literal("2018-03-21".date) <= dateAttr,
classOf[GtEq[_]],
+ resultFun("2018-03-21".date))
+
+ checkFilterPredicate(!(dateAttr < "2018-03-21".date), classOf[GtEq[_]],
+ resultFun("2018-03-21".date))
+ checkFilterPredicate(
+ dateAttr < "2018-03-19".date || dateAttr > "2018-03-20".date,
+ classOf[Operators.Or],
+ Seq(Row(resultFun("2018-03-18".date)),
Row(resultFun("2018-03-21".date))))
+ }
}
}
@@ -485,7 +599,8 @@ abstract class ParquetFilterSuite extends QueryTest with
ParquetTest with Shared
// spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key ->
ParquetOutputTimestampType.INT96.toString) {
- withParquetDataFrame(millisData.map(i => Tuple1(i))) { implicit df =>
+ import testImplicits._
+ withParquetDataFrame(millisData.map(i => Tuple1(i)).toDF()) { implicit
df =>
val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema)
assertResult(None) {
createParquetFilters(schema).createFilter(sources.IsNull("_1"))
@@ -502,33 +617,39 @@ abstract class ParquetFilterSuite extends QueryTest with
ParquetTest with Shared
(false, DecimalType.MAX_PRECISION) // binaryWriterUsingUnscaledBytes
).foreach { case (legacyFormat, precision) =>
withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key ->
legacyFormat.toString) {
- val schema = StructType.fromDDL(s"a decimal($precision, 2)")
val rdd =
spark.sparkContext.parallelize((1 to 4).map(i => Row(new
java.math.BigDecimal(i))))
- val dataFrame = spark.createDataFrame(rdd, schema)
- testDecimalPushDown(dataFrame) { implicit df =>
- assert(df.schema === schema)
- checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('a === 1, classOf[Eq[_]], 1)
- checkFilterPredicate('a <=> 1, classOf[Eq[_]], 1)
- checkFilterPredicate('a =!= 1, classOf[NotEq[_]], (2 to
4).map(Row.apply(_)))
-
- checkFilterPredicate('a < 2, classOf[Lt[_]], 1)
- checkFilterPredicate('a > 3, classOf[Gt[_]], 4)
- checkFilterPredicate('a <= 1, classOf[LtEq[_]], 1)
- checkFilterPredicate('a >= 4, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(Literal(1) === 'a, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(1) <=> 'a, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(2) > 'a, classOf[Lt[_]], 1)
- checkFilterPredicate(Literal(3) < 'a, classOf[Gt[_]], 4)
- checkFilterPredicate(Literal(1) >= 'a, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= 'a, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(!('a < 4), classOf[GtEq[_]], 4)
- checkFilterPredicate('a < 2 || 'a > 3, classOf[Operators.Or],
Seq(Row(1), Row(4)))
+ val dataFrame = spark.createDataFrame(rdd, StructType.fromDDL(s"a
decimal($precision, 2)"))
+ withNestedDataFrame(dataFrame) { case (inputDF, colName, resultFun) =>
+ withParquetDataFrame(inputDF) { implicit df =>
+ val decimalAttr: Expression = df(colName).expr
+ assert(df(colName).expr.dataType === DecimalType(precision, 2))
+
+ checkFilterPredicate(decimalAttr.isNull, classOf[Eq[_]],
Seq.empty[Row])
+ checkFilterPredicate(decimalAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(decimalAttr === 1, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(decimalAttr <=> 1, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(decimalAttr =!= 1, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(decimalAttr < 2, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(decimalAttr > 3, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(decimalAttr <= 1, classOf[LtEq[_]],
resultFun(1))
+ checkFilterPredicate(decimalAttr >= 4, classOf[GtEq[_]],
resultFun(4))
+
+ checkFilterPredicate(Literal(1) === decimalAttr, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(1) <=> decimalAttr, classOf[Eq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(2) > decimalAttr, classOf[Lt[_]],
resultFun(1))
+ checkFilterPredicate(Literal(3) < decimalAttr, classOf[Gt[_]],
resultFun(4))
+ checkFilterPredicate(Literal(1) >= decimalAttr, classOf[LtEq[_]],
resultFun(1))
+ checkFilterPredicate(Literal(4) <= decimalAttr, classOf[GtEq[_]],
resultFun(4))
+
+ checkFilterPredicate(!(decimalAttr < 4), classOf[GtEq[_]],
resultFun(4))
+ checkFilterPredicate(decimalAttr < 2 || decimalAttr > 3,
classOf[Operators.Or],
+ Seq(Row(resultFun(1)), Row(resultFun(4))))
+ }
}
}
}
@@ -1042,7 +1163,8 @@ abstract class ParquetFilterSuite extends QueryTest with
ParquetTest with Shared
}
test("SPARK-16371 Do not push down filters when inner name and outer name
are the same") {
- withParquetDataFrame((1 to 4).map(i => Tuple1(Tuple1(i)))) { implicit df =>
+ import testImplicits._
+ withParquetDataFrame((1 to 4).map(i => Tuple1(Tuple1(i))).toDF()) {
implicit df =>
// Here the schema becomes as below:
//
// root
@@ -1107,7 +1229,7 @@ abstract class ParquetFilterSuite extends QueryTest with
ParquetTest with Shared
}
}
- test("SPARK-20364: Disable Parquet predicate pushdown for fields having dots
in the names") {
+ test("SPARK-31026: Parquet predicate pushdown for fields having dots in the
names") {
import testImplicits._
Seq(true, false).foreach { vectorized =>
@@ -1120,6 +1242,28 @@ abstract class ParquetFilterSuite extends QueryTest with
ParquetTest with Shared
assert(readBack.count() == 1)
}
}
+
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key ->
vectorized.toString,
+ // Makes sure disabling 'spark.sql.parquet.recordFilter' still
enables
+ // row group level filtering.
+ SQLConf.PARQUET_RECORD_FILTER_ENABLED.key -> "false",
+ SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
+
+ withTempPath { path =>
+ val data = (1 to 1024)
+ data.toDF("col.dots").coalesce(1)
+ .write.option("parquet.block.size", 512)
+ .parquet(path.getAbsolutePath)
+ val df = spark.read.parquet(path.getAbsolutePath).filter("`col.dots`
== 500")
+ // Here, we strip the Spark side filter and check the actual results
from Parquet.
+ val actual = stripSparkFilter(df).collect().length
+ // Since those are filtered at row group level, the result count
should be less
+ // than the total length but should not be a single record.
+ // Note that, if record level filtering is enabled, it should be a
single record.
+ // If no filter is pushed down to Parquet, it should be the total
length of data.
+ assert(actual > 1 && actual < data.length)
+ }
+ }
}
}
@@ -1162,7 +1306,10 @@ abstract class ParquetFilterSuite extends QueryTest with
ParquetTest with Shared
}
test("filter pushdown - StringStartsWith") {
- withParquetDataFrame((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit
df =>
+ withParquetDataFrame {
+ import testImplicits._
+ (1 to 4).map(i => Tuple1(i + "str" + i)).toDF()
+ } { implicit df =>
checkFilterPredicate(
'_1.startsWith("").asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
@@ -1208,7 +1355,10 @@ abstract class ParquetFilterSuite extends QueryTest with
ParquetTest with Shared
}
// SPARK-28371: make sure filter is null-safe.
- withParquetDataFrame(Seq(Tuple1[String](null))) { implicit df =>
+ withParquetDataFrame {
+ import testImplicits._
+ Seq(Tuple1[String](null)).toDF()
+ } { implicit df =>
checkFilterPredicate(
'_1.startsWith("blah").asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 7f85fd2..497b823 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -82,7 +82,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with
SharedSparkSession
* Writes `data` to a Parquet file, reads it back and check file contents.
*/
protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data:
Seq[T]): Unit = {
- withParquetDataFrame(data)(r => checkAnswer(r, data.map(Row.fromTuple)))
+ withParquetDataFrame(data.toDF())(r => checkAnswer(r,
data.map(Row.fromTuple)))
}
test("basic data types (without binary)") {
@@ -94,7 +94,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with
SharedSparkSession
test("raw binary") {
val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte)))
- withParquetDataFrame(data) { df =>
+ withParquetDataFrame(data.toDF()) { df =>
assertResult(data.map(_._1.mkString(",")).sorted) {
df.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted
}
@@ -197,7 +197,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest
with SharedSparkSession
testStandardAndLegacyModes("struct") {
val data = (1 to 4).map(i => Tuple1((i, s"val_$i")))
- withParquetDataFrame(data) { df =>
+ withParquetDataFrame(data.toDF()) { df =>
// Structs are converted to `Row`s
checkAnswer(df, data.map { case Tuple1(struct) =>
Row(Row(struct.productIterator.toSeq: _*))
@@ -214,7 +214,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest
with SharedSparkSession
)
)
}
- withParquetDataFrame(data) { df =>
+ withParquetDataFrame(data.toDF()) { df =>
// Structs are converted to `Row`s
checkAnswer(df, data.map { case Tuple1(array) =>
Row(array.map(struct => Row(struct.productIterator.toSeq: _*)))
@@ -233,7 +233,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest
with SharedSparkSession
)
)
}
- withParquetDataFrame(data) { df =>
+ withParquetDataFrame(data.toDF()) { df =>
// Structs are converted to `Row`s
checkAnswer(df, data.map { case Tuple1(array) =>
Row(array.map { case Tuple1(Tuple1(str)) => Row(Row(str))})
@@ -243,7 +243,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest
with SharedSparkSession
testStandardAndLegacyModes("nested struct with array of array as field") {
val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i")))))
- withParquetDataFrame(data) { df =>
+ withParquetDataFrame(data.toDF()) { df =>
// Structs are converted to `Row`s
checkAnswer(df, data.map { case Tuple1(struct) =>
Row(Row(struct.productIterator.toSeq: _*))
@@ -260,7 +260,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest
with SharedSparkSession
)
)
}
- withParquetDataFrame(data) { df =>
+ withParquetDataFrame(data.toDF()) { df =>
// Structs are converted to `Row`s
checkAnswer(df, data.map { case Tuple1(m) =>
Row(m.map { case (k, v) => Row(k.productIterator.toSeq: _*) -> v })
@@ -277,7 +277,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest
with SharedSparkSession
)
)
}
- withParquetDataFrame(data) { df =>
+ withParquetDataFrame(data.toDF()) { df =>
// Structs are converted to `Row`s
checkAnswer(df, data.map { case Tuple1(m) =>
Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*)))
@@ -293,7 +293,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest
with SharedSparkSession
null.asInstanceOf[java.lang.Float],
null.asInstanceOf[java.lang.Double])
- withParquetDataFrame(allNulls :: Nil) { df =>
+ withParquetDataFrame((allNulls :: Nil).toDF()) { df =>
val rows = df.collect()
assert(rows.length === 1)
assert(rows.head === Row(Seq.fill(5)(null): _*))
@@ -306,7 +306,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest
with SharedSparkSession
None.asInstanceOf[Option[Long]],
None.asInstanceOf[Option[String]])
- withParquetDataFrame(allNones :: Nil) { df =>
+ withParquetDataFrame((allNones :: Nil).toDF()) { df =>
val rows = df.collect()
assert(rows.length === 1)
assert(rows.head === Row(Seq.fill(3)(null): _*))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index 828ba6a..f2dbc53 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -63,12 +63,16 @@ private[sql] trait ParquetTest extends
FileBasedDataSourceTest {
(f: String => Unit): Unit = withDataSourceFile(data)(f)
/**
- * Writes `data` to a Parquet file and reads it back as a [[DataFrame]],
+ * Writes `df` dataframe to a Parquet file and reads it back as a
[[DataFrame]],
* which is then passed to `f`. The Parquet file will be deleted after `f`
returns.
*/
- protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
- (data: Seq[T], testVectorized: Boolean = true)
- (f: DataFrame => Unit): Unit = withDataSourceDataFrame(data,
testVectorized)(f)
+ protected def withParquetDataFrame(df: DataFrame, testVectorized: Boolean =
true)
+ (f: DataFrame => Unit): Unit = {
+ withTempPath { file =>
+ df.write.format(dataSourceName).save(file.getCanonicalPath)
+ readFile(file.getCanonicalPath, testVectorized)(f)
+ }
+ }
/**
* Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and
registers it as a
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala
index 1cb7a21..33b2db5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala
@@ -24,66 +24,143 @@ import org.apache.spark.SparkFunSuite
*/
class FiltersSuite extends SparkFunSuite {
- test("EqualTo references") {
- assert(EqualTo("a", "1").references.toSeq == Seq("a"))
- assert(EqualTo("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b"))
+ private def withFieldNames(f: (String, Array[String]) => Unit): Unit = {
+ Seq(("a", Array("a")),
+ ("a.b", Array("a", "b")),
+ ("`a.b`.c", Array("a.b", "c")),
+ ("`a.b`.`c.d`.`e.f`", Array("a.b", "c.d", "e.f"))
+ ).foreach { case (name, fieldNames) =>
+ f(name, fieldNames)
+ }
}
- test("EqualNullSafe references") {
- assert(EqualNullSafe("a", "1").references.toSeq == Seq("a"))
- assert(EqualNullSafe("a", EqualTo("b", "2")).references.toSeq == Seq("a",
"b"))
- }
+ test("EqualTo references") { withFieldNames { (name, fieldNames) =>
+ assert(EqualTo(name, "1").references.toSeq == Seq(name))
+ assert(EqualTo(name, "1").v2references.toSeq.map(_.toSeq) ==
Seq(fieldNames.toSeq))
- test("GreaterThan references") {
- assert(GreaterThan("a", "1").references.toSeq == Seq("a"))
- assert(GreaterThan("a", EqualTo("b", "2")).references.toSeq == Seq("a",
"b"))
- }
+ assert(EqualTo(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b"))
+ assert(EqualTo("b", EqualTo(name, "2")).references.toSeq == Seq("b", name))
- test("GreaterThanOrEqual references") {
- assert(GreaterThanOrEqual("a", "1").references.toSeq == Seq("a"))
- assert(GreaterThanOrEqual("a", EqualTo("b", "2")).references.toSeq ==
Seq("a", "b"))
- }
+ assert(EqualTo(name, EqualTo("b", "2")).v2references.toSeq.map(_.toSeq)
+ == Seq(fieldNames.toSeq, Seq("b")))
+ assert(EqualTo("b", EqualTo(name, "2")).v2references.toSeq.map(_.toSeq)
+ == Seq(Seq("b"), fieldNames.toSeq))
+ }}
- test("LessThan references") {
- assert(LessThan("a", "1").references.toSeq == Seq("a"))
- assert(LessThan("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b"))
- }
+ test("EqualNullSafe references") { withFieldNames { (name, fieldNames) =>
+ assert(EqualNullSafe(name, "1").references.toSeq == Seq(name))
+ assert(EqualNullSafe(name, "1").v2references.toSeq.map(_.toSeq) ==
Seq(fieldNames.toSeq))
- test("LessThanOrEqual references") {
- assert(LessThanOrEqual("a", "1").references.toSeq == Seq("a"))
- assert(LessThanOrEqual("a", EqualTo("b", "2")).references.toSeq ==
Seq("a", "b"))
- }
+ assert(EqualNullSafe(name, EqualTo("b", "2")).references.toSeq ==
Seq(name, "b"))
+ assert(EqualNullSafe("b", EqualTo(name, "2")).references.toSeq == Seq("b",
name))
- test("In references") {
- assert(In("a", Array("1")).references.toSeq == Seq("a"))
- assert(In("a", Array("1", EqualTo("b", "2"))).references.toSeq == Seq("a",
"b"))
- }
+ assert(EqualNullSafe(name, EqualTo("b",
"2")).v2references.toSeq.map(_.toSeq)
+ == Seq(fieldNames.toSeq, Seq("b")))
+ assert(EqualNullSafe("b", EqualTo(name,
"2")).v2references.toSeq.map(_.toSeq)
+ == Seq(Seq("b"), fieldNames.toSeq))
+ }}
- test("IsNull references") {
- assert(IsNull("a").references.toSeq == Seq("a"))
- }
+ test("GreaterThan references") { withFieldNames { (name, fieldNames) =>
+ assert(GreaterThan(name, "1").references.toSeq == Seq(name))
+ assert(GreaterThan(name, "1").v2references.toSeq.map(_.toSeq) ==
Seq(fieldNames.toSeq))
- test("IsNotNull references") {
- assert(IsNotNull("a").references.toSeq == Seq("a"))
- }
+ assert(GreaterThan(name, EqualTo("b", "2")).references.toSeq == Seq(name,
"b"))
+ assert(GreaterThan("b", EqualTo(name, "2")).references.toSeq == Seq("b",
name))
- test("And references") {
- assert(And(EqualTo("a", "1"), EqualTo("b", "1")).references.toSeq ==
Seq("a", "b"))
- }
+ assert(GreaterThan(name, EqualTo("b", "2")).v2references.toSeq.map(_.toSeq)
+ == Seq(fieldNames.toSeq, Seq("b")))
+ assert(GreaterThan("b", EqualTo(name, "2")).v2references.toSeq.map(_.toSeq)
+ == Seq(Seq("b"), fieldNames.toSeq))
+ }}
- test("Or references") {
- assert(Or(EqualTo("a", "1"), EqualTo("b", "1")).references.toSeq ==
Seq("a", "b"))
- }
+ test("GreaterThanOrEqual references") { withFieldNames { (name, fieldNames)
=>
+ assert(GreaterThanOrEqual(name, "1").references.toSeq == Seq(name))
+ assert(GreaterThanOrEqual(name, "1").v2references.toSeq.map(_.toSeq) ==
Seq(fieldNames.toSeq))
- test("StringStartsWith references") {
- assert(StringStartsWith("a", "str").references.toSeq == Seq("a"))
- }
+ assert(GreaterThanOrEqual(name, EqualTo("b", "2")).references.toSeq ==
Seq(name, "b"))
+ assert(GreaterThanOrEqual("b", EqualTo(name, "2")).references.toSeq ==
Seq("b", name))
- test("StringEndsWith references") {
- assert(StringEndsWith("a", "str").references.toSeq == Seq("a"))
- }
+ assert(GreaterThanOrEqual(name, EqualTo("b",
"2")).v2references.toSeq.map(_.toSeq)
+ == Seq(fieldNames.toSeq, Seq("b")))
+ assert(GreaterThanOrEqual("b", EqualTo(name,
"2")).v2references.toSeq.map(_.toSeq)
+ == Seq(Seq("b"), fieldNames.toSeq))
+ }}
- test("StringContains references") {
- assert(StringContains("a", "str").references.toSeq == Seq("a"))
- }
+ test("LessThan references") { withFieldNames { (name, fieldNames) =>
+ assert(LessThan(name, "1").references.toSeq == Seq(name))
+ assert(LessThan(name, "1").v2references.toSeq.map(_.toSeq) ==
Seq(fieldNames.toSeq))
+
+ assert(LessThan("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b"))
+ }}
+
+ test("LessThanOrEqual references") { withFieldNames { (name, fieldNames) =>
+ assert(LessThanOrEqual(name, "1").references.toSeq == Seq(name))
+ assert(LessThanOrEqual(name, "1").v2references.toSeq.map(_.toSeq) ==
Seq(fieldNames.toSeq))
+
+ assert(LessThanOrEqual(name, EqualTo("b", "2")).references.toSeq ==
Seq(name, "b"))
+ assert(LessThanOrEqual("b", EqualTo(name, "2")).references.toSeq ==
Seq("b", name))
+
+ assert(LessThanOrEqual(name, EqualTo("b",
"2")).v2references.toSeq.map(_.toSeq)
+ == Seq(fieldNames.toSeq, Seq("b")))
+ assert(LessThanOrEqual("b", EqualTo(name,
"2")).v2references.toSeq.map(_.toSeq)
+ == Seq(Seq("b"), fieldNames.toSeq))
+ }}
+
+ test("In references") { withFieldNames { (name, fieldNames) =>
+ assert(In(name, Array("1")).references.toSeq == Seq(name))
+ assert(In(name, Array("1")).v2references.toSeq.map(_.toSeq) ==
Seq(fieldNames.toSeq))
+
+ assert(In(name, Array("1", EqualTo("b", "2"))).references.toSeq ==
Seq(name, "b"))
+ assert(In("b", Array("1", EqualTo(name, "2"))).references.toSeq ==
Seq("b", name))
+
+ assert(In(name, Array("1", EqualTo("b",
"2"))).v2references.toSeq.map(_.toSeq)
+ == Seq(fieldNames.toSeq, Seq("b")))
+ assert(In("b", Array("1", EqualTo(name,
"2"))).v2references.toSeq.map(_.toSeq)
+ == Seq(Seq("b"), fieldNames.toSeq))
+ }}
+
+ test("IsNull references") { withFieldNames { (name, fieldNames) =>
+ assert(IsNull(name).references.toSeq == Seq(name))
+ assert(IsNull(name).v2references.toSeq.map(_.toSeq) ==
Seq(fieldNames.toSeq))
+ }}
+
+ test("IsNotNull references") { withFieldNames { (name, fieldNames) =>
+ assert(IsNotNull(name).references.toSeq == Seq(name))
+ assert(IsNull(name).v2references.toSeq.map(_.toSeq) ==
Seq(fieldNames.toSeq))
+ }}
+
+ test("And references") { withFieldNames { (name, fieldNames) =>
+ assert(And(EqualTo(name, "1"), EqualTo("b", "1")).references.toSeq ==
Seq(name, "b"))
+ assert(And(EqualTo("b", "1"), EqualTo(name, "1")).references.toSeq ==
Seq("b", name))
+
+ assert(And(EqualTo(name, "1"), EqualTo("b",
"1")).v2references.toSeq.map(_.toSeq) ==
+ Seq(fieldNames.toSeq, Seq("b")))
+ assert(And(EqualTo("b", "1"), EqualTo(name,
"1")).v2references.toSeq.map(_.toSeq) ==
+ Seq(Seq("b"), fieldNames.toSeq))
+ }}
+
+ test("Or references") { withFieldNames { (name, fieldNames) =>
+ assert(Or(EqualTo(name, "1"), EqualTo("b", "1")).references.toSeq ==
Seq(name, "b"))
+ assert(Or(EqualTo("b", "1"), EqualTo(name, "1")).references.toSeq ==
Seq("b", name))
+
+ assert(Or(EqualTo(name, "1"), EqualTo("b",
"1")).v2references.toSeq.map(_.toSeq) ==
+ Seq(fieldNames.toSeq, Seq("b")))
+ assert(Or(EqualTo("b", "1"), EqualTo(name,
"1")).v2references.toSeq.map(_.toSeq) ==
+ Seq(Seq("b"), fieldNames.toSeq))
+ }}
+
+ test("StringStartsWith references") { withFieldNames { (name, fieldNames) =>
+ assert(StringStartsWith(name, "str").references.toSeq == Seq(name))
+ assert(StringStartsWith(name, "str").v2references.toSeq.map(_.toSeq) ==
Seq(fieldNames.toSeq))
+ }}
+
+ test("StringEndsWith references") { withFieldNames { (name, fieldNames) =>
+ assert(StringEndsWith(name, "str").references.toSeq == Seq(name))
+ assert(StringEndsWith(name, "str").v2references.toSeq.map(_.toSeq) ==
Seq(fieldNames.toSeq))
+ }}
+
+ test("StringContains references") { withFieldNames { (name, fieldNames) =>
+ assert(StringContains(name, "str").references.toSeq == Seq(name))
+ assert(StringContains(name, "str").v2references.toSeq.map(_.toSeq) ==
Seq(fieldNames.toSeq))
+ }}
}
diff --git
a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
index b9cbc48..f5abd30 100644
---
a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
+++
b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
@@ -65,9 +65,11 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* Create ORC filter as a SearchArgument instance.
*/
def createFilter(schema: StructType, filters: Seq[Filter]):
Option[SearchArgument] = {
- val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap
+ val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) ->
f.dataType).toMap
// Combines all convertible filters using `And` to produce a single
conjunction
- val conjunctionOptional = buildTree(convertibleFilters(schema,
dataTypeMap, filters))
+ // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so
they are removed.
+ val newFilters = filters.filter(!_.containsNestedColumn)
+ val conjunctionOptional = buildTree(convertibleFilters(schema,
dataTypeMap, newFilters))
conjunctionOptional.map { conjunction =>
// Then tries to build a single ORC `SearchArgument` for the conjunction
predicate.
// The input predicate is fully convertible. There should not be any
empty result in the
@@ -222,48 +224,39 @@ private[sql] object OrcFilters extends OrcFiltersBase {
// Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.`
characters
// in order to distinguish predicate pushdown for nested columns.
expression match {
- case EqualTo(attribute, value) if
isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValue = castLiteralValue(value, dataTypeMap(attribute))
- Some(builder.startAnd().equals(quotedName, getType(attribute),
castedValue).end())
+ case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name))
+ Some(builder.startAnd().equals(name, getType(name), castedValue).end())
- case EqualNullSafe(attribute, value) if
isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValue = castLiteralValue(value, dataTypeMap(attribute))
- Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute),
castedValue).end())
+ case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name))
+ Some(builder.startAnd().nullSafeEquals(name, getType(name),
castedValue).end())
- case LessThan(attribute, value) if
isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValue = castLiteralValue(value, dataTypeMap(attribute))
- Some(builder.startAnd().lessThan(quotedName, getType(attribute),
castedValue).end())
+ case LessThan(name, value) if isSearchableType(dataTypeMap(name)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name))
+ Some(builder.startAnd().lessThan(name, getType(name),
castedValue).end())
- case LessThanOrEqual(attribute, value) if
isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValue = castLiteralValue(value, dataTypeMap(attribute))
- Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute),
castedValue).end())
+ case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name))
=>
+ val castedValue = castLiteralValue(value, dataTypeMap(name))
+ Some(builder.startAnd().lessThanEquals(name, getType(name),
castedValue).end())
- case GreaterThan(attribute, value) if
isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValue = castLiteralValue(value, dataTypeMap(attribute))
- Some(builder.startNot().lessThanEquals(quotedName, getType(attribute),
castedValue).end())
+ case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name))
+ Some(builder.startNot().lessThanEquals(name, getType(name),
castedValue).end())
- case GreaterThanOrEqual(attribute, value) if
isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValue = castLiteralValue(value, dataTypeMap(attribute))
- Some(builder.startNot().lessThan(quotedName, getType(attribute),
castedValue).end())
+ case GreaterThanOrEqual(name, value) if
isSearchableType(dataTypeMap(name)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name))
+ Some(builder.startNot().lessThan(name, getType(name),
castedValue).end())
- case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- Some(builder.startAnd().isNull(quotedName, getType(attribute)).end())
+ case IsNull(name) if isSearchableType(dataTypeMap(name)) =>
+ Some(builder.startAnd().isNull(name, getType(name)).end())
- case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- Some(builder.startNot().isNull(quotedName, getType(attribute)).end())
+ case IsNotNull(name) if isSearchableType(dataTypeMap(name)) =>
+ Some(builder.startNot().isNull(name, getType(name)).end())
- case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValues = values.map(v => castLiteralValue(v,
dataTypeMap(attribute)))
- Some(builder.startAnd().in(quotedName, getType(attribute),
+ case In(name, values) if isSearchableType(dataTypeMap(name)) =>
+ val castedValues = values.map(v => castLiteralValue(v,
dataTypeMap(name)))
+ Some(builder.startAnd().in(name, getType(name),
castedValues.map(_.asInstanceOf[AnyRef]): _*).end())
case _ => None
diff --git
a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
index 6e9e592..675e089 100644
---
a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
+++
b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
@@ -65,9 +65,11 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* Create ORC filter as a SearchArgument instance.
*/
def createFilter(schema: StructType, filters: Seq[Filter]):
Option[SearchArgument] = {
- val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap
+ val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) ->
f.dataType).toMap
// Combines all convertible filters using `And` to produce a single
conjunction
- val conjunctionOptional = buildTree(convertibleFilters(schema,
dataTypeMap, filters))
+ // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so
they are removed.
+ val newFilters = filters.filter(!_.containsNestedColumn)
+ val conjunctionOptional = buildTree(convertibleFilters(schema,
dataTypeMap, newFilters))
conjunctionOptional.map { conjunction =>
// Then tries to build a single ORC `SearchArgument` for the conjunction
predicate.
// The input predicate is fully convertible. There should not be any
empty result in the
@@ -222,48 +224,39 @@ private[sql] object OrcFilters extends OrcFiltersBase {
// Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.`
characters
// in order to distinguish predicate pushdown for nested columns.
expression match {
- case EqualTo(attribute, value) if
isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValue = castLiteralValue(value, dataTypeMap(attribute))
- Some(builder.startAnd().equals(quotedName, getType(attribute),
castedValue).end())
+ case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name))
+ Some(builder.startAnd().equals(name, getType(name), castedValue).end())
- case EqualNullSafe(attribute, value) if
isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValue = castLiteralValue(value, dataTypeMap(attribute))
- Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute),
castedValue).end())
+ case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name))
+ Some(builder.startAnd().nullSafeEquals(name, getType(name),
castedValue).end())
- case LessThan(attribute, value) if
isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValue = castLiteralValue(value, dataTypeMap(attribute))
- Some(builder.startAnd().lessThan(quotedName, getType(attribute),
castedValue).end())
+ case LessThan(name, value) if isSearchableType(dataTypeMap(name)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name))
+ Some(builder.startAnd().lessThan(name, getType(name),
castedValue).end())
- case LessThanOrEqual(attribute, value) if
isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValue = castLiteralValue(value, dataTypeMap(attribute))
- Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute),
castedValue).end())
+ case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name))
=>
+ val castedValue = castLiteralValue(value, dataTypeMap(name))
+ Some(builder.startAnd().lessThanEquals(name, getType(name),
castedValue).end())
- case GreaterThan(attribute, value) if
isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValue = castLiteralValue(value, dataTypeMap(attribute))
- Some(builder.startNot().lessThanEquals(quotedName, getType(attribute),
castedValue).end())
+ case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name))
+ Some(builder.startNot().lessThanEquals(name, getType(name),
castedValue).end())
- case GreaterThanOrEqual(attribute, value) if
isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValue = castLiteralValue(value, dataTypeMap(attribute))
- Some(builder.startNot().lessThan(quotedName, getType(attribute),
castedValue).end())
+ case GreaterThanOrEqual(name, value) if
isSearchableType(dataTypeMap(name)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name))
+ Some(builder.startNot().lessThan(name, getType(name),
castedValue).end())
- case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- Some(builder.startAnd().isNull(quotedName, getType(attribute)).end())
+ case IsNull(name) if isSearchableType(dataTypeMap(name)) =>
+ Some(builder.startAnd().isNull(name, getType(name)).end())
- case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- Some(builder.startNot().isNull(quotedName, getType(attribute)).end())
+ case IsNotNull(name) if isSearchableType(dataTypeMap(name)) =>
+ Some(builder.startNot().isNull(name, getType(name)).end())
- case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) =>
- val quotedName = quoteIfNeeded(attribute)
- val castedValues = values.map(v => castLiteralValue(v,
dataTypeMap(attribute)))
- Some(builder.startAnd().in(quotedName, getType(attribute),
+ case In(name, values) if isSearchableType(dataTypeMap(name)) =>
+ val castedValues = values.map(v => castLiteralValue(v,
dataTypeMap(name)))
+ Some(builder.startAnd().in(name, getType(name),
castedValues.map(_.asInstanceOf[AnyRef]): _*).end())
case _ => None
diff --git
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala
index cd1bffb..f9c5145 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala
@@ -25,6 +25,7 @@ import
org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded
import org.apache.spark.sql.execution.datasources.orc.{OrcFilters =>
DatasourceOrcFilters}
import org.apache.spark.sql.execution.datasources.orc.OrcFilters.buildTree
import org.apache.spark.sql.hive.HiveUtils
@@ -73,9 +74,11 @@ private[orc] object OrcFilters extends Logging {
if (HiveUtils.isHive23) {
DatasourceOrcFilters.createFilter(schema,
filters).asInstanceOf[Option[SearchArgument]]
} else {
- val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap
+ val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) ->
f.dataType).toMap
+ // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so
they are removed.
+ val newFilters = filters.filter(!_.containsNestedColumn)
// Combines all convertible filters using `And` to produce a single
conjunction
- val conjunctionOptional = buildTree(convertibleFilters(schema,
dataTypeMap, filters))
+ val conjunctionOptional = buildTree(convertibleFilters(schema,
dataTypeMap, newFilters))
conjunctionOptional.map { conjunction =>
// Then tries to build a single ORC `SearchArgument` for the
conjunction predicate.
// The input predicate is fully convertible. There should not be any
empty result in the
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]