This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b4faf063e375 [SPARK-50131][SQL] Add IN Subquery DataFrame API b4faf063e375 is described below commit b4faf063e37597cf09b9e210766972303fbebe94 Author: Takuya Ueshin <ues...@databricks.com> AuthorDate: Tue Apr 15 13:55:15 2025 -0700 [SPARK-50131][SQL] Add IN Subquery DataFrame API ### What changes were proposed in this pull request? Adds IN Subquery DataFrame API. The existing `Column.isin` will take `Dataset`/`DataFrame` and use it as IN subquery. For example: ```py >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob"), (8, "Mike")], ["age", "name"]) >>> df.createOrReplaceTempView('t') >>> # SELECT * FROM t WHERE age IN (SELECT * FROM range(6)) ORDER BY age >>> df.where(df.age.isin(spark.range(6))).orderBy("age").show() +---+-----+ |age| name| +---+-----+ | 2|Alice| | 5| Bob| +---+-----+ ``` For multiple values to be compared, `sf.struct` will be used to tuple them: ```py >>> from pyspark.sql import functions as sf >>> # SELECT * FROM t WHERE (age, name) IN (SELECT id, 'Bob' FROM range(6)) >>> df.where(sf.struct(df.age, df.name).isin(spark.range(6).select("id", sf.lit("Bob")))).show() +---+----+ |age|name| +---+----+ | 5| Bob| +---+----+ ``` ### Why are the changes needed? The IN subquery was missing in DataFrame API. ### Does this PR introduce _any_ user-facing change? Yes, `Column.isin` will take `Dataset`/`DataFrame` to use it as IN subquery. ### How was this patch tested? Added the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50470 from ueshin/issues/SPARK-50131/in_subquery. Authored-by: Takuya Ueshin <ues...@databricks.com> Signed-off-by: Takuya Ueshin <ues...@databricks.com> --- python/pyspark/sql/classic/column.py | 7 + python/pyspark/sql/column.py | 25 ++- python/pyspark/sql/connect/column.py | 13 ++ python/pyspark/sql/connect/expressions.py | 24 ++- .../pyspark/sql/connect/proto/expressions_pb2.py | 12 +- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 11 ++ python/pyspark/sql/tests/test_subquery.py | 207 ++++++++++++++++++++- .../main/scala/org/apache/spark/sql/Column.scala | 20 ++ .../main/scala/org/apache/spark/sql/Dataset.scala | 8 +- .../apache/spark/sql/internal/columnNodes.scala | 22 +++ .../sql/catalyst/expressions/predicates.scala | 9 +- .../apache/spark/sql/DataFrameSubquerySuite.scala | 125 +++++++++++++ .../main/protobuf/spark/connect/expressions.proto | 4 + .../org/apache/spark/sql/connect/Dataset.scala | 10 - .../apache/spark/sql/connect/SparkSession.scala | 7 +- .../spark/sql/connect/columnNodeSupport.scala | 28 +-- .../sql/connect/planner/SparkConnectPlanner.scala | 6 + .../org/apache/spark/sql/classic/Dataset.scala | 12 +- .../spark/sql/classic/columnNodeSupport.scala | 14 +- .../apache/spark/sql/DataFrameSubquerySuite.scala | 124 ++++++++++++ 20 files changed, 624 insertions(+), 64 deletions(-) diff --git a/python/pyspark/sql/classic/column.py b/python/pyspark/sql/classic/column.py index 161f8ba4bb7a..fef65bcb5d54 100644 --- a/python/pyspark/sql/classic/column.py +++ b/python/pyspark/sql/classic/column.py @@ -474,6 +474,13 @@ class Column(ParentColumn): return Column(jc) def isin(self, *cols: Any) -> ParentColumn: + from pyspark.sql.classic.dataframe import DataFrame + + if len(cols) == 1 and isinstance(cols[0], DataFrame): + df: DataFrame = cols[0] + jc = self._jc.isin(df._jdf) + return Column(jc) + if len(cols) == 1 and isinstance(cols[0], (list, set)): cols = cast(Tuple, cols[0]) cols = cast( diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index a055e4456495..db02cc80dbed 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -883,6 +883,9 @@ class Column(TableValuedFunctionArgument): .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.1.0 + Also takes a single :class:`DataFrame` to be used as IN subquery. + Parameters ---------- cols : Any @@ -900,7 +903,7 @@ class Column(TableValuedFunctionArgument): Example 1: Filter rows with names in the specified values - >>> df[df.name.isin("Bob", "Mike")].show() + >>> df[df.name.isin("Bob", "Mike")].orderBy("age").show() +---+----+ |age|name| +---+----+ @@ -925,6 +928,26 @@ class Column(TableValuedFunctionArgument): +---+----+ | 8|Mike| +---+----+ + + Example 4: Take a :class:`DataFrame` and work as IN subquery + + >>> df.where(df.age.isin(spark.range(6))).orderBy("age").show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + + Example 5: Multiple values for IN subquery + + >>> from pyspark.sql.functions import lit, struct + >>> df.where(struct(df.age, df.name).isin(spark.range(6).select("id", lit("Bob")))).show() + +---+----+ + |age|name| + +---+----+ + | 5| Bob| + +---+----+ """ ... diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 15d943175850..d6ed62ba4a52 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -47,6 +47,7 @@ from pyspark.sql.connect.expressions import ( LiteralExpression, CaseWhen, SortOrder, + SubqueryExpression, CastExpression, WindowExpression, WithField, @@ -461,6 +462,18 @@ class Column(ParentColumn): return Column(self._expr) def isin(self, *cols: Any) -> ParentColumn: + from pyspark.sql.connect.dataframe import DataFrame + + if len(cols) == 1 and isinstance(cols[0], DataFrame): + if isinstance(self._expr, UnresolvedFunction) and self._expr._name == "struct": + values = self._expr.children + else: + values = [self._expr] + + return Column( + SubqueryExpression(cols[0]._plan, subquery_type="in", in_subquery_values=values) + ) + if len(cols) == 1 and isinstance(cols[0], (list, set)): _cols = list(cols[0]) else: diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index e5b10be41963..872770ee2291 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -1241,9 +1241,10 @@ class SubqueryExpression(Expression): partition_spec: Optional[Sequence["Expression"]] = None, order_spec: Optional[Sequence["SortOrder"]] = None, with_single_partition: Optional[bool] = None, + in_subquery_values: Optional[Sequence["Expression"]] = None, ) -> None: assert isinstance(subquery_type, str) - assert subquery_type in ("scalar", "exists", "table_arg") + assert subquery_type in ("scalar", "exists", "table_arg", "in") super().__init__() self._plan = plan @@ -1251,6 +1252,7 @@ class SubqueryExpression(Expression): self._partition_spec = partition_spec or [] self._order_spec = order_spec or [] self._with_single_partition = with_single_partition + self._in_subquery_values = in_subquery_values or [] def to_plan(self, session: "SparkConnectClient") -> proto.Expression: expr = self._create_proto_expression() @@ -1276,17 +1278,25 @@ class SubqueryExpression(Expression): ) if self._with_single_partition is not None: table_arg_options.with_single_partition = self._with_single_partition + elif self._subquery_type == "in": + expr.subquery_expression.subquery_type = proto.SubqueryExpression.SUBQUERY_TYPE_IN + expr.subquery_expression.in_subquery_values.extend( + [expr.to_plan(session) for expr in self._in_subquery_values] + ) return expr def __repr__(self) -> str: repr_parts = [f"plan={self._plan}", f"type={self._subquery_type}"] - if self._partition_spec: - repr_parts.append(f"partition_spec={self._partition_spec}") - if self._order_spec: - repr_parts.append(f"order_spec={self._order_spec}") - if self._with_single_partition is not None: - repr_parts.append(f"with_single_partition={self._with_single_partition}") + if self._subquery_type == "table_arg": + if self._partition_spec: + repr_parts.append(f"partition_spec={self._partition_spec}") + if self._order_spec: + repr_parts.append(f"order_spec={self._order_spec}") + if self._with_single_partition is not None: + repr_parts.append(f"with_single_partition={self._with_single_partition}") + elif self._subquery_type == "in": + repr_parts.append(f"values={self._in_subquery_values}") return f"SubqueryExpression({', '.join(repr_parts)})" diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index f5fdd162a708..0cec23f4857d 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import common_pb2 as spark_dot_connect_dot_common DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xf3\x34\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12 \x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolved [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xf3\x34\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12 \x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolved [...] ) _globals = globals() @@ -130,9 +130,9 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals["_MERGEACTION_ACTIONTYPE"]._serialized_start = 8586 _globals["_MERGEACTION_ACTIONTYPE"]._serialized_end = 8753 _globals["_SUBQUERYEXPRESSION"]._serialized_start = 8770 - _globals["_SUBQUERYEXPRESSION"]._serialized_end = 9383 - _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_start = 9003 - _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_end = 9237 - _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_start = 9239 - _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_end = 9361 + _globals["_SUBQUERYEXPRESSION"]._serialized_end = 9479 + _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_start = 9076 + _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_end = 9310 + _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_start = 9313 + _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_end = 9457 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index f6aada59a2d8..25fc04c0319e 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -1940,12 +1940,14 @@ class SubqueryExpression(google.protobuf.message.Message): SUBQUERY_TYPE_SCALAR: SubqueryExpression._SubqueryType.ValueType # 1 SUBQUERY_TYPE_EXISTS: SubqueryExpression._SubqueryType.ValueType # 2 SUBQUERY_TYPE_TABLE_ARG: SubqueryExpression._SubqueryType.ValueType # 3 + SUBQUERY_TYPE_IN: SubqueryExpression._SubqueryType.ValueType # 4 class SubqueryType(_SubqueryType, metaclass=_SubqueryTypeEnumTypeWrapper): ... SUBQUERY_TYPE_UNKNOWN: SubqueryExpression.SubqueryType.ValueType # 0 SUBQUERY_TYPE_SCALAR: SubqueryExpression.SubqueryType.ValueType # 1 SUBQUERY_TYPE_EXISTS: SubqueryExpression.SubqueryType.ValueType # 2 SUBQUERY_TYPE_TABLE_ARG: SubqueryExpression.SubqueryType.ValueType # 3 + SUBQUERY_TYPE_IN: SubqueryExpression.SubqueryType.ValueType # 4 class TableArgOptions(google.protobuf.message.Message): """Nested message for table argument options.""" @@ -2010,6 +2012,7 @@ class SubqueryExpression(google.protobuf.message.Message): PLAN_ID_FIELD_NUMBER: builtins.int SUBQUERY_TYPE_FIELD_NUMBER: builtins.int TABLE_ARG_OPTIONS_FIELD_NUMBER: builtins.int + IN_SUBQUERY_VALUES_FIELD_NUMBER: builtins.int plan_id: builtins.int """(Required) The ID of the corresponding connect plan.""" subquery_type: global___SubqueryExpression.SubqueryType.ValueType @@ -2017,12 +2020,18 @@ class SubqueryExpression(google.protobuf.message.Message): @property def table_arg_options(self) -> global___SubqueryExpression.TableArgOptions: """(Optional) Options specific to table arguments.""" + @property + def in_subquery_values( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Expression]: + """(Optional) IN subquery values.""" def __init__( self, *, plan_id: builtins.int = ..., subquery_type: global___SubqueryExpression.SubqueryType.ValueType = ..., table_arg_options: global___SubqueryExpression.TableArgOptions | None = ..., + in_subquery_values: collections.abc.Iterable[global___Expression] | None = ..., ) -> None: ... def HasField( self, @@ -2035,6 +2044,8 @@ class SubqueryExpression(google.protobuf.message.Message): field_name: typing_extensions.Literal[ "_table_arg_options", b"_table_arg_options", + "in_subquery_values", + b"in_subquery_values", "plan_id", b"plan_id", "subquery_type", diff --git a/python/pyspark/sql/tests/test_subquery.py b/python/pyspark/sql/tests/test_subquery.py index 7c63ddb69458..7c87f4b46cc6 100644 --- a/python/pyspark/sql/tests/test_subquery.py +++ b/python/pyspark/sql/tests/test_subquery.py @@ -28,7 +28,7 @@ class SubqueryTestsMixin: def df1(self): return self.spark.createDataFrame( [ - (1, 1.0), + (1, 2.0), (1, 2.0), (2, 1.0), (2, 2.0), @@ -459,6 +459,211 @@ class SubqueryTestsMixin: ), ) + def test_in_subquery(self): + with self.tempView("l", "r", "t"): + self.df1.createOrReplaceTempView("l") + self.df2.createOrReplaceTempView("r") + self.spark.table("r").filter( + sf.col("c").isNotNull() & sf.col("d").isNotNull() + ).createOrReplaceTempView("t") + + with self.subTest("IN"): + assertDataFrameEqual( + self.spark.table("l").where( + sf.col("l.a").isin(self.spark.table("r").select(sf.col("c"))) + ), + self.spark.sql("""select * from l where l.a in (select c from r)"""), + ) + assertDataFrameEqual( + self.spark.table("l").where( + sf.col("l.a").isin( + self.spark.table("r") + .where(sf.col("l.b").outer() < sf.col("r.d")) + .select(sf.col("c")) + ) + ), + self.spark.sql( + """select * from l where l.a in (select c from r where l.b < r.d)""" + ), + ) + assertDataFrameEqual( + self.spark.table("l").where( + sf.col("l.a").isin(self.spark.table("r").select("c")) + & (sf.col("l.a") > sf.lit(2)) + & sf.col("l.b").isNotNull() + ), + self.spark.sql( + """ + select * from l + where l.a in (select c from r) and l.a > 2 and l.b is not null + """ + ), + ) + + with self.subTest("IN with struct"), self.tempView("ll", "rr"): + self.spark.table("l").select( + "*", sf.struct("a", "b").alias("sab") + ).createOrReplaceTempView("ll") + self.spark.table("r").select( + "*", sf.struct(sf.col("c").alias("a"), sf.col("d").alias("b")).alias("scd") + ).createOrReplaceTempView("rr") + + for col, values in [ + (sf.col("sab"), "sab"), + (sf.struct(sf.struct(sf.col("a"), sf.col("b"))), "struct(struct(a, b))"), + ]: + for df, query in [ + (self.spark.table("rr").select(sf.col("scd")), "select scd from rr"), + ( + self.spark.table("rr").select( + sf.struct(sf.col("c").alias("a"), sf.col("d").alias("b")) + ), + "select struct(c as a, d as b) from rr", + ), + ( + self.spark.table("rr").select(sf.struct(sf.col("c"), sf.col("d"))), + "select struct(c, d) from rr", + ), + ]: + sql_query = f"""select a, b from ll where {values} in ({query})""" + with self.subTest(sql_query=sql_query): + assertDataFrameEqual( + self.spark.table("ll").where(col.isin(df)).select("a", "b"), + self.spark.sql(sql_query), + ) + + with self.subTest("NOT IN"): + assertDataFrameEqual( + self.spark.table("l").where( + ~sf.col("a").isin(self.spark.table("r").select("c")) + ), + self.spark.sql("""select * from l where a not in (select c from r)"""), + ) + assertDataFrameEqual( + self.spark.table("l").where( + ~sf.col("a").isin( + self.spark.table("r").where(sf.col("c").isNotNull()).select(sf.col("c")) + ) + ), + self.spark.sql( + """select * from l where a not in (select c from r where c is not null)""" + ), + ) + assertDataFrameEqual( + self.spark.table("l").where( + ( + ~sf.struct(sf.col("a"), sf.col("b")).isin( + self.spark.table("t").select(sf.col("c"), sf.col("d")) + ) + ) + & (sf.col("a") < sf.lit(4)) + ), + self.spark.sql( + """select * from l where (a, b) not in (select c, d from t) and a < 4""" + ), + ) + assertDataFrameEqual( + self.spark.table("l").where( + ~sf.struct(sf.col("a"), sf.col("b")).isin( + self.spark.table("r") + .where(sf.col("c") > sf.lit(10)) + .select(sf.col("c"), sf.col("d")) + ) + ), + self.spark.sql( + """select * from l where (a, b) not in (select c, d from r where c > 10)""" + ), + ) + + with self.subTest("IN within OR"): + assertDataFrameEqual( + self.spark.table("l").where( + sf.col("l.a").isin(self.spark.table("r").select("c")) + | ( + sf.col("l.a").isin( + self.spark.table("r") + .where(sf.col("l.b").outer() < sf.col("r.d")) + .select(sf.col("c")) + ) + ) + ), + self.spark.sql( + """ + select * from l + where l.a in (select c from r) or l.a in (select c from r where l.b < r.d) + """ + ), + ) + assertDataFrameEqual( + self.spark.table("l").where( + (~sf.col("a").isin(self.spark.table("r").select(sf.col("c")))) + | ( + ~sf.col("a").isin( + self.spark.table("r") + .where(sf.col("c").isNotNull()) + .select(sf.col("c")) + ) + ) + ), + self.spark.sql( + """ + select * from l + where a not in (select c from r) + or a not in (select c from r where c is not null) + """ + ), + ) + + with self.subTest("complex IN"): + assertDataFrameEqual( + self.spark.table("l").where( + ~sf.struct(sf.col("a"), sf.col("b")).isin( + self.spark.table("r").select(sf.col("c"), sf.col("d")) + ) + ), + self.spark.sql("""select * from l where (a, b) not in (select c, d from r)"""), + ) + assertDataFrameEqual( + self.spark.table("l").where( + ( + ~sf.struct(sf.col("a"), sf.col("b")).isin( + self.spark.table("t").select(sf.col("c"), sf.col("d")) + ) + ) + & ((sf.col("a") + sf.col("b")).isNotNull()) + ), + self.spark.sql( + """ + select * from l + where (a, b) not in (select c, d from t) and (a + b) is not null + """ + ), + ) + + with self.subTest("same column in subquery"): + assertDataFrameEqual( + self.spark.table("l") + .alias("l1") + .where( + sf.col("a").isin( + self.spark.table("l") + .where(sf.col("a") < sf.lit(3)) + .groupBy(sf.col("a")) + .agg({}) + ) + ) + .select(sf.col("a")), + self.spark.sql( + """select a from l l1 where a in (select a from l where a < 3 group by a)""" + ), + ) + + with self.subTest("col IN (NULL)"): + assertDataFrameEqual( + self.spark.table("l").where(sf.col("a").isin(None)), + self.spark.sql("""SELECT * FROM l WHERE a IN (NULL)"""), + ) + def test_scalar_subquery_with_missing_outer_reference(self): with self.tempView("l", "r"): self.df1.createOrReplaceTempView("l") diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala index 7f5eed1eb1ad..88d597fdfbb7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala @@ -803,6 +803,26 @@ class Column(val node: ColumnNode) extends Logging with TableValuedFunctionArgum */ def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala) + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided Dataset/DataFrame. + * + * @group subquery + * @since 4.1.0 + */ + def isin(ds: Dataset[_]): Column = { + if (ds == null) { + // A single null should be handled as a value. + isin(Seq(ds): _*) + } else { + val values = node match { + case internal.UnresolvedFunction("struct", arguments, _, _, _, _) => arguments + case _ => Seq(node) + } + Column(internal.SubqueryExpression(ds, internal.SubqueryType.IN(values))) + } + } + /** * SQL like expression. Returns a boolean column based on a SQL LIKE match. * diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala index 4952fa36f66e..c287ad69de2e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1758,7 +1758,9 @@ abstract class Dataset[T] extends Serializable { * @group subquery * @since 4.0.0 */ - def scalar(): Column + def scalar(): Column = { + Column(internal.SubqueryExpression(this, internal.SubqueryType.SCALAR)) + } /** * Return a `Column` object for an EXISTS Subquery. @@ -1771,7 +1773,9 @@ abstract class Dataset[T] extends Serializable { * @group subquery * @since 4.0.0 */ - def exists(): Column + def exists(): Column = { + Column(internal.SubqueryExpression(this, internal.SubqueryType.EXISTS)) + } /** * Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala index 463307409839..4a7339165cb5 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -20,6 +20,7 @@ import java.util.concurrent.atomic.AtomicLong import ColumnNode._ +import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.catalyst.util.AttributeNameParser import org.apache.spark.sql.errors.DataTypeErrorsBase @@ -651,3 +652,24 @@ private[sql] case class LazyExpression( override def sql: String = "lazy" + argumentsToSql(Seq(child)) override private[internal] def children: Seq[ColumnNodeLike] = Seq(child) } + +sealed trait SubqueryType + +object SubqueryType { + case object SCALAR extends SubqueryType + case object EXISTS extends SubqueryType + case class IN(values: Seq[ColumnNode]) extends SubqueryType +} + +case class SubqueryExpression( + ds: Dataset[_], + subqueryType: SubqueryType, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { + override def sql: String = subqueryType match { + case SubqueryType.SCALAR => s"($ds)" + case SubqueryType.IN(values) => s"(${values.map(_.sql).mkString(",")}) IN ($ds)" + case _ => s"$subqueryType ($ds)" + } + override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3e23f0551125..c31c72bc1148 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -23,7 +23,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedPlanId} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference import org.apache.spark.sql.catalyst.expressions.Cast._ @@ -419,6 +419,13 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) copy(values = newChildren.dropRight(1), query = newChildren.last.asInstanceOf[ListQuery]) } +case class UnresolvedInSubqueryPlanId(values: Seq[Expression], planId: Long) + extends UnresolvedPlanId { + + override def withPlan(plan: LogicalPlan): Expression = { + InSubquery(values, ListQuery(plan)) + } +} /** * Evaluates to `true` if `list` contains `value`. diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala index ba77879a5a80..9e93cd4442d5 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala @@ -45,10 +45,13 @@ class DataFrameSubquerySuite extends QueryTest with RemoteSparkSession { row((null, 5.0)), row((6, null))).toDF("c", "d") + lazy val t = r.filter($"c".isNotNull && $"d".isNotNull) + override def beforeAll(): Unit = { super.beforeAll() l.createOrReplaceTempView("l") r.createOrReplaceTempView("r") + t.createOrReplaceTempView("t") } test("noop outer()") { @@ -318,6 +321,128 @@ class DataFrameSubquerySuite extends QueryTest with RemoteSparkSession { sql("select a, (select sum(d) from r where c = a) from l")) } + test("IN predicate subquery") { + checkAnswer( + spark.table("l").where($"l.a".isin(spark.table("r").select($"c"))), + sql("select * from l where l.a in (select c from r)")) + + checkAnswer( + spark + .table("l") + .where($"l.a".isin(spark.table("r").where($"l.b".outer() < $"r.d").select($"c"))), + sql("select * from l where l.a in (select c from r where l.b < r.d)")) + + checkAnswer( + spark + .table("l") + .where($"l.a".isin(spark.table("r").select("c")) && $"l.a" > 2 && $"l.b".isNotNull), + sql("select * from l where l.a in (select c from r) and l.a > 2 and l.b is not null")) + } + + test("IN predicate subquery with struct") { + withTempView("ll", "rr") { + spark.table("l").select($"*", struct("a", "b").alias("sab")).createOrReplaceTempView("ll") + spark + .table("r") + .select($"*", struct($"c".as("a"), $"d".as("b")).alias("scd")) + .createOrReplaceTempView("rr") + + for ((col, values) <- Seq( + ($"sab", "sab"), + (struct(struct($"a", $"b")), "struct(struct(a, b))")); + (df, query) <- Seq( + (spark.table("rr").select($"scd"), "select scd from rr"), + ( + spark.table("rr").select(struct($"c".as("a"), $"d".as("b"))), + "select struct(c as a, d as b) from rr"), + (spark.table("rr").select(struct($"c", $"d")), "select struct(c, d) from rr"))) { + checkAnswer( + spark.table("ll").where(col.isin(df)).select($"a", $"b"), + sql(s"select a, b from ll where $values in ($query)")) + } + } + } + + test("NOT IN predicate subquery") { + checkAnswer( + spark.table("l").where(!$"a".isin(spark.table("r").select($"c"))), + sql("select * from l where a not in (select c from r)")) + + checkAnswer( + spark.table("l").where(!$"a".isin(spark.table("r").where($"c".isNotNull).select($"c"))), + sql("select * from l where a not in (select c from r where c is not null)")) + + checkAnswer( + spark + .table("l") + .where(!struct($"a", $"b").isin(spark.table("t").select($"c", $"d")) && $"a" < lit(4)), + sql("select * from l where (a, b) not in (select c, d from t) and a < 4")) + + // Empty sub-query + checkAnswer( + spark + .table("l") + .where( + !struct($"a", $"b").isin(spark.table("r").where($"c" > lit(10)).select($"c", $"d"))), + sql("select * from l where (a, b) not in (select c, d from r where c > 10)")) + } + + test("IN predicate subquery within OR") { + checkAnswer( + spark + .table("l") + .where($"l.a".isin(spark.table("r").select("c")) + || $"l.a".isin(spark.table("r").where($"l.b".outer() < $"r.d").select($"c"))), + sql( + "select * from l where l.a in (select c from r)" + + " or l.a in (select c from r where l.b < r.d)")) + + checkAnswer( + spark + .table("l") + .where(!$"a".isin(spark.table("r").select("c")) + || !$"a".isin(spark.table("r").where($"c".isNotNull).select($"c"))), + sql( + "select * from l where a not in (select c from r)" + + " or a not in (select c from r where c is not null)")) + } + + test("complex IN predicate subquery") { + checkAnswer( + spark.table("l").where(!struct($"a", $"b").isin(spark.table("r").select($"c", $"d"))), + sql("select * from l where (a, b) not in (select c, d from r)")) + + checkAnswer( + spark + .table("l") + .where(!struct($"a", $"b").isin(spark.table("t").select($"c", $"d")) + && ($"a" + $"b").isNotNull), + sql("select * from l where (a, b) not in (select c, d from t) and (a + b) is not null")) + } + + test("same column in subquery and outer table") { + checkAnswer( + spark + .table("l") + .as("l1") + .where( + $"a".isin( + spark + .table("l") + .where($"a" < lit(3)) + .groupBy($"a") + .agg(Map.empty[String, String]))) + .select($"a"), + sql("select a from l l1 where a in (select a from l where a < 3 group by a)")) + } + + test("col IN (NULL)") { + checkAnswer(spark.table("l").where($"a".isin(null)), sql("SELECT * FROM l WHERE a IN (NULL)")) + checkAnswer( + spark.table("l").where(!$"a".isin(null)), + sql("SELECT * FROM l WHERE a NOT IN (NULL)")) + } + private def table1() = { sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)") spark.table("t1") diff --git a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto index 4dcaa9a40142..df907a84868f 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -483,11 +483,15 @@ message SubqueryExpression { // (Optional) Options specific to table arguments. optional TableArgOptions table_arg_options = 3; + // (Optional) IN subquery values. + repeated Expression in_subquery_values = 4; + enum SubqueryType { SUBQUERY_TYPE_UNKNOWN = 0; SUBQUERY_TYPE_SCALAR = 1; SUBQUERY_TYPE_EXISTS = 2; SUBQUERY_TYPE_TABLE_ARG = 3; + SUBQUERY_TYPE_IN = 4; } // Nested message for table argument options. diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala index 419ac3b7f74a..ec169ba114a3 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala @@ -626,16 +626,6 @@ class Dataset[T] private[sql] ( def transpose(): DataFrame = buildTranspose(Seq.empty) - /** @inheritdoc */ - def scalar(): Column = { - Column(SubqueryExpressionNode(plan.getRoot, SubqueryType.SCALAR)) - } - - /** @inheritdoc */ - def exists(): Column = { - Column(SubqueryExpressionNode(plan.getRoot, SubqueryType.EXISTS)) - } - /** @inheritdoc */ def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getLimitBuilder diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala index cf7b6a1ddd6d..739b0318759e 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala @@ -49,10 +49,11 @@ import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BoxedLongEncoder, UnboundRowEncoder} import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer -import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf} +import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf, SubqueryExpression} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ExecutionListenerManager @@ -420,8 +421,8 @@ class SparkSession private[sql] ( @DeveloperApi def newDataset[T](encoder: AgnosticEncoder[T], cols: Seq[Column])( f: proto.Relation.Builder => Unit): Dataset[T] = { - val references = cols.flatMap(_.node.collect { case n: SubqueryExpressionNode => - n.relation + val references: Seq[proto.Relation] = cols.flatMap(_.node.collect { + case n: SubqueryExpression => n.ds.plan.getRoot }) val builder = proto.Relation.newBuilder() diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala index 54f45a434826..1e798387726b 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala @@ -27,10 +27,11 @@ import org.apache.spark.connect.proto.Expression.Window.WindowFrame.{FrameBounda import org.apache.spark.sql.{functions, Column, Encoder} import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} -import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SqlExpression, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame} +import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SqlExpression, SubqueryExpression, SubqueryType, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame} /** * Converter for [[ColumnNode]] to [[proto.Expression]] conversions. @@ -218,11 +219,15 @@ object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) { case LazyExpression(child, _) => return apply(child, e) - case SubqueryExpressionNode(relation, subqueryType, _) => + case SubqueryExpression(ds, subqueryType, _) => + val relation = ds.plan.getRoot val b = builder.getSubqueryExpressionBuilder b.setSubqueryType(subqueryType match { case SubqueryType.SCALAR => proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_SCALAR case SubqueryType.EXISTS => proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_EXISTS + case SubqueryType.IN(values) => + b.addAllInSubqueryValues(values.map(value => apply(value, e)).asJava) + proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_IN }) assert(relation.hasCommon && relation.getCommon.hasPlanId) b.setPlanId(relation.getCommon.getPlanId) @@ -311,22 +316,3 @@ case class ProtoColumnNode( override def sql: String = expr.toString override def children: Seq[ColumnNodeLike] = Seq.empty } - -sealed trait SubqueryType - -object SubqueryType { - case object SCALAR extends SubqueryType - case object EXISTS extends SubqueryType -} - -case class SubqueryExpressionNode( - relation: proto.Relation, - subqueryType: SubqueryType, - override val origin: Origin = CurrentOrigin.get) - extends ColumnNode { - override def sql: String = subqueryType match { - case SubqueryType.SCALAR => s"($relation)" - case _ => s"$subqueryType ($relation)" - } - override def children: Seq[ColumnNodeLike] = Seq.empty -} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 734eb394ca68..a7a76e334e47 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -4025,6 +4025,12 @@ class SparkConnectPlanner( } else { UnresolvedTableArgPlanId(planId) } + case proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_IN => + UnresolvedInSubqueryPlanId( + getSubqueryExpression.getInSubqueryValuesList.asScala.map { value => + transformExpression(value) + }.toSeq, + planId) case other => throw InvalidPlanInput(s"Unknown SubqueryType $other") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index 366cc3f4b7d7..8765e1bfc7c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder, StructEncoder} -import org.apache.spark.sql.catalyst.expressions.{ScalarSubquery => ScalarSubqueryExpr, _} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ @@ -1062,16 +1062,6 @@ class Dataset[T] private[sql]( ) } - /** @inheritdoc */ - def scalar(): Column = { - Column(ExpressionColumnNode(ScalarSubqueryExpr(logicalPlan))) - } - - /** @inheritdoc */ - def exists(): Column = { - Column(ExpressionColumnNode(Exists(logicalPlan))) - } - /** @inheritdoc */ @scala.annotation.varargs def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withSameTypedPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala index 5766535ac5da..5aec32c572da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala @@ -28,11 +28,12 @@ import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF, TypedAggregateExpression} import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator} -import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SQLConf, SqlExpression, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame} +import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SQLConf, SqlExpression, SubqueryExpression, SubqueryType, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame} import org.apache.spark.sql.types.{DataType, NullType} /** @@ -192,6 +193,17 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres case l: LazyExpression => analysis.LazyExpression(apply(l.child)) + case SubqueryExpression(ds, subqueryType, _) => + subqueryType match { + case SubqueryType.SCALAR => + expressions.ScalarSubquery(ds.logicalPlan) + case SubqueryType.EXISTS => + expressions.Exists(ds.logicalPlan) + case SubqueryType.IN(values) => + expressions.InSubquery( + values.map(value => apply(value)), expressions.ListQuery(ds.logicalPlan)) + } + case node => throw SparkException.internalError("Unsupported ColumnNode: " + node) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala index f5782124243c..6e900c3c8b5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala @@ -47,10 +47,13 @@ class DataFrameSubquerySuite extends QueryTest with SharedSparkSession { row((null, 5.0)), row((6, null))).toDF("c", "d") + lazy val t = r.filter($"c".isNotNull && $"d".isNotNull) + protected override def beforeAll(): Unit = { super.beforeAll() l.createOrReplaceTempView("l") r.createOrReplaceTempView("r") + t.createOrReplaceTempView("t") } test("noop outer()") { @@ -378,6 +381,127 @@ class DataFrameSubquerySuite extends QueryTest with SharedSparkSession { ) } + test("IN predicate subquery") { + checkAnswer( + spark.table("l").where($"l.a".isin(spark.table("r").select($"c"))), + sql("select * from l where l.a in (select c from r)")) + + checkAnswer( + spark + .table("l") + .where($"l.a".isin(spark.table("r").where($"l.b".outer() < $"r.d").select($"c"))), + sql("select * from l where l.a in (select c from r where l.b < r.d)")) + + checkAnswer( + spark + .table("l") + .where($"l.a".isin(spark.table("r").select("c")) && $"l.a" > 2 && $"l.b".isNotNull), + sql("select * from l where l.a in (select c from r) and l.a > 2 and l.b is not null")) + } + + test("IN predicate subquery with struct") { + withTempView("ll", "rr") { + spark.table("l").select($"*", struct("a", "b").alias("sab")).createOrReplaceTempView("ll") + spark + .table("r") + .select($"*", struct($"c".as("a"), $"d".as("b")).alias("scd")) + .createOrReplaceTempView("rr") + + for ((col, values) <- Seq( + ($"sab", "sab"), + (struct(struct($"a", $"b")), "struct(struct(a, b))")); + (df, query) <- Seq( + (spark.table("rr").select($"scd"), "select scd from rr"), + ( + spark.table("rr").select(struct($"c".as("a"), $"d".as("b"))), + "select struct(c as a, d as b) from rr"), + (spark.table("rr").select(struct($"c", $"d")), "select struct(c, d) from rr"))) { + checkAnswer( + spark.table("ll").where(col.isin(df)).select($"a", $"b"), + sql(s"select a, b from ll where $values in ($query)")) + } + } + } + + test("NOT IN predicate subquery") { + checkAnswer( + spark.table("l").where(!$"a".isin(spark.table("r").select($"c"))), + sql("select * from l where a not in (select c from r)")) + + checkAnswer( + spark.table("l").where(!$"a".isin(spark.table("r").where($"c".isNotNull).select($"c"))), + sql("select * from l where a not in (select c from r where c is not null)")) + + checkAnswer( + spark + .table("l") + .where(!struct($"a", $"b").isin(spark.table("t").select($"c", $"d")) && $"a" < 4), + sql("select * from l where (a, b) not in (select c, d from t) and a < 4")) + + // Empty sub-query + checkAnswer( + spark + .table("l") + .where(!struct($"a", $"b").isin(spark.table("r").where($"c" > 10).select($"c", $"d"))), + sql("select * from l where (a, b) not in (select c, d from r where c > 10)")) + } + + test("IN predicate subquery within OR") { + checkAnswer( + spark + .table("l") + .where($"l.a".isin(spark.table("r").select("c")) + || $"l.a".isin(spark.table("r").where($"l.b".outer() < $"r.d").select($"c"))), + sql( + "select * from l where l.a in (select c from r)" + + " or l.a in (select c from r where l.b < r.d)")) + + checkAnswer( + spark + .table("l") + .where(!$"a".isin(spark.table("r").select("c")) + || !$"a".isin(spark.table("r").where($"c".isNotNull).select($"c"))), + sql( + "select * from l where a not in (select c from r)" + + " or a not in (select c from r where c is not null)")) + } + + test("complex IN predicate subquery") { + checkAnswer( + spark.table("l").where(!struct($"a", $"b").isin(spark.table("r").select($"c", $"d"))), + sql("select * from l where (a, b) not in (select c, d from r)")) + + checkAnswer( + spark + .table("l") + .where(!struct($"a", $"b").isin(spark.table("t").select($"c", $"d")) + && ($"a" + $"b").isNotNull), + sql("select * from l where (a, b) not in (select c, d from t) and (a + b) is not null")) + } + + test("same column in subquery and outer table") { + checkAnswer( + spark + .table("l") + .as("l1") + .where( + $"a".isin( + spark + .table("l") + .where($"a" < lit(3)) + .groupBy($"a") + .agg(Map.empty[String, String]))) + .select($"a"), + sql("select a from l l1 where a in (select a from l where a < 3 group by a)")) + } + + test("col IN (NULL)") { + checkAnswer(spark.table("l").where($"a".isin(null)), sql("SELECT * FROM l WHERE a IN (NULL)")) + checkAnswer( + spark.table("l").where(!$"a".isin(null)), + sql("SELECT * FROM l WHERE a NOT IN (NULL)")) + } + private def table1() = { sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)") spark.table("t1") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org