Repository: spark
Updated Branches:
  refs/heads/master ca71cc8c8 -> df3266951


[SPARK-7157][SQL] add sampleBy to DataFrame

This was previously committed but then reverted due to test failures (see 
#6769).

Author: Xiangrui Meng <[email protected]>

Closes #7755 from rxin/SPARK-7157 and squashes the following commits:

fbf9044 [Xiangrui Meng] fix python test
542bd37 [Xiangrui Meng] update test
604fe6d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into 
SPARK-7157
f051afd [Xiangrui Meng] use udf instead of building expression
f4e9425 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into 
SPARK-7157
8fb990b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into 
SPARK-7157
103beb3 [Xiangrui Meng] add Java-friendly sampleBy
991f26f [Xiangrui Meng] fix seed
4a14834 [Xiangrui Meng] move sampleBy to stat
832f7cc [Xiangrui Meng] add sampleBy to DataFrame


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/df326695
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/df326695
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/df326695

Branch: refs/heads/master
Commit: df32669514afc0223ecdeca30fbfbe0b40baef3a
Parents: ca71cc8
Author: Xiangrui Meng <[email protected]>
Authored: Thu Jul 30 17:16:03 2015 -0700
Committer: Reynold Xin <[email protected]>
Committed: Thu Jul 30 17:16:03 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 41 +++++++++++++++++++
 .../spark/sql/DataFrameStatFunctions.scala      | 42 ++++++++++++++++++++
 .../apache/spark/sql/JavaDataFrameSuite.java    |  9 +++++
 .../apache/spark/sql/DataFrameStatSuite.scala   | 12 +++++-
 4 files changed, 102 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/df326695/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index d76e051..0f3480c 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -441,6 +441,42 @@ class DataFrame(object):
         rdd = self._jdf.sample(withReplacement, fraction, long(seed))
         return DataFrame(rdd, self.sql_ctx)
 
+    @since(1.5)
+    def sampleBy(self, col, fractions, seed=None):
+        """
+        Returns a stratified sample without replacement based on the
+        fraction given on each stratum.
+
+        :param col: column that defines strata
+        :param fractions:
+            sampling fraction for each stratum. If a stratum is not
+            specified, we treat its fraction as zero.
+        :param seed: random seed
+        :return: a new DataFrame that represents the stratified sample
+
+        >>> from pyspark.sql.functions import col
+        >>> dataset = sqlContext.range(0, 100).select((col("id") % 
3).alias("key"))
+        >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, 
seed=0)
+        >>> sampled.groupBy("key").count().orderBy("key").show()
+        +---+-----+
+        |key|count|
+        +---+-----+
+        |  0|    3|
+        |  1|    8|
+        +---+-----+
+
+        """
+        if not isinstance(col, str):
+            raise ValueError("col must be a string, but got %r" % type(col))
+        if not isinstance(fractions, dict):
+            raise ValueError("fractions must be a dict but got %r" % 
type(fractions))
+        for k, v in fractions.items():
+            if not isinstance(k, (float, int, long, basestring)):
+                raise ValueError("key must be float, int, long, or string, but 
got %r" % type(k))
+            fractions[k] = float(v)
+        seed = seed if seed is not None else random.randint(0, sys.maxsize)
+        return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), 
seed), self.sql_ctx)
+
     @since(1.4)
     def randomSplit(self, weights, seed=None):
         """Randomly splits this :class:`DataFrame` with the provided weights.
@@ -1314,6 +1350,11 @@ class DataFrameStatFunctions(object):
 
     freqItems.__doc__ = DataFrame.freqItems.__doc__
 
+    def sampleBy(self, col, fractions, seed=None):
+        return self.df.sampleBy(col, fractions, seed)
+
+    sampleBy.__doc__ = DataFrame.sampleBy.__doc__
+
 
 def _test():
     import doctest

http://git-wip-us.apache.org/repos/asf/spark/blob/df326695/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 4ec5808..2e68e35 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -17,6 +17,10 @@
 
 package org.apache.spark.sql
 
+import java.{util => ju, lang => jl}
+
+import scala.collection.JavaConverters._
+
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.execution.stat._
 
@@ -166,4 +170,42 @@ final class DataFrameStatFunctions private[sql](df: 
DataFrame) {
   def freqItems(cols: Seq[String]): DataFrame = {
     FrequentItems.singlePassFreqItems(df, cols, 0.01)
   }
+
+  /**
+   * Returns a stratified sample without replacement based on the fraction 
given on each stratum.
+   * @param col column that defines strata
+   * @param fractions sampling fraction for each stratum. If a stratum is not 
specified, we treat
+   *                  its fraction as zero.
+   * @param seed random seed
+   * @tparam T stratum type
+   * @return a new [[DataFrame]] that represents the stratified sample
+   *
+   * @since 1.5.0
+   */
+  def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): 
DataFrame = {
+    require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
+      s"Fractions must be in [0, 1], but got $fractions.")
+    import org.apache.spark.sql.functions.{rand, udf}
+    val c = Column(col)
+    val r = rand(seed)
+    val f = udf { (stratum: Any, x: Double) =>
+      x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0)
+    }
+    df.filter(f(c, r))
+  }
+
+  /**
+   * Returns a stratified sample without replacement based on the fraction 
given on each stratum.
+   * @param col column that defines strata
+   * @param fractions sampling fraction for each stratum. If a stratum is not 
specified, we treat
+   *                  its fraction as zero.
+   * @param seed random seed
+   * @tparam T stratum type
+   * @return a new [[DataFrame]] that represents the stratified sample
+   *
+   * @since 1.5.0
+   */
+  def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): 
DataFrame = {
+    sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/df326695/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 9e61d06..2c669bb 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -226,4 +226,13 @@ public class JavaDataFrameSuite {
     Double result = df.stat().cov("a", "b");
     Assert.assertTrue(Math.abs(result) < 1e-6);
   }
+
+  @Test
+  public void testSampleBy() {
+    DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key"));
+    DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 
0.1, 1, 0.2), 0L);
+    Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
+    Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 
8)};
+    Assert.assertArrayEquals(expected, actual);
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/df326695/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 7ba4ba7..07a675e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -21,9 +21,9 @@ import java.util.Random
 
 import org.scalatest.Matchers._
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.functions.col
 
-class DataFrameStatSuite extends SparkFunSuite  {
+class DataFrameStatSuite extends QueryTest {
 
   private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
   import sqlCtx.implicits._
@@ -130,4 +130,12 @@ class DataFrameStatSuite extends SparkFunSuite  {
     val items2 = singleColResults.collect().head
     items2.getSeq[Double](0) should contain (-1.0)
   }
+
+  test("sampleBy") {
+    val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
+    val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
+    checkAnswer(
+      sampled.groupBy("key").count().orderBy("key"),
+      Seq(Row(0, 5), Row(1, 8)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to