This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fbd620989d7 [SPARK-42042][CONNECT][PYTHON] DataFrameReader` should 
support StructType schema
fbd620989d7 is described below

commit fbd620989d77e810c1066889840b237a7eca920a
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Jan 13 14:56:09 2023 +0800

    [SPARK-42042][CONNECT][PYTHON] DataFrameReader` should support StructType 
schema
    
    ### What changes were proposed in this pull request?
    `DataFrameReader` should support StructType schema
    
    ### Why are the changes needed?
    for parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    updated ut
    
    Closes #39545 from zhengruifeng/connect_io_struct_type_schema.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../main/protobuf/spark/connect/relations.proto    |  2 ++
 .../sql/connect/planner/SparkConnectPlanner.scala  | 11 ++++--
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  5 ++-
 python/pyspark/sql/connect/readwriter.py           | 14 +++++---
 .../sql/tests/connect/test_connect_basic.py        | 39 +++++++++++++++-------
 5 files changed, 51 insertions(+), 20 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 1ddf9f9c0db..f029273c48b 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -119,6 +119,8 @@ message Read {
     string format = 1;
 
     // (Optional) If not set, Spark will infer the schema.
+    //
+    // This schema string should be either DDL-formatted or JSON-formatted.
     optional string schema = 2;
 
     // Options for the data source. The context of this map varies based on the
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 4c6c63729d5..512ac8efe2e 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -658,8 +658,15 @@ class SparkConnectPlanner(session: SparkSession) {
         val reader = session.read
         reader.format(rel.getDataSource.getFormat)
         localMap.foreach { case (key, value) => reader.option(key, value) }
-        if (rel.getDataSource.getSchema != null && 
!rel.getDataSource.getSchema.isEmpty) {
-          reader.schema(rel.getDataSource.getSchema)
+        if (rel.getDataSource.hasSchema && 
rel.getDataSource.getSchema.nonEmpty) {
+
+          DataType.parseTypeWithFallback(
+            rel.getDataSource.getSchema,
+            StructType.fromDDL,
+            fallbackParser = DataType.fromJson) match {
+            case s: StructType => reader.schema(s)
+            case other => throw InvalidPlanInput(s"Invalid schema $other")
+          }
         }
         reader.load().queryExecution.analyzed
       case _ => throw InvalidPlanInput("Does not support " + 
rel.getReadTypeCase.name())
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 282bac5dff1..04512a4c891 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -557,7 +557,10 @@ class Read(google.protobuf.message.Message):
         format: builtins.str
         """(Required) Supported formats include: parquet, orc, text, json, 
parquet, csv, avro."""
         schema: builtins.str
-        """(Optional) If not set, Spark will infer the schema."""
+        """(Optional) If not set, Spark will infer the schema.
+
+        This schema string should be either DDL-formatted or JSON-formatted.
+        """
         @property
         def options(
             self,
diff --git a/python/pyspark/sql/connect/readwriter.py 
b/python/pyspark/sql/connect/readwriter.py
index f256e5d3118..62a082cc90a 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -69,9 +69,13 @@ class DataFrameReader(OptionUtils):
 
     format.__doc__ = PySparkDataFrameReader.format.__doc__
 
-    # TODO(SPARK-40539): support StructType in python client and support 
schema as StructType.
-    def schema(self, schema: str) -> "DataFrameReader":
-        self._schema = schema
+    def schema(self, schema: Union[StructType, str]) -> "DataFrameReader":
+        if isinstance(schema, StructType):
+            self._schema = schema.json()
+        elif isinstance(schema, str):
+            self._schema = schema
+        else:
+            raise TypeError(f"schema must be a StructType or str, but got 
{schema}")
         return self
 
     schema.__doc__ = PySparkDataFrameReader.schema.__doc__
@@ -93,7 +97,7 @@ class DataFrameReader(OptionUtils):
         self,
         path: Optional[str] = None,
         format: Optional[str] = None,
-        schema: Optional[str] = None,
+        schema: Optional[Union[StructType, str]] = None,
         **options: "OptionalPrimitiveType",
     ) -> "DataFrame":
         if format is not None:
@@ -122,7 +126,7 @@ class DataFrameReader(OptionUtils):
     def json(
         self,
         path: str,
-        schema: Optional[str] = None,
+        schema: Optional[Union[StructType, str]] = None,
         primitivesAsString: Optional[Union[bool, str]] = None,
         prefersDecimal: Optional[Union[bool, str]] = None,
         allowComments: Optional[Union[bool, str]] = None,
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 8d9becf259a..94b9854c7b3 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -214,10 +214,21 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             ).format("json").save(d)
             # Read the JSON file as a DataFrame.
             self.assert_eq(self.connect.read.json(d).toPandas(), 
self.spark.read.json(d).toPandas())
-            self.assert_eq(
-                self.connect.read.json(path=d, schema="age INT, name 
STRING").toPandas(),
-                self.spark.read.json(path=d, schema="age INT, name 
STRING").toPandas(),
-            )
+
+            for schema in [
+                "age INT, name STRING",
+                StructType(
+                    [
+                        StructField("age", IntegerType()),
+                        StructField("name", StringType()),
+                    ]
+                ),
+            ]:
+                self.assert_eq(
+                    self.connect.read.json(path=d, schema=schema).toPandas(),
+                    self.spark.read.json(path=d, schema=schema).toPandas(),
+                )
+
             self.assert_eq(
                 self.connect.read.json(path=d, 
primitivesAsString=True).toPandas(),
                 self.spark.read.json(path=d, 
primitivesAsString=True).toPandas(),
@@ -1677,14 +1688,18 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         shutil.rmtree(tmpPath)
         writeDf.write.text(tmpPath)
 
-        readDf = self.connect.read.format("text").schema("id 
STRING").load(path=tmpPath)
-        expectResult = writeDf.collect()
-        pandasResult = readDf.toPandas()
-        if pandasResult is None:
-            self.assertTrue(False, "Empty pandas dataframe")
-        else:
-            actualResult = pandasResult.values.tolist()
-            self.assertEqual(len(expectResult), len(actualResult))
+        for schema in [
+            "id STRING",
+            StructType([StructField("id", StringType())]),
+        ]:
+            readDf = 
self.connect.read.format("text").schema(schema).load(path=tmpPath)
+            expectResult = writeDf.collect()
+            pandasResult = readDf.toPandas()
+            if pandasResult is None:
+                self.assertTrue(False, "Empty pandas dataframe")
+            else:
+                actualResult = pandasResult.values.tolist()
+                self.assertEqual(len(expectResult), len(actualResult))
 
     def test_simple_read_without_schema(self) -> None:
         """SPARK-41300: Schema not set when reading CSV."""


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to