Repository: spark Updated Branches: refs/heads/master 111d6b9b8 -> 0401cbaa8
[SPARK-7157][SQL] add sampleBy to DataFrame Add `sampleBy` to DataFrame. rxin Author: Xiangrui Meng <[email protected]> Closes #6769 from mengxr/SPARK-7157 and squashes the following commits: 991f26f [Xiangrui Meng] fix seed 4a14834 [Xiangrui Meng] move sampleBy to stat 832f7cc [Xiangrui Meng] add sampleBy to DataFrame Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0401cbaa Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0401cbaa Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0401cbaa Branch: refs/heads/master Commit: 0401cbaa8ee51c71f43604f338b65022a479da0a Parents: 111d6b9 Author: Xiangrui Meng <[email protected]> Authored: Tue Jun 23 17:46:29 2015 -0700 Committer: Reynold Xin <[email protected]> Committed: Tue Jun 23 17:46:29 2015 -0700 ---------------------------------------------------------------------- python/pyspark/sql/dataframe.py | 40 ++++++++++++++++++++ .../spark/sql/DataFrameStatFunctions.scala | 24 ++++++++++++ .../apache/spark/sql/DataFrameStatSuite.scala | 12 +++++- 3 files changed, 74 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0401cbaa/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 152b873..213338d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -448,6 +448,41 @@ class DataFrame(object): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + @since(1.5) + def sampleBy(self, col, fractions, seed=None): + """ + Returns a stratified sample without replacement based on the + fraction given on each stratum. + + :param col: column that defines strata + :param fractions: + sampling fraction for each stratum. If a stratum is not + specified, we treat its fraction as zero. + :param seed: random seed + :return: a new DataFrame that represents the stratified sample + + >>> from pyspark.sql.functions import col + >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) + >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) + >>> sampled.groupBy("key").count().orderBy("key").show() + +---+-----+ + |key|count| + +---+-----+ + | 0| 5| + | 1| 8| + +---+-----+ + """ + if not isinstance(col, str): + raise ValueError("col must be a string, but got %r" % type(col)) + if not isinstance(fractions, dict): + raise ValueError("fractions must be a dict but got %r" % type(fractions)) + for k, v in fractions.items(): + if not isinstance(k, (float, int, long, basestring)): + raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) + fractions[k] = float(v) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. @@ -1322,6 +1357,11 @@ class DataFrameStatFunctions(object): freqItems.__doc__ = DataFrame.freqItems.__doc__ + def sampleBy(self, col, fractions, seed=None): + return self.df.sampleBy(col, fractions, seed) + + sampleBy.__doc__ = DataFrame.sampleBy.__doc__ + def _test(): import doctest http://git-wip-us.apache.org/repos/asf/spark/blob/0401cbaa/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index edb9ed7..955d287 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.UUID + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ @@ -163,4 +165,26 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = { + require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), + s"Fractions must be in [0, 1], but got $fractions.") + import org.apache.spark.sql.functions.rand + val c = Column(col) + val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8)) + val expr = fractions.toSeq.map { case (k, v) => + (c === k) && (r < v) + }.reduce(_ || _) || false + df.filter(expr) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/0401cbaa/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0d3ff89..3dd4688 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import org.scalatest.Matchers._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions.col -class DataFrameStatSuite extends SparkFunSuite { +class DataFrameStatSuite extends QueryTest { private val sqlCtx = org.apache.spark.sql.test.TestSQLContext import sqlCtx.implicits._ @@ -98,4 +98,12 @@ class DataFrameStatSuite extends SparkFunSuite { val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) } + + test("sampleBy") { + val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 4), Row(1, 9))) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
