This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f571f2ede89 [SPARK-40900][SQL] Reimplement `frequentItems` with 
dataframe operations
f571f2ede89 is described below

commit f571f2ede895b576079eb13b5af47d70710aa301
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Oct 26 09:47:33 2022 +0900

    [SPARK-40900][SQL] Reimplement `frequentItems` with dataframe operations
    
    ### What changes were proposed in this pull request?
    Reimplement `frequentItems` with dataframe operations
    
    ### Why are the changes needed?
    1, do not truncate the sql plan any more;
    2, enable sql optimization like column pruning
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existing UTs and manually check
    
    Closes #38375 from zhengruifeng/sql_stat_freq_item.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../spark/sql/execution/stat/FrequentItems.scala   | 213 ++++++++++++++-------
 1 file changed, 140 insertions(+), 73 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index bcd226f95f8..50092571e85 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -17,55 +17,22 @@
 
 package org.apache.spark.sql.execution.stat
 
-import scala.collection.mutable.{Map => MutableMap}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, 
DataOutputStream}
+
+import scala.collection.mutable
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
-import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.{functions, Column, DataFrame}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, 
UnsafeProjection, UnsafeRow}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, 
TypedImperativeAggregate}
+import org.apache.spark.sql.catalyst.trees.UnaryLike
+import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
 
 object FrequentItems extends Logging {
 
-  /** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */
-  private class FreqItemCounter(size: Int) extends Serializable {
-    val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long]
-    /**
-     * Add a new example to the counts if it exists, otherwise deduct the count
-     * from existing items.
-     */
-    def add(key: Any, count: Long): this.type = {
-      if (baseMap.contains(key))  {
-        baseMap(key) += count
-      } else {
-        if (baseMap.size < size) {
-          baseMap += key -> count
-        } else {
-          val minCount = if (baseMap.values.isEmpty) 0 else baseMap.values.min
-          val remainder = count - minCount
-          if (remainder >= 0) {
-            baseMap += key -> count // something will get kicked out, so we 
can add this
-            baseMap.retain((k, v) => v > minCount)
-            baseMap.transform((k, v) => v - minCount)
-          } else {
-            baseMap.transform((k, v) => v - count)
-          }
-        }
-      }
-      this
-    }
-
-    /**
-     * Merge two maps of counts.
-     * @param other The map containing the counts for that partition
-     */
-    def merge(other: FreqItemCounter): this.type = {
-      other.baseMap.foreach { case (k, v) =>
-        add(k, v)
-      }
-      this
-    }
-  }
-
   /**
    * Finding frequent items for columns, possibly with false positives. Using 
the
    * frequent element count algorithm described in
@@ -85,42 +52,142 @@ object FrequentItems extends Logging {
       cols: Seq[String],
       support: Double): DataFrame = {
     require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], 
but got $support.")
-    val numCols = cols.length
+
     // number of max items to keep counts for
     val sizeOfMap = (1 / support).toInt
-    val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
-
-    val freqItems = df.select(cols.map(Column(_)) : 
_*).rdd.treeAggregate(countMaps)(
-      seqOp = (counts, row) => {
-        var i = 0
-        while (i < numCols) {
-          val thisMap = counts(i)
-          val key = row.get(i)
-          thisMap.add(key, 1L)
-          i += 1
-        }
-        counts
-      },
-      combOp = (baseCounts, counts) => {
-        var i = 0
-        while (i < numCols) {
-          baseCounts(i).merge(counts(i))
-          i += 1
+
+    val frequentItemCols = cols.map { col =>
+      val aggExpr = new CollectFrequentItems(functions.col(col).expr, 
sizeOfMap)
+      Column(aggExpr.toAggregateExpression(isDistinct = 
false)).as(s"${col}_freqItems")
+    }
+
+    df.select(frequentItemCols: _*)
+  }
+}
+
+case class CollectFrequentItems(
+    child: Expression,
+    size: Int,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0) extends 
TypedImperativeAggregate[mutable.Map[Any, Long]]
+  with UnaryLike[Expression] {
+  require(size > 0)
+
+  def this(child: Expression, size: Int) = this(child, size, 0, 0)
+
+  // Returns empty array for empty inputs
+  override def nullable: Boolean = false
+
+  override def dataType: DataType = ArrayType(child.dataType, containsNull = 
child.nullable)
+
+  override def prettyName: String = "collect_frequent_items"
+
+  override def createAggregationBuffer(): mutable.Map[Any, Long] =
+    mutable.Map.empty[Any, Long]
+
+  private def add(map: mutable.Map[Any, Long], key: Any, count: Long): 
mutable.Map[Any, Long] = {
+    if (map.contains(key)) {
+      map(key) += count
+    } else {
+      if (map.size < size) {
+        map += key -> count
+      } else {
+        val minCount = if (map.values.isEmpty) 0 else map.values.min
+        val remainder = count - minCount
+        if (remainder >= 0) {
+          map += key -> count // something will get kicked out, so we can add 
this
+          map.retain((k, v) => v > minCount)
+          map.transform((k, v) => v - minCount)
+        } else {
+          map.transform((k, v) => v - count)
         }
-        baseCounts
       }
-    )
-    val justItems = freqItems.map(m => m.baseMap.keys.toArray)
-    val resultRow = Row(justItems : _*)
+    }
+    map
+  }
+
+  override def update(
+      buffer: mutable.Map[Any, Long],
+      input: InternalRow): mutable.Map[Any, Long] = {
+    val key = child.eval(input)
+    if (key != null) {
+      this.add(buffer, InternalRow.copyValue(key), 1L)
+    } else {
+      this.add(buffer, key, 1L)
+    }
+  }
+
+  override def merge(
+      buffer: mutable.Map[Any, Long],
+      input: mutable.Map[Any, Long]): mutable.Map[Any, Long] = {
+    val otherIter = input.iterator
+    while (otherIter.hasNext) {
+      val (key, count) = otherIter.next
+      add(buffer, key, count)
+    }
+    buffer
+  }
 
-    val outputCols = cols.map { name =>
-      val originalField = df.resolve(name)
+  override def eval(buffer: mutable.Map[Any, Long]): Any =
+    new GenericArrayData(buffer.keys.toArray)
 
-      // append frequent Items to the column name for easy debugging
-      StructField(name + "_freqItems", ArrayType(originalField.dataType, 
originalField.nullable))
-    }.toArray
+  private lazy val projection =
+    UnsafeProjection.create(Array[DataType](child.dataType, LongType))
 
-    val schema = StructType(outputCols).toAttributes
-    Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, 
Seq(resultRow)))
+  override def serialize(map: mutable.Map[Any, Long]): Array[Byte] = {
+    val buffer = new Array[Byte](4 << 10) // 4K
+    val bos = new ByteArrayOutputStream()
+    val out = new DataOutputStream(bos)
+    Utils.tryWithSafeFinally {
+      // Write pairs in counts map to byte buffer.
+      map.foreach { case (key, count) =>
+        val row = InternalRow.apply(key, count)
+        val unsafeRow = projection.apply(row)
+        out.writeInt(unsafeRow.getSizeInBytes)
+        unsafeRow.writeToStream(out, buffer)
+      }
+      out.writeInt(-1)
+      out.flush()
+
+      bos.toByteArray
+    } {
+      out.close()
+      bos.close()
+    }
   }
+
+  override def deserialize(bytes: Array[Byte]): mutable.Map[Any, Long] = {
+    val bis = new ByteArrayInputStream(bytes)
+    val ins = new DataInputStream(bis)
+    Utils.tryWithSafeFinally {
+      val map = mutable.Map.empty[Any, Long]
+      // Read unsafeRow size and content in bytes.
+      var sizeOfNextRow = ins.readInt()
+      while (sizeOfNextRow >= 0) {
+        val bs = new Array[Byte](sizeOfNextRow)
+        ins.readFully(bs)
+        val row = new UnsafeRow(2)
+        row.pointTo(bs, sizeOfNextRow)
+        // Insert the pairs into counts map.
+        val key = row.get(0, child.dataType)
+        val count = row.get(1, LongType).asInstanceOf[Long]
+        map.update(key, count)
+        sizeOfNextRow = ins.readInt()
+      }
+
+      map
+    } {
+      ins.close()
+      bis.close()
+    }
+  }
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override protected def withNewChildInternal(newChild: Expression): 
Expression =
+    copy(child = newChild)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to