Repository: spark Updated Branches: refs/heads/master 7143f6e97 -> 5553198fe
[SPARK-7156][SQL] Addressed follow up comments for randomSplit small fixes regarding comments in PR #5761 cc rxin Author: Burak Yavuz <[email protected]> Closes #5795 from brkyvz/split-followup and squashes the following commits: 369c522 [Burak Yavuz] changed wording a little 1ea456f [Burak Yavuz] Addressed follow up comments Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5553198f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5553198f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5553198f Branch: refs/heads/master Commit: 5553198fe521fb38b600b7687f7780d89a6e1cb9 Parents: 7143f6e Author: Burak Yavuz <[email protected]> Authored: Wed Apr 29 19:13:47 2015 -0700 Committer: Reynold Xin <[email protected]> Committed: Wed Apr 29 19:13:47 2015 -0700 ---------------------------------------------------------------------- python/pyspark/sql/dataframe.py | 7 ++++++- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5553198f/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3074af3..5908ebc 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -437,6 +437,10 @@ class DataFrame(object): def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. + :param weights: list of doubles as weights with which to split the DataFrame. Weights will + be normalized if they don't sum up to 1.0. + :param seed: The seed for sampling. + >>> splits = df4.randomSplit([1.0, 2.0], 24) >>> splits[0].count() 1 @@ -445,7 +449,8 @@ class DataFrame(object): 3 """ for w in weights: - assert w >= 0.0, "Negative weight value: %s" % w + if w < 0.0: + raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed)) return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] http://git-wip-us.apache.org/repos/asf/spark/blob/5553198f/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 0d02e14..2669300 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -752,7 +752,7 @@ class DataFrame private[sql]( * @param seed Seed for sampling. * @group dfops */ - def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = { + private[spark] def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = { randomSplit(weights.toArray, seed) } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
