This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 30190edb2df7 [SPARK-46955][PS] Implement `Frame.to_stata`
30190edb2df7 is described below
commit 30190edb2df7e6cc15a5db7b070cd9dde11e2106
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Feb 5 09:14:05 2024 +0800
[SPARK-46955][PS] Implement `Frame.to_stata`
### What changes were proposed in this pull request?
Implement `Frame.to_stata`
### Why are the changes needed?
for Pandas parity
### Does this PR introduce _any_ user-facing change?
yes
```
In [5]: df = pd.DataFrame({'animal': ['falcon', 'parrot', 'falcon',
'parrot'], 'speed': [350, 18, 361, 15]})
In [6]: psdf = ps.from_pandas(df)
In [7]: df.to_stata('/tmp/animals_1.dta')
In [8]: psdf.to_stata('/tmp/animals_2.dta')
In [9]: pd.read_stata('/tmp/animals_1.dta')
Out[9]:
index animal speed
0 0 falcon 350
1 1 parrot 18
2 2 falcon 361
3 3 parrot 15
In [10]: pd.read_stata('/tmp/animals_2.dta')
Out[10]:
index animal speed
0 0 falcon 350
1 1 parrot 18
2 2 falcon 361
3 3 parrot 15
```
### How was this patch tested?
added ut
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44996 from zhengruifeng/ps_to_stata.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 2 +
.../docs/source/reference/pyspark.pandas/frame.rst | 1 +
python/pyspark/pandas/frame.py | 82 ++++++++++++++++++++++
python/pyspark/pandas/missing/frame.py | 1 -
.../pandas/tests/connect/io/test_parity_stata.py | 42 +++++++++++
python/pyspark/pandas/tests/io/test_stata.py | 67 ++++++++++++++++++
6 files changed, 194 insertions(+), 1 deletion(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 233dcf4e54b6..2ed2144fa64b 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -817,6 +817,7 @@ pyspark_pandas = Module(
"pyspark.pandas.tests.io.test_io",
"pyspark.pandas.tests.io.test_csv",
"pyspark.pandas.tests.io.test_feather",
+ "pyspark.pandas.tests.io.test_stata",
"pyspark.pandas.tests.io.test_dataframe_conversion",
"pyspark.pandas.tests.io.test_dataframe_spark_io",
"pyspark.pandas.tests.io.test_series_conversion",
@@ -1299,6 +1300,7 @@ pyspark_pandas_connect_part3 = Module(
"pyspark.pandas.tests.connect.io.test_parity_io",
"pyspark.pandas.tests.connect.io.test_parity_csv",
"pyspark.pandas.tests.connect.io.test_parity_feather",
+ "pyspark.pandas.tests.connect.io.test_parity_stata",
"pyspark.pandas.tests.connect.io.test_parity_dataframe_conversion",
"pyspark.pandas.tests.connect.io.test_parity_dataframe_spark_io",
"pyspark.pandas.tests.connect.io.test_parity_series_conversion",
diff --git a/python/docs/source/reference/pyspark.pandas/frame.rst
b/python/docs/source/reference/pyspark.pandas/frame.rst
index 564ddb607a19..336fd262f611 100644
--- a/python/docs/source/reference/pyspark.pandas/frame.rst
+++ b/python/docs/source/reference/pyspark.pandas/frame.rst
@@ -284,6 +284,7 @@ Serialization / IO / Conversion
DataFrame.to_spark
DataFrame.to_string
DataFrame.to_feather
+ DataFrame.to_stata
DataFrame.to_json
DataFrame.to_dict
DataFrame.to_excel
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 3b3565f7ea9f..e857344a6098 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -2683,6 +2683,88 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
self._to_internal_pandas(), self.to_feather,
pd.DataFrame.to_feather, args
)
+ def to_stata(
+ self,
+ path: Union[str, IO[str]],
+ *,
+ convert_dates: Optional[Dict] = None,
+ write_index: bool = True,
+ byteorder: Optional[str] = None,
+ time_stamp: Optional[datetime.datetime] = None,
+ data_label: Optional[str] = None,
+ variable_labels: Optional[Dict] = None,
+ version: Optional[int] = 114,
+ convert_strl: Optional[Sequence[Name]] = None,
+ compression: str = "infer",
+ storage_options: Optional[str] = None,
+ value_labels: Optional[Dict] = None,
+ ) -> None:
+ """
+ Export DataFrame object to Stata dta format.
+
+ .. note:: This method should only be used if the resulting DataFrame
is expected
+ to be small, as all the data is loaded into the driver's
memory.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ path : str, path object, or buffer
+ String, path object (implementing ``os.PathLike[str]``), or
file-like
+ object implementing a binary ``write()`` function.
+ convert_dates : dict
+ Dictionary mapping columns containing datetime types to stata
+ internal format to use when writing the dates. Options are 'tc',
+ 'td', 'tm', 'tw', 'th', 'tq', 'ty'. Column can be either an integer
+ or a name. Datetime columns that do not have a conversion type
+ specified will be converted to 'tc'. Raises NotImplementedError if
+ a datetime column has timezone information.
+ write_index : bool
+ Write the index to Stata dataset.
+ byteorder : str
+ Can be ">", "<", "little", or "big". default is `sys.byteorder`.
+ time_stamp : datetime
+ A datetime to use as file creation date. Default is the current
+ time.
+ data_label : str, optional
+ A label for the data set. Must be 80 characters or smaller.
+ variable_labels : dict
+ Dictionary containing columns as keys and variable labels as
+ values. Each label must be 80 characters or smaller.
+ version : {{114, 117, 118, 119, None}}, default 114
+ Version to use in the output dta file. Set to None to let pandas
+ decide between 118 or 119 formats depending on the number of
+ columns in the frame. Version 114 can be read by Stata 10 and
+ later. Version 117 can be read by Stata 13 or later. Version 118
+ is supported in Stata 14 and later. Version 119 is supported in
+ Stata 15 and later. Version 114 limits string variables to 244
+ characters or fewer while versions 117 and later allow strings
+ with lengths up to 2,000,000 characters. Versions 118 and 119
+ support Unicode characters, and version 119 supports more than
+ 32,767 variables.
+ convert_strl : list, optional
+ List of column names to convert to string columns to Stata StrL
+ format. Only available if version is 117. Storing strings in the
+ StrL format can produce smaller dta files if strings have more than
+ 8 characters and values are repeated.
+ value_labels : dict of dicts
+ Dictionary containing columns as keys and dictionaries of column
value
+ to labels as values. Labels for a single variable must be 32,000
+ characters or smaller.
+
+ Examples
+ --------
+ >>> df = ps.DataFrame({'animal': ['falcon', 'parrot', 'falcon',
'parrot'],
+ ... 'speed': [350, 18, 361, 15]})
+ >>> df.to_stata('animals.dta') # doctest: +SKIP
+ """
+ # Make sure locals() call is at the top of the function so we don't
capture local variables.
+ args = locals()
+
+ return validate_arguments_and_invoke_function(
+ self._to_internal_pandas(), self.to_stata, pd.DataFrame.to_stata,
args
+ )
+
def transpose(self) -> "DataFrame":
"""
Transpose index and columns.
diff --git a/python/pyspark/pandas/missing/frame.py
b/python/pyspark/pandas/missing/frame.py
index fdb6cec7c0f9..bdfa7574dc3d 100644
--- a/python/pyspark/pandas/missing/frame.py
+++ b/python/pyspark/pandas/missing/frame.py
@@ -44,7 +44,6 @@ class MissingPandasLikeDataFrame:
set_axis = _unsupported_function("set_axis")
to_period = _unsupported_function("to_period")
to_sql = _unsupported_function("to_sql")
- to_stata = _unsupported_function("to_stata")
to_timestamp = _unsupported_function("to_timestamp")
tz_convert = _unsupported_function("tz_convert")
tz_localize = _unsupported_function("tz_localize")
diff --git a/python/pyspark/pandas/tests/connect/io/test_parity_stata.py
b/python/pyspark/pandas/tests/connect/io/test_parity_stata.py
new file mode 100644
index 000000000000..d7a74d7399a5
--- /dev/null
+++ b/python/pyspark/pandas/tests/connect/io/test_parity_stata.py
@@ -0,0 +1,42 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.io.test_stata import StataMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+
+
+class StataParityTests(
+ StataMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+ TestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.pandas.tests.connect.io.test_parity_stata import * # noqa:
F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/io/test_stata.py
b/python/pyspark/pandas/tests/io/test_stata.py
new file mode 100644
index 000000000000..6fe7cf13513c
--- /dev/null
+++ b/python/pyspark/pandas/tests/io/test_stata.py
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+
+
+class StataMixin:
+ @property
+ def pdf(self):
+ return pd.DataFrame(
+ {"animal": ["falcon", "parrot", "falcon", "parrot"], "speed":
[350, 18, 361, 15]}
+ )
+
+ @property
+ def psdf(self):
+ return ps.from_pandas(self.pdf)
+
+ def test_to_feather(self):
+ with self.temp_dir() as dirpath:
+ path1 = f"{dirpath}/file1.dta"
+ path2 = f"{dirpath}/file2.dta"
+
+ self.pdf.to_stata(path1)
+ self.psdf.to_stata(path2)
+
+ self.assert_eq(
+ pd.read_stata(path1),
+ pd.read_stata(path2),
+ )
+
+
+class StataTests(
+ StataMixin,
+ PandasOnSparkTestCase,
+ TestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.pandas.tests.io.test_stata import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]