This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f571f2ede89 [SPARK-40900][SQL] Reimplement `frequentItems` with dataframe operations f571f2ede89 is described below commit f571f2ede895b576079eb13b5af47d70710aa301 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Oct 26 09:47:33 2022 +0900 [SPARK-40900][SQL] Reimplement `frequentItems` with dataframe operations ### What changes were proposed in this pull request? Reimplement `frequentItems` with dataframe operations ### Why are the changes needed? 1, do not truncate the sql plan any more; 2, enable sql optimization like column pruning ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing UTs and manually check Closes #38375 from zhengruifeng/sql_stat_freq_item. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../spark/sql/execution/stat/FrequentItems.scala | 213 ++++++++++++++------- 1 file changed, 140 insertions(+), 73 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index bcd226f95f8..50092571e85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -17,55 +17,22 @@ package org.apache.spark.sql.execution.stat -import scala.collection.mutable.{Map => MutableMap} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.{functions, Column, DataFrame} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object FrequentItems extends Logging { - /** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */ - private class FreqItemCounter(size: Int) extends Serializable { - val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long] - /** - * Add a new example to the counts if it exists, otherwise deduct the count - * from existing items. - */ - def add(key: Any, count: Long): this.type = { - if (baseMap.contains(key)) { - baseMap(key) += count - } else { - if (baseMap.size < size) { - baseMap += key -> count - } else { - val minCount = if (baseMap.values.isEmpty) 0 else baseMap.values.min - val remainder = count - minCount - if (remainder >= 0) { - baseMap += key -> count // something will get kicked out, so we can add this - baseMap.retain((k, v) => v > minCount) - baseMap.transform((k, v) => v - minCount) - } else { - baseMap.transform((k, v) => v - count) - } - } - } - this - } - - /** - * Merge two maps of counts. - * @param other The map containing the counts for that partition - */ - def merge(other: FreqItemCounter): this.type = { - other.baseMap.foreach { case (k, v) => - add(k, v) - } - this - } - } - /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in @@ -85,42 +52,142 @@ object FrequentItems extends Logging { cols: Seq[String], support: Double): DataFrame = { require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], but got $support.") - val numCols = cols.length + // number of max items to keep counts for val sizeOfMap = (1 / support).toInt - val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap)) - - val freqItems = df.select(cols.map(Column(_)) : _*).rdd.treeAggregate(countMaps)( - seqOp = (counts, row) => { - var i = 0 - while (i < numCols) { - val thisMap = counts(i) - val key = row.get(i) - thisMap.add(key, 1L) - i += 1 - } - counts - }, - combOp = (baseCounts, counts) => { - var i = 0 - while (i < numCols) { - baseCounts(i).merge(counts(i)) - i += 1 + + val frequentItemCols = cols.map { col => + val aggExpr = new CollectFrequentItems(functions.col(col).expr, sizeOfMap) + Column(aggExpr.toAggregateExpression(isDistinct = false)).as(s"${col}_freqItems") + } + + df.select(frequentItemCols: _*) + } +} + +case class CollectFrequentItems( + child: Expression, + size: Int, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[mutable.Map[Any, Long]] + with UnaryLike[Expression] { + require(size > 0) + + def this(child: Expression, size: Int) = this(child, size, 0, 0) + + // Returns empty array for empty inputs + override def nullable: Boolean = false + + override def dataType: DataType = ArrayType(child.dataType, containsNull = child.nullable) + + override def prettyName: String = "collect_frequent_items" + + override def createAggregationBuffer(): mutable.Map[Any, Long] = + mutable.Map.empty[Any, Long] + + private def add(map: mutable.Map[Any, Long], key: Any, count: Long): mutable.Map[Any, Long] = { + if (map.contains(key)) { + map(key) += count + } else { + if (map.size < size) { + map += key -> count + } else { + val minCount = if (map.values.isEmpty) 0 else map.values.min + val remainder = count - minCount + if (remainder >= 0) { + map += key -> count // something will get kicked out, so we can add this + map.retain((k, v) => v > minCount) + map.transform((k, v) => v - minCount) + } else { + map.transform((k, v) => v - count) } - baseCounts } - ) - val justItems = freqItems.map(m => m.baseMap.keys.toArray) - val resultRow = Row(justItems : _*) + } + map + } + + override def update( + buffer: mutable.Map[Any, Long], + input: InternalRow): mutable.Map[Any, Long] = { + val key = child.eval(input) + if (key != null) { + this.add(buffer, InternalRow.copyValue(key), 1L) + } else { + this.add(buffer, key, 1L) + } + } + + override def merge( + buffer: mutable.Map[Any, Long], + input: mutable.Map[Any, Long]): mutable.Map[Any, Long] = { + val otherIter = input.iterator + while (otherIter.hasNext) { + val (key, count) = otherIter.next + add(buffer, key, count) + } + buffer + } - val outputCols = cols.map { name => - val originalField = df.resolve(name) + override def eval(buffer: mutable.Map[Any, Long]): Any = + new GenericArrayData(buffer.keys.toArray) - // append frequent Items to the column name for easy debugging - StructField(name + "_freqItems", ArrayType(originalField.dataType, originalField.nullable)) - }.toArray + private lazy val projection = + UnsafeProjection.create(Array[DataType](child.dataType, LongType)) - val schema = StructType(outputCols).toAttributes - Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow))) + override def serialize(map: mutable.Map[Any, Long]): Array[Byte] = { + val buffer = new Array[Byte](4 << 10) // 4K + val bos = new ByteArrayOutputStream() + val out = new DataOutputStream(bos) + Utils.tryWithSafeFinally { + // Write pairs in counts map to byte buffer. + map.foreach { case (key, count) => + val row = InternalRow.apply(key, count) + val unsafeRow = projection.apply(row) + out.writeInt(unsafeRow.getSizeInBytes) + unsafeRow.writeToStream(out, buffer) + } + out.writeInt(-1) + out.flush() + + bos.toByteArray + } { + out.close() + bos.close() + } } + + override def deserialize(bytes: Array[Byte]): mutable.Map[Any, Long] = { + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(bis) + Utils.tryWithSafeFinally { + val map = mutable.Map.empty[Any, Long] + // Read unsafeRow size and content in bytes. + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(2) + row.pointTo(bs, sizeOfNextRow) + // Insert the pairs into counts map. + val key = row.get(0, child.dataType) + val count = row.get(1, LongType).asInstanceOf[Long] + map.update(key, count) + sizeOfNextRow = ins.readInt() + } + + map + } { + ins.close() + bis.close() + } + } + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org