This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b4faf063e375 [SPARK-50131][SQL] Add IN Subquery DataFrame API
b4faf063e375 is described below

commit b4faf063e37597cf09b9e210766972303fbebe94
Author: Takuya Ueshin <ues...@databricks.com>
AuthorDate: Tue Apr 15 13:55:15 2025 -0700

    [SPARK-50131][SQL] Add IN Subquery DataFrame API
    
    ### What changes were proposed in this pull request?
    
    Adds IN Subquery DataFrame API.
    
    The existing `Column.isin` will take `Dataset`/`DataFrame` and use it as IN 
subquery.
    
    For example:
    
    ```py
    >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob"), (8, "Mike")], 
["age", "name"])
    >>> df.createOrReplaceTempView('t')
    
    >>> # SELECT * FROM t WHERE age IN (SELECT * FROM range(6)) ORDER BY age
    >>> df.where(df.age.isin(spark.range(6))).orderBy("age").show()
    +---+-----+
    |age| name|
    +---+-----+
    |  2|Alice|
    |  5|  Bob|
    +---+-----+
    ```
    
    For multiple values to be compared, `sf.struct` will be used to tuple them:
    
    ```py
    >>> from pyspark.sql import functions as sf
    
    >>> # SELECT * FROM t WHERE (age, name) IN (SELECT id, 'Bob' FROM range(6))
    >>> df.where(sf.struct(df.age, df.name).isin(spark.range(6).select("id", 
sf.lit("Bob")))).show()
    +---+----+
    |age|name|
    +---+----+
    |  5| Bob|
    +---+----+
    ```
    
    ### Why are the changes needed?
    
    The IN subquery was missing in DataFrame API.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, `Column.isin` will take `Dataset`/`DataFrame` to use it as IN subquery.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50470 from ueshin/issues/SPARK-50131/in_subquery.
    
    Authored-by: Takuya Ueshin <ues...@databricks.com>
    Signed-off-by: Takuya Ueshin <ues...@databricks.com>
---
 python/pyspark/sql/classic/column.py               |   7 +
 python/pyspark/sql/column.py                       |  25 ++-
 python/pyspark/sql/connect/column.py               |  13 ++
 python/pyspark/sql/connect/expressions.py          |  24 ++-
 .../pyspark/sql/connect/proto/expressions_pb2.py   |  12 +-
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  11 ++
 python/pyspark/sql/tests/test_subquery.py          | 207 ++++++++++++++++++++-
 .../main/scala/org/apache/spark/sql/Column.scala   |  20 ++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |   8 +-
 .../apache/spark/sql/internal/columnNodes.scala    |  22 +++
 .../sql/catalyst/expressions/predicates.scala      |   9 +-
 .../apache/spark/sql/DataFrameSubquerySuite.scala  | 125 +++++++++++++
 .../main/protobuf/spark/connect/expressions.proto  |   4 +
 .../org/apache/spark/sql/connect/Dataset.scala     |  10 -
 .../apache/spark/sql/connect/SparkSession.scala    |   7 +-
 .../spark/sql/connect/columnNodeSupport.scala      |  28 +--
 .../sql/connect/planner/SparkConnectPlanner.scala  |   6 +
 .../org/apache/spark/sql/classic/Dataset.scala     |  12 +-
 .../spark/sql/classic/columnNodeSupport.scala      |  14 +-
 .../apache/spark/sql/DataFrameSubquerySuite.scala  | 124 ++++++++++++
 20 files changed, 624 insertions(+), 64 deletions(-)

diff --git a/python/pyspark/sql/classic/column.py 
b/python/pyspark/sql/classic/column.py
index 161f8ba4bb7a..fef65bcb5d54 100644
--- a/python/pyspark/sql/classic/column.py
+++ b/python/pyspark/sql/classic/column.py
@@ -474,6 +474,13 @@ class Column(ParentColumn):
         return Column(jc)
 
     def isin(self, *cols: Any) -> ParentColumn:
+        from pyspark.sql.classic.dataframe import DataFrame
+
+        if len(cols) == 1 and isinstance(cols[0], DataFrame):
+            df: DataFrame = cols[0]
+            jc = self._jc.isin(df._jdf)
+            return Column(jc)
+
         if len(cols) == 1 and isinstance(cols[0], (list, set)):
             cols = cast(Tuple, cols[0])
         cols = cast(
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index a055e4456495..db02cc80dbed 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -883,6 +883,9 @@ class Column(TableValuedFunctionArgument):
         .. versionchanged:: 3.4.0
             Supports Spark Connect.
 
+        .. versionchanged:: 4.1.0
+            Also takes a single :class:`DataFrame` to be used as IN subquery.
+
         Parameters
         ----------
         cols : Any
@@ -900,7 +903,7 @@ class Column(TableValuedFunctionArgument):
 
         Example 1: Filter rows with names in the specified values
 
-        >>> df[df.name.isin("Bob", "Mike")].show()
+        >>> df[df.name.isin("Bob", "Mike")].orderBy("age").show()
         +---+----+
         |age|name|
         +---+----+
@@ -925,6 +928,26 @@ class Column(TableValuedFunctionArgument):
         +---+----+
         |  8|Mike|
         +---+----+
+
+        Example 4: Take a :class:`DataFrame` and work as IN subquery
+
+        >>> df.where(df.age.isin(spark.range(6))).orderBy("age").show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  2|Alice|
+        |  5|  Bob|
+        +---+-----+
+
+        Example 5: Multiple values for IN subquery
+
+        >>> from pyspark.sql.functions import lit, struct
+        >>> df.where(struct(df.age, df.name).isin(spark.range(6).select("id", 
lit("Bob")))).show()
+        +---+----+
+        |age|name|
+        +---+----+
+        |  5| Bob|
+        +---+----+
         """
         ...
 
diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index 15d943175850..d6ed62ba4a52 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -47,6 +47,7 @@ from pyspark.sql.connect.expressions import (
     LiteralExpression,
     CaseWhen,
     SortOrder,
+    SubqueryExpression,
     CastExpression,
     WindowExpression,
     WithField,
@@ -461,6 +462,18 @@ class Column(ParentColumn):
         return Column(self._expr)
 
     def isin(self, *cols: Any) -> ParentColumn:
+        from pyspark.sql.connect.dataframe import DataFrame
+
+        if len(cols) == 1 and isinstance(cols[0], DataFrame):
+            if isinstance(self._expr, UnresolvedFunction) and self._expr._name 
== "struct":
+                values = self._expr.children
+            else:
+                values = [self._expr]
+
+            return Column(
+                SubqueryExpression(cols[0]._plan, subquery_type="in", 
in_subquery_values=values)
+            )
+
         if len(cols) == 1 and isinstance(cols[0], (list, set)):
             _cols = list(cols[0])
         else:
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index e5b10be41963..872770ee2291 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -1241,9 +1241,10 @@ class SubqueryExpression(Expression):
         partition_spec: Optional[Sequence["Expression"]] = None,
         order_spec: Optional[Sequence["SortOrder"]] = None,
         with_single_partition: Optional[bool] = None,
+        in_subquery_values: Optional[Sequence["Expression"]] = None,
     ) -> None:
         assert isinstance(subquery_type, str)
-        assert subquery_type in ("scalar", "exists", "table_arg")
+        assert subquery_type in ("scalar", "exists", "table_arg", "in")
 
         super().__init__()
         self._plan = plan
@@ -1251,6 +1252,7 @@ class SubqueryExpression(Expression):
         self._partition_spec = partition_spec or []
         self._order_spec = order_spec or []
         self._with_single_partition = with_single_partition
+        self._in_subquery_values = in_subquery_values or []
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
         expr = self._create_proto_expression()
@@ -1276,17 +1278,25 @@ class SubqueryExpression(Expression):
                 )
             if self._with_single_partition is not None:
                 table_arg_options.with_single_partition = 
self._with_single_partition
+        elif self._subquery_type == "in":
+            expr.subquery_expression.subquery_type = 
proto.SubqueryExpression.SUBQUERY_TYPE_IN
+            expr.subquery_expression.in_subquery_values.extend(
+                [expr.to_plan(session) for expr in self._in_subquery_values]
+            )
 
         return expr
 
     def __repr__(self) -> str:
         repr_parts = [f"plan={self._plan}", f"type={self._subquery_type}"]
 
-        if self._partition_spec:
-            repr_parts.append(f"partition_spec={self._partition_spec}")
-        if self._order_spec:
-            repr_parts.append(f"order_spec={self._order_spec}")
-        if self._with_single_partition is not None:
-            
repr_parts.append(f"with_single_partition={self._with_single_partition}")
+        if self._subquery_type == "table_arg":
+            if self._partition_spec:
+                repr_parts.append(f"partition_spec={self._partition_spec}")
+            if self._order_spec:
+                repr_parts.append(f"order_spec={self._order_spec}")
+            if self._with_single_partition is not None:
+                
repr_parts.append(f"with_single_partition={self._with_single_partition}")
+        elif self._subquery_type == "in":
+            repr_parts.append(f"values={self._in_subquery_values}")
 
         return f"SubqueryExpression({', '.join(repr_parts)})"
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index f5fdd162a708..0cec23f4857d 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import common_pb2 as 
spark_dot_connect_dot_common
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xf3\x34\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
 
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolved 
[...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xf3\x34\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
 
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolved 
[...]
 )
 
 _globals = globals()
@@ -130,9 +130,9 @@ if not _descriptor._USE_C_DESCRIPTORS:
     _globals["_MERGEACTION_ACTIONTYPE"]._serialized_start = 8586
     _globals["_MERGEACTION_ACTIONTYPE"]._serialized_end = 8753
     _globals["_SUBQUERYEXPRESSION"]._serialized_start = 8770
-    _globals["_SUBQUERYEXPRESSION"]._serialized_end = 9383
-    _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_start = 9003
-    _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_end = 9237
-    _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_start = 9239
-    _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_end = 9361
+    _globals["_SUBQUERYEXPRESSION"]._serialized_end = 9479
+    _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_start = 9076
+    _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_end = 9310
+    _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_start = 9313
+    _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_end = 9457
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index f6aada59a2d8..25fc04c0319e 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1940,12 +1940,14 @@ class 
SubqueryExpression(google.protobuf.message.Message):
         SUBQUERY_TYPE_SCALAR: SubqueryExpression._SubqueryType.ValueType  # 1
         SUBQUERY_TYPE_EXISTS: SubqueryExpression._SubqueryType.ValueType  # 2
         SUBQUERY_TYPE_TABLE_ARG: SubqueryExpression._SubqueryType.ValueType  # 
3
+        SUBQUERY_TYPE_IN: SubqueryExpression._SubqueryType.ValueType  # 4
 
     class SubqueryType(_SubqueryType, metaclass=_SubqueryTypeEnumTypeWrapper): 
...
     SUBQUERY_TYPE_UNKNOWN: SubqueryExpression.SubqueryType.ValueType  # 0
     SUBQUERY_TYPE_SCALAR: SubqueryExpression.SubqueryType.ValueType  # 1
     SUBQUERY_TYPE_EXISTS: SubqueryExpression.SubqueryType.ValueType  # 2
     SUBQUERY_TYPE_TABLE_ARG: SubqueryExpression.SubqueryType.ValueType  # 3
+    SUBQUERY_TYPE_IN: SubqueryExpression.SubqueryType.ValueType  # 4
 
     class TableArgOptions(google.protobuf.message.Message):
         """Nested message for table argument options."""
@@ -2010,6 +2012,7 @@ class SubqueryExpression(google.protobuf.message.Message):
     PLAN_ID_FIELD_NUMBER: builtins.int
     SUBQUERY_TYPE_FIELD_NUMBER: builtins.int
     TABLE_ARG_OPTIONS_FIELD_NUMBER: builtins.int
+    IN_SUBQUERY_VALUES_FIELD_NUMBER: builtins.int
     plan_id: builtins.int
     """(Required) The ID of the corresponding connect plan."""
     subquery_type: global___SubqueryExpression.SubqueryType.ValueType
@@ -2017,12 +2020,18 @@ class 
SubqueryExpression(google.protobuf.message.Message):
     @property
     def table_arg_options(self) -> global___SubqueryExpression.TableArgOptions:
         """(Optional) Options specific to table arguments."""
+    @property
+    def in_subquery_values(
+        self,
+    ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Expression]:
+        """(Optional) IN subquery values."""
     def __init__(
         self,
         *,
         plan_id: builtins.int = ...,
         subquery_type: global___SubqueryExpression.SubqueryType.ValueType = 
...,
         table_arg_options: global___SubqueryExpression.TableArgOptions | None 
= ...,
+        in_subquery_values: collections.abc.Iterable[global___Expression] | 
None = ...,
     ) -> None: ...
     def HasField(
         self,
@@ -2035,6 +2044,8 @@ class SubqueryExpression(google.protobuf.message.Message):
         field_name: typing_extensions.Literal[
             "_table_arg_options",
             b"_table_arg_options",
+            "in_subquery_values",
+            b"in_subquery_values",
             "plan_id",
             b"plan_id",
             "subquery_type",
diff --git a/python/pyspark/sql/tests/test_subquery.py 
b/python/pyspark/sql/tests/test_subquery.py
index 7c63ddb69458..7c87f4b46cc6 100644
--- a/python/pyspark/sql/tests/test_subquery.py
+++ b/python/pyspark/sql/tests/test_subquery.py
@@ -28,7 +28,7 @@ class SubqueryTestsMixin:
     def df1(self):
         return self.spark.createDataFrame(
             [
-                (1, 1.0),
+                (1, 2.0),
                 (1, 2.0),
                 (2, 1.0),
                 (2, 2.0),
@@ -459,6 +459,211 @@ class SubqueryTestsMixin:
                     ),
                 )
 
+    def test_in_subquery(self):
+        with self.tempView("l", "r", "t"):
+            self.df1.createOrReplaceTempView("l")
+            self.df2.createOrReplaceTempView("r")
+            self.spark.table("r").filter(
+                sf.col("c").isNotNull() & sf.col("d").isNotNull()
+            ).createOrReplaceTempView("t")
+
+            with self.subTest("IN"):
+                assertDataFrameEqual(
+                    self.spark.table("l").where(
+                        
sf.col("l.a").isin(self.spark.table("r").select(sf.col("c")))
+                    ),
+                    self.spark.sql("""select * from l where l.a in (select c 
from r)"""),
+                )
+                assertDataFrameEqual(
+                    self.spark.table("l").where(
+                        sf.col("l.a").isin(
+                            self.spark.table("r")
+                            .where(sf.col("l.b").outer() < sf.col("r.d"))
+                            .select(sf.col("c"))
+                        )
+                    ),
+                    self.spark.sql(
+                        """select * from l where l.a in (select c from r where 
l.b < r.d)"""
+                    ),
+                )
+                assertDataFrameEqual(
+                    self.spark.table("l").where(
+                        sf.col("l.a").isin(self.spark.table("r").select("c"))
+                        & (sf.col("l.a") > sf.lit(2))
+                        & sf.col("l.b").isNotNull()
+                    ),
+                    self.spark.sql(
+                        """
+                        select * from l
+                        where l.a in (select c from r) and l.a > 2 and l.b is 
not null
+                        """
+                    ),
+                )
+
+            with self.subTest("IN with struct"), self.tempView("ll", "rr"):
+                self.spark.table("l").select(
+                    "*", sf.struct("a", "b").alias("sab")
+                ).createOrReplaceTempView("ll")
+                self.spark.table("r").select(
+                    "*", sf.struct(sf.col("c").alias("a"), 
sf.col("d").alias("b")).alias("scd")
+                ).createOrReplaceTempView("rr")
+
+                for col, values in [
+                    (sf.col("sab"), "sab"),
+                    (sf.struct(sf.struct(sf.col("a"), sf.col("b"))), 
"struct(struct(a, b))"),
+                ]:
+                    for df, query in [
+                        (self.spark.table("rr").select(sf.col("scd")), "select 
scd from rr"),
+                        (
+                            self.spark.table("rr").select(
+                                sf.struct(sf.col("c").alias("a"), 
sf.col("d").alias("b"))
+                            ),
+                            "select struct(c as a, d as b) from rr",
+                        ),
+                        (
+                            
self.spark.table("rr").select(sf.struct(sf.col("c"), sf.col("d"))),
+                            "select struct(c, d) from rr",
+                        ),
+                    ]:
+                        sql_query = f"""select a, b from ll where {values} in 
({query})"""
+                        with self.subTest(sql_query=sql_query):
+                            assertDataFrameEqual(
+                                
self.spark.table("ll").where(col.isin(df)).select("a", "b"),
+                                self.spark.sql(sql_query),
+                            )
+
+            with self.subTest("NOT IN"):
+                assertDataFrameEqual(
+                    self.spark.table("l").where(
+                        ~sf.col("a").isin(self.spark.table("r").select("c"))
+                    ),
+                    self.spark.sql("""select * from l where a not in (select c 
from r)"""),
+                )
+                assertDataFrameEqual(
+                    self.spark.table("l").where(
+                        ~sf.col("a").isin(
+                            
self.spark.table("r").where(sf.col("c").isNotNull()).select(sf.col("c"))
+                        )
+                    ),
+                    self.spark.sql(
+                        """select * from l where a not in (select c from r 
where c is not null)"""
+                    ),
+                )
+                assertDataFrameEqual(
+                    self.spark.table("l").where(
+                        (
+                            ~sf.struct(sf.col("a"), sf.col("b")).isin(
+                                self.spark.table("t").select(sf.col("c"), 
sf.col("d"))
+                            )
+                        )
+                        & (sf.col("a") < sf.lit(4))
+                    ),
+                    self.spark.sql(
+                        """select * from l where (a, b) not in (select c, d 
from t) and a < 4"""
+                    ),
+                )
+                assertDataFrameEqual(
+                    self.spark.table("l").where(
+                        ~sf.struct(sf.col("a"), sf.col("b")).isin(
+                            self.spark.table("r")
+                            .where(sf.col("c") > sf.lit(10))
+                            .select(sf.col("c"), sf.col("d"))
+                        )
+                    ),
+                    self.spark.sql(
+                        """select * from l where (a, b) not in (select c, d 
from r where c > 10)"""
+                    ),
+                )
+
+            with self.subTest("IN within OR"):
+                assertDataFrameEqual(
+                    self.spark.table("l").where(
+                        sf.col("l.a").isin(self.spark.table("r").select("c"))
+                        | (
+                            sf.col("l.a").isin(
+                                self.spark.table("r")
+                                .where(sf.col("l.b").outer() < sf.col("r.d"))
+                                .select(sf.col("c"))
+                            )
+                        )
+                    ),
+                    self.spark.sql(
+                        """
+                        select * from l
+                        where l.a in (select c from r) or l.a in (select c 
from r where l.b < r.d)
+                        """
+                    ),
+                )
+                assertDataFrameEqual(
+                    self.spark.table("l").where(
+                        
(~sf.col("a").isin(self.spark.table("r").select(sf.col("c"))))
+                        | (
+                            ~sf.col("a").isin(
+                                self.spark.table("r")
+                                .where(sf.col("c").isNotNull())
+                                .select(sf.col("c"))
+                            )
+                        )
+                    ),
+                    self.spark.sql(
+                        """
+                        select * from l
+                        where a not in (select c from r)
+                        or a not in (select c from r where c is not null)
+                        """
+                    ),
+                )
+
+            with self.subTest("complex IN"):
+                assertDataFrameEqual(
+                    self.spark.table("l").where(
+                        ~sf.struct(sf.col("a"), sf.col("b")).isin(
+                            self.spark.table("r").select(sf.col("c"), 
sf.col("d"))
+                        )
+                    ),
+                    self.spark.sql("""select * from l where (a, b) not in 
(select c, d from r)"""),
+                )
+                assertDataFrameEqual(
+                    self.spark.table("l").where(
+                        (
+                            ~sf.struct(sf.col("a"), sf.col("b")).isin(
+                                self.spark.table("t").select(sf.col("c"), 
sf.col("d"))
+                            )
+                        )
+                        & ((sf.col("a") + sf.col("b")).isNotNull())
+                    ),
+                    self.spark.sql(
+                        """
+                        select * from l
+                        where (a, b) not in (select c, d from t) and (a + b) 
is not null
+                        """
+                    ),
+                )
+
+            with self.subTest("same column in subquery"):
+                assertDataFrameEqual(
+                    self.spark.table("l")
+                    .alias("l1")
+                    .where(
+                        sf.col("a").isin(
+                            self.spark.table("l")
+                            .where(sf.col("a") < sf.lit(3))
+                            .groupBy(sf.col("a"))
+                            .agg({})
+                        )
+                    )
+                    .select(sf.col("a")),
+                    self.spark.sql(
+                        """select a from l l1 where a in (select a from l 
where a < 3 group by a)"""
+                    ),
+                )
+
+            with self.subTest("col IN (NULL)"):
+                assertDataFrameEqual(
+                    self.spark.table("l").where(sf.col("a").isin(None)),
+                    self.spark.sql("""SELECT * FROM l WHERE a IN (NULL)"""),
+                )
+
     def test_scalar_subquery_with_missing_outer_reference(self):
         with self.tempView("l", "r"):
             self.df1.createOrReplaceTempView("l")
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
index 7f5eed1eb1ad..88d597fdfbb7 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
@@ -803,6 +803,26 @@ class Column(val node: ColumnNode) extends Logging with 
TableValuedFunctionArgum
    */
   def isInCollection(values: java.lang.Iterable[_]): Column = 
isInCollection(values.asScala)
 
+  /**
+   * A boolean expression that is evaluated to true if the value of this 
expression is contained
+   * by the provided Dataset/DataFrame.
+   *
+   * @group subquery
+   * @since 4.1.0
+   */
+  def isin(ds: Dataset[_]): Column = {
+    if (ds == null) {
+      // A single null should be handled as a value.
+      isin(Seq(ds): _*)
+    } else {
+      val values = node match {
+        case internal.UnresolvedFunction("struct", arguments, _, _, _, _) => 
arguments
+        case _ => Seq(node)
+      }
+      Column(internal.SubqueryExpression(ds, internal.SubqueryType.IN(values)))
+    }
+  }
+
   /**
    * SQL like expression. Returns a boolean column based on a SQL LIKE match.
    *
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
index 4952fa36f66e..c287ad69de2e 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1758,7 +1758,9 @@ abstract class Dataset[T] extends Serializable {
    * @group subquery
    * @since 4.0.0
    */
-  def scalar(): Column
+  def scalar(): Column = {
+    Column(internal.SubqueryExpression(this, internal.SubqueryType.SCALAR))
+  }
 
   /**
    * Return a `Column` object for an EXISTS Subquery.
@@ -1771,7 +1773,9 @@ abstract class Dataset[T] extends Serializable {
    * @group subquery
    * @since 4.0.0
    */
-  def exists(): Column
+  def exists(): Column = {
+    Column(internal.SubqueryExpression(this, internal.SubqueryType.EXISTS))
+  }
 
   /**
    * Define (named) metrics to observe on the Dataset. This method returns an 
'observed' Dataset
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
index 463307409839..4a7339165cb5 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
@@ -20,6 +20,7 @@ import java.util.concurrent.atomic.AtomicLong
 
 import ColumnNode._
 
+import org.apache.spark.sql.Dataset
 import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
 import org.apache.spark.sql.catalyst.util.AttributeNameParser
 import org.apache.spark.sql.errors.DataTypeErrorsBase
@@ -651,3 +652,24 @@ private[sql] case class LazyExpression(
   override def sql: String = "lazy" + argumentsToSql(Seq(child))
   override private[internal] def children: Seq[ColumnNodeLike] = Seq(child)
 }
+
+sealed trait SubqueryType
+
+object SubqueryType {
+  case object SCALAR extends SubqueryType
+  case object EXISTS extends SubqueryType
+  case class IN(values: Seq[ColumnNode]) extends SubqueryType
+}
+
+case class SubqueryExpression(
+    ds: Dataset[_],
+    subqueryType: SubqueryType,
+    override val origin: Origin = CurrentOrigin.get)
+    extends ColumnNode {
+  override def sql: String = subqueryType match {
+    case SubqueryType.SCALAR => s"($ds)"
+    case SubqueryType.IN(values) => s"(${values.map(_.sql).mkString(",")}) IN 
($ds)"
+    case _ => s"$subqueryType ($ds)"
+  }
+  override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 3e23f0551125..c31c72bc1148 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -23,7 +23,7 @@ import scala.collection.immutable.TreeSet
 import org.apache.spark.SparkUnsupportedOperationException
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, 
UnresolvedPlanId}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
 import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
 import org.apache.spark.sql.catalyst.expressions.Cast._
@@ -419,6 +419,13 @@ case class InSubquery(values: Seq[Expression], query: 
ListQuery)
     copy(values = newChildren.dropRight(1), query = 
newChildren.last.asInstanceOf[ListQuery])
 }
 
+case class UnresolvedInSubqueryPlanId(values: Seq[Expression], planId: Long)
+  extends UnresolvedPlanId {
+
+  override def withPlan(plan: LogicalPlan): Expression = {
+    InSubquery(values, ListQuery(plan))
+  }
+}
 
 /**
  * Evaluates to `true` if `list` contains `value`.
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index ba77879a5a80..9e93cd4442d5 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -45,10 +45,13 @@ class DataFrameSubquerySuite extends QueryTest with 
RemoteSparkSession {
     row((null, 5.0)),
     row((6, null))).toDF("c", "d")
 
+  lazy val t = r.filter($"c".isNotNull && $"d".isNotNull)
+
   override def beforeAll(): Unit = {
     super.beforeAll()
     l.createOrReplaceTempView("l")
     r.createOrReplaceTempView("r")
+    t.createOrReplaceTempView("t")
   }
 
   test("noop outer()") {
@@ -318,6 +321,128 @@ class DataFrameSubquerySuite extends QueryTest with 
RemoteSparkSession {
       sql("select a, (select sum(d) from r where c = a) from l"))
   }
 
+  test("IN predicate subquery") {
+    checkAnswer(
+      spark.table("l").where($"l.a".isin(spark.table("r").select($"c"))),
+      sql("select * from l where l.a in (select c from r)"))
+
+    checkAnswer(
+      spark
+        .table("l")
+        .where($"l.a".isin(spark.table("r").where($"l.b".outer() < 
$"r.d").select($"c"))),
+      sql("select * from l where l.a in (select c from r where l.b < r.d)"))
+
+    checkAnswer(
+      spark
+        .table("l")
+        .where($"l.a".isin(spark.table("r").select("c")) && $"l.a" > 2 && 
$"l.b".isNotNull),
+      sql("select * from l where l.a in (select c from r) and l.a > 2 and l.b 
is not null"))
+  }
+
+  test("IN predicate subquery with struct") {
+    withTempView("ll", "rr") {
+      spark.table("l").select($"*", struct("a", 
"b").alias("sab")).createOrReplaceTempView("ll")
+      spark
+        .table("r")
+        .select($"*", struct($"c".as("a"), $"d".as("b")).alias("scd"))
+        .createOrReplaceTempView("rr")
+
+      for ((col, values) <- Seq(
+          ($"sab", "sab"),
+          (struct(struct($"a", $"b")), "struct(struct(a, b))"));
+        (df, query) <- Seq(
+          (spark.table("rr").select($"scd"), "select scd from rr"),
+          (
+            spark.table("rr").select(struct($"c".as("a"), $"d".as("b"))),
+            "select struct(c as a, d as b) from rr"),
+          (spark.table("rr").select(struct($"c", $"d")), "select struct(c, d) 
from rr"))) {
+        checkAnswer(
+          spark.table("ll").where(col.isin(df)).select($"a", $"b"),
+          sql(s"select a, b from ll where $values in ($query)"))
+      }
+    }
+  }
+
+  test("NOT IN predicate subquery") {
+    checkAnswer(
+      spark.table("l").where(!$"a".isin(spark.table("r").select($"c"))),
+      sql("select * from l where a not in (select c from r)"))
+
+    checkAnswer(
+      
spark.table("l").where(!$"a".isin(spark.table("r").where($"c".isNotNull).select($"c"))),
+      sql("select * from l where a not in (select c from r where c is not 
null)"))
+
+    checkAnswer(
+      spark
+        .table("l")
+        .where(!struct($"a", $"b").isin(spark.table("t").select($"c", $"d")) 
&& $"a" < lit(4)),
+      sql("select * from l where (a, b) not in (select c, d from t) and a < 
4"))
+
+    // Empty sub-query
+    checkAnswer(
+      spark
+        .table("l")
+        .where(
+          !struct($"a", $"b").isin(spark.table("r").where($"c" > 
lit(10)).select($"c", $"d"))),
+      sql("select * from l where (a, b) not in (select c, d from r where c > 
10)"))
+  }
+
+  test("IN predicate subquery within OR") {
+    checkAnswer(
+      spark
+        .table("l")
+        .where($"l.a".isin(spark.table("r").select("c"))
+          || $"l.a".isin(spark.table("r").where($"l.b".outer() < 
$"r.d").select($"c"))),
+      sql(
+        "select * from l where l.a in (select c from r)" +
+          " or l.a in (select c from r where l.b < r.d)"))
+
+    checkAnswer(
+      spark
+        .table("l")
+        .where(!$"a".isin(spark.table("r").select("c"))
+          || !$"a".isin(spark.table("r").where($"c".isNotNull).select($"c"))),
+      sql(
+        "select * from l where a not in (select c from r)" +
+          " or a not in (select c from r where c is not null)"))
+  }
+
+  test("complex IN predicate subquery") {
+    checkAnswer(
+      spark.table("l").where(!struct($"a", 
$"b").isin(spark.table("r").select($"c", $"d"))),
+      sql("select * from l where (a, b) not in (select c, d from r)"))
+
+    checkAnswer(
+      spark
+        .table("l")
+        .where(!struct($"a", $"b").isin(spark.table("t").select($"c", $"d"))
+          && ($"a" + $"b").isNotNull),
+      sql("select * from l where (a, b) not in (select c, d from t) and (a + 
b) is not null"))
+  }
+
+  test("same column in subquery and outer table") {
+    checkAnswer(
+      spark
+        .table("l")
+        .as("l1")
+        .where(
+          $"a".isin(
+            spark
+              .table("l")
+              .where($"a" < lit(3))
+              .groupBy($"a")
+              .agg(Map.empty[String, String])))
+        .select($"a"),
+      sql("select a from l l1 where a in (select a from l where a < 3 group by 
a)"))
+  }
+
+  test("col IN (NULL)") {
+    checkAnswer(spark.table("l").where($"a".isin(null)), sql("SELECT * FROM l 
WHERE a IN (NULL)"))
+    checkAnswer(
+      spark.table("l").where(!$"a".isin(null)),
+      sql("SELECT * FROM l WHERE a NOT IN (NULL)"))
+  }
+
   private def table1() = {
     sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
     spark.table("t1")
diff --git 
a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 4dcaa9a40142..df907a84868f 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -483,11 +483,15 @@ message SubqueryExpression {
   // (Optional) Options specific to table arguments.
   optional TableArgOptions table_arg_options = 3;
 
+  // (Optional) IN subquery values.
+  repeated Expression in_subquery_values = 4;
+
   enum SubqueryType {
     SUBQUERY_TYPE_UNKNOWN = 0;
     SUBQUERY_TYPE_SCALAR = 1;
     SUBQUERY_TYPE_EXISTS = 2;
     SUBQUERY_TYPE_TABLE_ARG = 3;
+    SUBQUERY_TYPE_IN = 4;
   }
 
   // Nested message for table argument options.
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
index 419ac3b7f74a..ec169ba114a3 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
@@ -626,16 +626,6 @@ class Dataset[T] private[sql] (
   def transpose(): DataFrame =
     buildTranspose(Seq.empty)
 
-  /** @inheritdoc */
-  def scalar(): Column = {
-    Column(SubqueryExpressionNode(plan.getRoot, SubqueryType.SCALAR))
-  }
-
-  /** @inheritdoc */
-  def exists(): Column = {
-    Column(SubqueryExpressionNode(plan.getRoot, SubqueryType.EXISTS))
-  }
-
   /** @inheritdoc */
   def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { 
builder =>
     builder.getLimitBuilder
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index cf7b6a1ddd6d..739b0318759e 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -49,10 +49,11 @@ import org.apache.spark.sql.catalyst.{JavaTypeInference, 
ScalaReflection}
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, 
BoxedLongEncoder, UnboundRowEncoder}
 import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
+import org.apache.spark.sql.connect.ConnectConversions._
 import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, 
SparkConnectClient, SparkResult}
 import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
 import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
-import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
+import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf, 
SubqueryExpression}
 import org.apache.spark.sql.sources.BaseRelation
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.ExecutionListenerManager
@@ -420,8 +421,8 @@ class SparkSession private[sql] (
   @DeveloperApi
   def newDataset[T](encoder: AgnosticEncoder[T], cols: Seq[Column])(
       f: proto.Relation.Builder => Unit): Dataset[T] = {
-    val references = cols.flatMap(_.node.collect { case n: 
SubqueryExpressionNode =>
-      n.relation
+    val references: Seq[proto.Relation] = cols.flatMap(_.node.collect {
+      case n: SubqueryExpression => n.ds.plan.getRoot
     })
 
     val builder = proto.Relation.newBuilder()
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
index 54f45a434826..1e798387726b 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
@@ -27,10 +27,11 @@ import 
org.apache.spark.connect.proto.Expression.Window.WindowFrame.{FrameBounda
 import org.apache.spark.sql.{functions, Column, Encoder}
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
+import org.apache.spark.sql.connect.ConnectConversions._
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter
 import 
org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder
 import org.apache.spark.sql.expressions.{Aggregator, 
UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
-import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, 
ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, 
LazyExpression, Literal, SortOrder, SqlExpression, UnresolvedAttribute, 
UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, 
UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame}
+import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, 
ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, 
LazyExpression, Literal, SortOrder, SqlExpression, SubqueryExpression, 
SubqueryType, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, 
UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, 
Window, WindowFrame}
 
 /**
  * Converter for [[ColumnNode]] to [[proto.Expression]] conversions.
@@ -218,11 +219,15 @@ object ColumnNodeToProtoConverter extends (ColumnNode => 
proto.Expression) {
       case LazyExpression(child, _) =>
         return apply(child, e)
 
-      case SubqueryExpressionNode(relation, subqueryType, _) =>
+      case SubqueryExpression(ds, subqueryType, _) =>
+        val relation = ds.plan.getRoot
         val b = builder.getSubqueryExpressionBuilder
         b.setSubqueryType(subqueryType match {
           case SubqueryType.SCALAR => 
proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_SCALAR
           case SubqueryType.EXISTS => 
proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_EXISTS
+          case SubqueryType.IN(values) =>
+            b.addAllInSubqueryValues(values.map(value => apply(value, 
e)).asJava)
+            proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_IN
         })
         assert(relation.hasCommon && relation.getCommon.hasPlanId)
         b.setPlanId(relation.getCommon.getPlanId)
@@ -311,22 +316,3 @@ case class ProtoColumnNode(
   override def sql: String = expr.toString
   override def children: Seq[ColumnNodeLike] = Seq.empty
 }
-
-sealed trait SubqueryType
-
-object SubqueryType {
-  case object SCALAR extends SubqueryType
-  case object EXISTS extends SubqueryType
-}
-
-case class SubqueryExpressionNode(
-    relation: proto.Relation,
-    subqueryType: SubqueryType,
-    override val origin: Origin = CurrentOrigin.get)
-    extends ColumnNode {
-  override def sql: String = subqueryType match {
-    case SubqueryType.SCALAR => s"($relation)"
-    case _ => s"$subqueryType ($relation)"
-  }
-  override def children: Seq[ColumnNodeLike] = Seq.empty
-}
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 734eb394ca68..a7a76e334e47 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -4025,6 +4025,12 @@ class SparkConnectPlanner(
         } else {
           UnresolvedTableArgPlanId(planId)
         }
+      case proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_IN =>
+        UnresolvedInSubqueryPlanId(
+          getSubqueryExpression.getInSubqueryValuesList.asScala.map { value =>
+            transformExpression(value)
+          }.toSeq,
+          planId)
       case other => throw InvalidPlanInput(s"Unknown SubqueryType $other")
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
index 366cc3f4b7d7..8765e1bfc7c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
 import org.apache.spark.sql.catalyst.encoders._
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, 
ProductEncoder, StructEncoder}
-import org.apache.spark.sql.catalyst.expressions.{ScalarSubquery => 
ScalarSubqueryExpr, _}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
 import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans._
@@ -1062,16 +1062,6 @@ class Dataset[T] private[sql](
     )
   }
 
-  /** @inheritdoc */
-  def scalar(): Column = {
-    Column(ExpressionColumnNode(ScalarSubqueryExpr(logicalPlan)))
-  }
-
-  /** @inheritdoc */
-  def exists(): Column = {
-    Column(ExpressionColumnNode(Exists(logicalPlan)))
-  }
-
   /** @inheritdoc */
   @scala.annotation.varargs
   def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = 
withSameTypedPlan {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala
index 5766535ac5da..5aec32c572da 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala
@@ -28,11 +28,12 @@ import 
org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
 import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
+import org.apache.spark.sql.classic.ClassicConversions._
 import org.apache.spark.sql.execution.SparkSqlParser
 import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF, 
TypedAggregateExpression}
 import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
 import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, 
UserDefinedAggregateFunction, UserDefinedAggregator}
-import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, 
ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, 
LazyExpression, Literal, SortOrder, SQLConf, SqlExpression, 
UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, 
UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, 
Window, WindowFrame}
+import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, 
ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, 
LazyExpression, Literal, SortOrder, SQLConf, SqlExpression, SubqueryExpression, 
SubqueryType, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, 
UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, 
Window, WindowFrame}
 import org.apache.spark.sql.types.{DataType, NullType}
 
 /**
@@ -192,6 +193,17 @@ private[sql] trait ColumnNodeToExpressionConverter extends 
(ColumnNode => Expres
         case l: LazyExpression =>
           analysis.LazyExpression(apply(l.child))
 
+        case SubqueryExpression(ds, subqueryType, _) =>
+          subqueryType match {
+            case SubqueryType.SCALAR =>
+              expressions.ScalarSubquery(ds.logicalPlan)
+            case SubqueryType.EXISTS =>
+              expressions.Exists(ds.logicalPlan)
+            case SubqueryType.IN(values) =>
+              expressions.InSubquery(
+                values.map(value => apply(value)), 
expressions.ListQuery(ds.logicalPlan))
+          }
+
         case node =>
           throw SparkException.internalError("Unsupported ColumnNode: " + node)
       }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index f5782124243c..6e900c3c8b5f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -47,10 +47,13 @@ class DataFrameSubquerySuite extends QueryTest with 
SharedSparkSession {
     row((null, 5.0)),
     row((6, null))).toDF("c", "d")
 
+  lazy val t = r.filter($"c".isNotNull && $"d".isNotNull)
+
   protected override def beforeAll(): Unit = {
     super.beforeAll()
     l.createOrReplaceTempView("l")
     r.createOrReplaceTempView("r")
+    t.createOrReplaceTempView("t")
   }
 
   test("noop outer()") {
@@ -378,6 +381,127 @@ class DataFrameSubquerySuite extends QueryTest with 
SharedSparkSession {
     )
   }
 
+  test("IN predicate subquery") {
+    checkAnswer(
+      spark.table("l").where($"l.a".isin(spark.table("r").select($"c"))),
+      sql("select * from l where l.a in (select c from r)"))
+
+    checkAnswer(
+      spark
+        .table("l")
+        .where($"l.a".isin(spark.table("r").where($"l.b".outer() < 
$"r.d").select($"c"))),
+      sql("select * from l where l.a in (select c from r where l.b < r.d)"))
+
+    checkAnswer(
+      spark
+        .table("l")
+        .where($"l.a".isin(spark.table("r").select("c")) && $"l.a" > 2 && 
$"l.b".isNotNull),
+      sql("select * from l where l.a in (select c from r) and l.a > 2 and l.b 
is not null"))
+  }
+
+  test("IN predicate subquery with struct") {
+    withTempView("ll", "rr") {
+      spark.table("l").select($"*", struct("a", 
"b").alias("sab")).createOrReplaceTempView("ll")
+      spark
+        .table("r")
+        .select($"*", struct($"c".as("a"), $"d".as("b")).alias("scd"))
+        .createOrReplaceTempView("rr")
+
+      for ((col, values) <- Seq(
+          ($"sab", "sab"),
+          (struct(struct($"a", $"b")), "struct(struct(a, b))"));
+        (df, query) <- Seq(
+          (spark.table("rr").select($"scd"), "select scd from rr"),
+          (
+            spark.table("rr").select(struct($"c".as("a"), $"d".as("b"))),
+            "select struct(c as a, d as b) from rr"),
+          (spark.table("rr").select(struct($"c", $"d")), "select struct(c, d) 
from rr"))) {
+        checkAnswer(
+          spark.table("ll").where(col.isin(df)).select($"a", $"b"),
+          sql(s"select a, b from ll where $values in ($query)"))
+      }
+    }
+  }
+
+  test("NOT IN predicate subquery") {
+    checkAnswer(
+      spark.table("l").where(!$"a".isin(spark.table("r").select($"c"))),
+      sql("select * from l where a not in (select c from r)"))
+
+    checkAnswer(
+      
spark.table("l").where(!$"a".isin(spark.table("r").where($"c".isNotNull).select($"c"))),
+      sql("select * from l where a not in (select c from r where c is not 
null)"))
+
+    checkAnswer(
+      spark
+        .table("l")
+        .where(!struct($"a", $"b").isin(spark.table("t").select($"c", $"d")) 
&& $"a" < 4),
+      sql("select * from l where (a, b) not in (select c, d from t) and a < 
4"))
+
+    // Empty sub-query
+    checkAnswer(
+      spark
+        .table("l")
+        .where(!struct($"a", $"b").isin(spark.table("r").where($"c" > 
10).select($"c", $"d"))),
+      sql("select * from l where (a, b) not in (select c, d from r where c > 
10)"))
+  }
+
+  test("IN predicate subquery within OR") {
+    checkAnswer(
+      spark
+        .table("l")
+        .where($"l.a".isin(spark.table("r").select("c"))
+          || $"l.a".isin(spark.table("r").where($"l.b".outer() < 
$"r.d").select($"c"))),
+      sql(
+        "select * from l where l.a in (select c from r)" +
+          " or l.a in (select c from r where l.b < r.d)"))
+
+    checkAnswer(
+      spark
+        .table("l")
+        .where(!$"a".isin(spark.table("r").select("c"))
+          || !$"a".isin(spark.table("r").where($"c".isNotNull).select($"c"))),
+      sql(
+        "select * from l where a not in (select c from r)" +
+          " or a not in (select c from r where c is not null)"))
+  }
+
+  test("complex IN predicate subquery") {
+    checkAnswer(
+      spark.table("l").where(!struct($"a", 
$"b").isin(spark.table("r").select($"c", $"d"))),
+      sql("select * from l where (a, b) not in (select c, d from r)"))
+
+    checkAnswer(
+      spark
+        .table("l")
+        .where(!struct($"a", $"b").isin(spark.table("t").select($"c", $"d"))
+          && ($"a" + $"b").isNotNull),
+      sql("select * from l where (a, b) not in (select c, d from t) and (a + 
b) is not null"))
+  }
+
+  test("same column in subquery and outer table") {
+    checkAnswer(
+      spark
+        .table("l")
+        .as("l1")
+        .where(
+          $"a".isin(
+            spark
+              .table("l")
+              .where($"a" < lit(3))
+              .groupBy($"a")
+              .agg(Map.empty[String, String])))
+        .select($"a"),
+      sql("select a from l l1 where a in (select a from l where a < 3 group by 
a)"))
+  }
+
+  test("col IN (NULL)") {
+    checkAnswer(spark.table("l").where($"a".isin(null)), sql("SELECT * FROM l 
WHERE a IN (NULL)"))
+    checkAnswer(
+      spark.table("l").where(!$"a".isin(null)),
+      sql("SELECT * FROM l WHERE a NOT IN (NULL)"))
+  }
+
   private def table1() = {
     sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
     spark.table("t1")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to