Repository: spark Updated Branches: refs/heads/branch-2.3 86457a16d -> 32429256f
[SPARK-24208][SQL] Fix attribute deduplication for FlatMapGroupsInPandas A self-join on a dataset which contains a `FlatMapGroupsInPandas` fails because of duplicate attributes. This happens because we are not dealing with this specific case in our `dedupAttr` rules. The PR fix the issue by adding the management of the specific case added UT + manual tests Author: Marco Gaido <[email protected]> Author: Marco Gaido <[email protected]> Closes #21737 from mgaido91/SPARK-24208. (cherry picked from commit ebf4bfb966389342bfd9bdb8e3b612828c18730c) Signed-off-by: Xiao Li <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/32429256 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/32429256 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/32429256 Branch: refs/heads/branch-2.3 Commit: 32429256f3e659c648462e5b2740747645740c97 Parents: 86457a1 Author: Marco Gaido <[email protected]> Authored: Wed Jul 11 09:29:19 2018 -0700 Committer: Xiao Li <[email protected]> Committed: Wed Jul 11 09:35:44 2018 -0700 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 16 ++++++++++++++++ .../spark/sql/catalyst/analysis/Analyzer.scala | 4 ++++ .../org/apache/spark/sql/GroupedDatasetSuite.scala | 12 ++++++++++++ 3 files changed, 32 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/32429256/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index aa7d8eb..6bfb329 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4691,6 +4691,22 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): result = df.groupby('time').apply(foo_udf).sort('time') self.assertPandasEqual(df.toPandas(), result.toPandas()) + def test_self_join_with_pandas(self): + import pyspark.sql.functions as F + + @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + def dummy_pandas_udf(df): + return df[['key', 'col']] + + df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), + Row(key=2, col='C')]) + dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf) + + # this was throwing an AnalysisException before SPARK-24208 + res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'), + F.col('temp0.key') == F.col('temp1.key')) + self.assertEquals(res.count(), 5) + if __name__ == "__main__": from pyspark.sql.tests import * http://git-wip-us.apache.org/repos/asf/spark/blob/32429256/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8597d83..a584cb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -723,6 +723,10 @@ class Analyzer( if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + case oldVersion @ FlatMapGroupsInPandas(_, _, output, _) + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(output = output.map(_.newInstance()))) + case oldVersion: Generate if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) http://git-wip-us.apache.org/repos/asf/spark/blob/32429256/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala index 218a1b7..9699fad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala @@ -93,4 +93,16 @@ class GroupedDatasetSuite extends QueryTest with SharedSQLContext { } datasetWithUDF.unpersist(true) } + + test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { + val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF( + "pyUDF", + null, + StructType(Seq(StructField("s", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true)) + val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === $"temp1.s") + df1.queryExecution.assertAnalyzed() + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
