This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dc186c5e6b6 [SPARK-43773][CONNECT][PYTHON][, THRESHOLD] Implement
'levenshtein(str1, str2)' functions in python client
dc186c5e6b6 is described below
commit dc186c5e6b6bdb63345081ee9f70b8c102792cdd
Author: panbingkun <[email protected]>
AuthorDate: Sun May 28 08:38:32 2023 +0800
[SPARK-43773][CONNECT][PYTHON][, THRESHOLD] Implement 'levenshtein(str1,
str2)' functions in python client
### What changes were proposed in this pull request?
The pr aims to implement 'levenshtein(str1, str2[, threshold])' functions
in python client
### Why are the changes needed?
After Add a max distance argument to the levenshtein() function We have
already implemented it on the scala side, so we need to align it on `pyspark`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- Manual testing
python/run-tests --testnames 'python.pyspark.sql.tests.test_functions
FunctionsTests.test_levenshtein_function'
- Pass GA
Closes #41296 from panbingkun/SPARK-43773.
Lead-authored-by: panbingkun <[email protected]>
Co-authored-by: panbingkun <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/functions.py | 9 +++++++--
python/pyspark/sql/functions.py | 19 +++++++++++++++++--
.../sql/tests/connect/test_connect_function.py | 5 +++++
python/pyspark/sql/tests/test_functions.py | 7 +++++++
4 files changed, 36 insertions(+), 4 deletions(-)
diff --git a/python/pyspark/sql/connect/functions.py
b/python/pyspark/sql/connect/functions.py
index b7d7bc937cf..d3a05d6a1c6 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -1878,8 +1878,13 @@ def substring_index(str: "ColumnOrName", delim: str,
count: int) -> Column:
substring_index.__doc__ = pysparkfuncs.substring_index.__doc__
-def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column:
- return _invoke_function_over_columns("levenshtein", left, right)
+def levenshtein(
+ left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] =
None
+) -> Column:
+ if threshold is None:
+ return _invoke_function_over_columns("levenshtein", left, right)
+ else:
+ return _invoke_function("levenshtein", _to_col(left), _to_col(right),
lit(threshold))
levenshtein.__doc__ = pysparkfuncs.levenshtein.__doc__
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index e9b71f7d617..fe35f12c402 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -6594,7 +6594,9 @@ def substring_index(str: "ColumnOrName", delim: str,
count: int) -> Column:
@try_remote_functions
-def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column:
+def levenshtein(
+ left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] =
None
+) -> Column:
"""Computes the Levenshtein distance of the two given strings.
.. versionadded:: 1.5.0
@@ -6608,6 +6610,12 @@ def levenshtein(left: "ColumnOrName", right:
"ColumnOrName") -> Column:
first column value.
right : :class:`~pyspark.sql.Column` or str
second column value.
+ threshold : int, optional
+ if set when the levenshtein distance of the two given strings
+ less than or equal to a given threshold then return result distance,
or -1
+
+ .. versionchanged: 3.5.0
+ Added ``threshold`` argument.
Returns
-------
@@ -6619,8 +6627,15 @@ def levenshtein(left: "ColumnOrName", right:
"ColumnOrName") -> Column:
>>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
>>> df0.select(levenshtein('l', 'r').alias('d')).collect()
[Row(d=3)]
+ >>> df0.select(levenshtein('l', 'r', 2).alias('d')).collect()
+ [Row(d=-1)]
"""
- return _invoke_function_over_columns("levenshtein", left, right)
+ if threshold is None:
+ return _invoke_function_over_columns("levenshtein", left, right)
+ else:
+ return _invoke_function(
+ "levenshtein", _to_java_column(left), _to_java_column(right),
threshold
+ )
@try_remote_functions
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py
b/python/pyspark/sql/tests/connect/test_connect_function.py
index e274635d3c6..3e3b4dd5b16 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -1924,6 +1924,11 @@ class SparkConnectFunctionTests(ReusedConnectTestCase,
PandasOnSparkTestUtils, S
cdf.select(CF.levenshtein(cdf.b, cdf.c)).toPandas(),
sdf.select(SF.levenshtein(sdf.b, sdf.c)).toPandas(),
)
+ self.assert_eq(
+ cdf.select(CF.levenshtein(cdf.b, cdf.c, 1)).toPandas(),
+ sdf.select(SF.levenshtein(sdf.b, sdf.c, 1)).toPandas(),
+ )
+
self.assert_eq(
cdf.select(CF.locate("e", cdf.b)).toPandas(),
sdf.select(SF.locate("e", sdf.b)).toPandas(),
diff --git a/python/pyspark/sql/tests/test_functions.py
b/python/pyspark/sql/tests/test_functions.py
index 9067de34633..72c6c365b80 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -377,6 +377,13 @@ class FunctionsTestsMixin:
actual = df.select(F.array_contains(df.data, "1").alias("b")).collect()
self.assertEqual([Row(b=True), Row(b=False)], actual)
+ def test_levenshtein_function(self):
+ df = self.spark.createDataFrame([("kitten", "sitting")], ["l", "r"])
+ actual_without_threshold = df.select(F.levenshtein(df.l,
df.r).alias("b")).collect()
+ self.assertEqual([Row(b=3)], actual_without_threshold)
+ actual_with_threshold = df.select(F.levenshtein(df.l, df.r,
2).alias("b")).collect()
+ self.assertEqual([Row(b=-1)], actual_with_threshold)
+
def test_between_function(self):
df = self.spark.createDataFrame(
[Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]