Repository: spark Updated Branches: refs/heads/master 4cb2ff9d8 -> e319ac92e
[SPARK-24962][SQL] Refactor CodeGenerator.createUnsafeArray, ArraySetLike, and ArrayDistinct ## What changes were proposed in this pull request? This PR integrates handling of `UnsafeArrayData` and `GenericArrayData` into one. The current `CodeGenerator.createUnsafeArray` handles only allocation of `UnsafeArrayData`. This PR introduces a new method `createArrayData` that returns a code to allocate `UnsafeArrayData` or `GenericArrayData` and to assign a value into the allocated array. This PR also reduce the size of generated code by calling a runtime helper. This PR replaced `createArrayData` with `createUnsafeArray`. This PR also refactor `ArraySetLike` that can be used for `ArrayDistinct`, too. This PR also refactors`ArrayDistinct` to use `ArraryBuilder`. ## How was this patch tested? Existing tests Closes #21912 from kiszk/SPARK-24962. Lead-authored-by: Kazuaki Ishizaki <[email protected]> Co-authored-by: Takuya UESHIN <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e319ac92 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e319ac92 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e319ac92 Branch: refs/heads/master Commit: e319ac92e597d95eba5b787bb7a5d5499bb3f87c Parents: 4cb2ff9 Author: Kazuaki Ishizaki <[email protected]> Authored: Tue Sep 4 15:26:34 2018 +0800 Committer: Wenchen Fan <[email protected]> Committed: Tue Sep 4 15:26:34 2018 +0800 ---------------------------------------------------------------------- .../catalyst/expressions/UnsafeArrayData.java | 22 +- .../expressions/codegen/CodeGenerator.scala | 150 +-- .../expressions/collectionOperations.scala | 926 +++++++------------ .../spark/sql/catalyst/util/ArrayData.scala | 27 + 4 files changed, 464 insertions(+), 661 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e319ac92/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index cf2a5ed..9e7b15d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -473,13 +473,27 @@ public final class UnsafeArrayData extends ArrayData { return result; } - public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) { - return fromPrimitiveArray(null, offset, length, elementSize); + public static UnsafeArrayData createFreshArray(int length, int elementSize) { + final long headerInBytes = calculateHeaderPortionInBytes(length); + final long valueRegionInBytes = (long)elementSize * length; + final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; + if (totalSizeInLongs > Integer.MAX_VALUE / 8) { + throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + + "it's too big."); + } + + final long[] data = new long[(int)totalSizeInLongs]; + + Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); + + UnsafeArrayData result = new UnsafeArrayData(); + result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); + return result; } - public static boolean shouldUseGenericArrayData(int elementSize, int length) { + public static boolean shouldUseGenericArrayData(int elementSize, long length) { final long headerInBytes = calculateHeaderPortionInBytes(length); - final long valueRegionInBytes = (long)elementSize * length; + final long valueRegionInBytes = elementSize * length; final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; return totalSizeInLongs > Integer.MAX_VALUE / 8; } http://git-wip-us.apache.org/repos/asf/spark/blob/e319ac92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b8f0976..d5857e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -39,7 +39,7 @@ import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -747,73 +747,6 @@ class CodegenContext { } /** - * Generates code creating a [[UnsafeArrayData]]. - * - * @param arrayName name of the array to create - * @param numElements code representing the number of elements the array should contain - * @param elementType data type of the elements in the array - * @param additionalErrorMessage string to include in the error message - */ - def createUnsafeArray( - arrayName: String, - numElements: String, - elementType: DataType, - additionalErrorMessage: String): String = { - val arraySize = freshName("size") - val arrayBytes = freshName("arrayBytes") - - s""" - |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElements, - | ${elementType.defaultSize}); - |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try create array with " + $arraySize + - | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." + - | "$additionalErrorMessage"); - |} - |byte[] $arrayBytes = new byte[(int)$arraySize]; - |UnsafeArrayData $arrayName = new UnsafeArrayData(); - |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); - |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); - """.stripMargin - } - - /** - * Generates code creating a [[UnsafeArrayData]]. The generated code executes - * a provided fallback when the size of backing array would exceed the array size limit. - * @param arrayName a name of the array to create - * @param numElements a piece of code representing the number of elements the array should contain - * @param elementSize a size of an element in bytes - * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]] - * and getting the backing array as a parameter - * @param fallbackCode a piece of code executed when the array size limit is exceeded - */ - def createUnsafeArrayWithFallback( - arrayName: String, - numElements: String, - elementSize: Int, - bodyCode: String => String, - fallbackCode: String): String = { - val arraySize = freshName("size") - val arrayBytes = freshName("arrayBytes") - s""" - |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElements, - | $elementSize); - |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | $fallbackCode - |} else { - | final byte[] $arrayBytes = new byte[(int)$arraySize]; - | UnsafeArrayData $arrayName = new UnsafeArrayData(); - | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); - | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); - | ${bodyCode(arrayBytes)} - |} - """.stripMargin - } - - /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. * @@ -1491,6 +1424,59 @@ object CodeGenerator extends Logging { } /** + * Generates code creating a [[UnsafeArrayData]] or [[GenericArrayData]] based on + * given parameters. + * + * @param arrayName name of the array to create + * @param elementType data type of the elements in source array + * @param numElements code representing the number of elements the array should contain + * @param additionalErrorMessage string to include in the error message + * + * @return code representing the allocation of [[ArrayData]] + */ + def createArrayData( + arrayName: String, + elementType: DataType, + numElements: String, + additionalErrorMessage: String): String = { + val elementSize = if (CodeGenerator.isPrimitiveType(elementType)) { + elementType.defaultSize + } else { + -1 + } + s""" + |ArrayData $arrayName = ArrayData.allocateArrayData( + | $elementSize, $numElements, "$additionalErrorMessage"); + """.stripMargin + } + + /** + * Generates assignment code for an [[ArrayData]] + * + * @param dstArray name of the array to be assigned + * @param elementType data type of the elements in destination and source arrays + * @param srcArray name of the array to be read + * @param needNullCheck value which shows whether a nullcheck is required for the returning + * assignment + * @param dstArrayIndex an index variable to access each element of destination array + * @param srcArrayIndex an index variable to access each element of source array + * + * @return code representing an assignment to each element of the [[ArrayData]], which requires + * a pair of destination and source loop index variables + */ + def createArrayAssignment( + dstArray: String, + elementType: DataType, + srcArray: String, + dstArrayIndex: String, + srcArrayIndex: String, + needNullCheck: Boolean): String = { + CodeGenerator.setArrayElement(dstArray, elementType, dstArrayIndex, + CodeGenerator.getValue(srcArray, elementType, srcArrayIndex), + if (needNullCheck) Some(s"$srcArray.isNullAt($srcArrayIndex)") else None) + } + + /** * Returns the code to update a column in Row for a given DataType. */ def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { @@ -1559,6 +1545,34 @@ object CodeGenerator extends Logging { } /** + * Generates code of setter for an [[ArrayData]]. + */ + def setArrayElement( + array: String, + elementType: DataType, + i: String, + value: String, + isNull: Option[String] = None): String = { + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + val setFunc = if (isPrimitiveType) { + s"set${CodeGenerator.primitiveTypeName(elementType)}" + } else { + "update" + } + if (isNull.isDefined && isPrimitiveType) { + s""" + |if (${isNull.get}) { + | $array.setNullAt($i); + |} else { + | $array.$setFunc($i, $value); + |} + """.stripMargin + } else { + s"$array.$setFunc($i, $value);" + } + } + + /** * Returns the specialized code to set a given value in a column vector for a given `DataType` * that could potentially be nullable. */ http://git-wip-us.apache.org/repos/asf/spark/blob/e319ac92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5e4f48e..ea6fccc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -372,7 +372,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp val values = childMap.valueArray() val length = childMap.numElements() val resultData = new Array[AnyRef](length) - var i = 0; + var i = 0 while (i < length) { val key = keys.get(i, childDataType.keyType) val value = values.get(i, childDataType.valueType) @@ -385,107 +385,123 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { + val arrayData = ctx.freshName("arrayData") val numElements = ctx.freshName("numElements") val keys = ctx.freshName("keys") val values = ctx.freshName("values") val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) - val code = if (isKeyPrimitive && isValuePrimitive) { - genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + + val wordSize = UnsafeRow.WORD_SIZE + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 + val (isPrimitive, elementSize) = if (isKeyPrimitive && isValuePrimitive) { + (true, structSize + wordSize) + } else { + (false, -1) + } + + val allocation = + s""" + |ArrayData $arrayData = ArrayData.allocateArrayData( + | $elementSize, $numElements, " $prettyName failed."); + """.stripMargin + + val code = if (isPrimitive) { + val genCodeForPrimitive = genCodeForPrimitiveElements( + ctx, arrayData, keys, values, ev.value, numElements, structSize) + s""" + |if ($arrayData instanceof UnsafeArrayData) { + | $genCodeForPrimitive + |} else { + | ${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)} + |} + """.stripMargin } else { - genCodeForAnyElements(ctx, keys, values, ev.value, numElements) + s"${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)}" } + s""" |final int $numElements = $c.numElements(); |final ArrayData $keys = $c.keyArray(); |final ArrayData $values = $c.valueArray(); + |$allocation |$code """.stripMargin }) } - private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") + private def getKey(varName: String, index: String) = + CodeGenerator.getValue(varName, childDataType.keyType, index) - private def getValue(varName: String) = { - CodeGenerator.getValue(varName, childDataType.valueType, "z") - } + private def getValue(varName: String, index: String) = + CodeGenerator.getValue(varName, childDataType.valueType, index) private def genCodeForPrimitiveElements( ctx: CodegenContext, + arrayData: String, keys: String, values: String, - arrayData: String, - numElements: String): String = { - val unsafeRow = ctx.freshName("unsafeRow") + resultArrayData: String, + numElements: String, + structSize: Int): String = { val unsafeArrayData = ctx.freshName("unsafeArrayData") + val baseObject = ctx.freshName("baseObject") + val unsafeRow = ctx.freshName("unsafeRow") val structsOffset = ctx.freshName("structsOffset") + val offset = ctx.freshName("offset") + val z = ctx.freshName("z") val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" val baseOffset = Platform.BYTE_ARRAY_OFFSET val wordSize = UnsafeRow.WORD_SIZE - val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 - val structSizeAsLong = structSize + "L" - val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) - val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.valueType) + val structSizeAsLong = s"${structSize}L" - val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" - val valueAssignmentChecked = if (childDataType.valueContainsNull) { - s""" - |if ($values.isNullAt(z)) { - | $unsafeRow.setNullAt(1); - |} else { - | $valueAssignment - |} - """.stripMargin - } else { - valueAssignment - } + val setKey = CodeGenerator.setColumn(unsafeRow, childDataType.keyType, 0, getKey(keys, z)) - val assignmentLoop = (byteArray: String) => - s""" - |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; - |UnsafeRow $unsafeRow = new UnsafeRow(2); - |for (int z = 0; z < $numElements; z++) { - | long offset = $structsOffset + z * $structSizeAsLong; - | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); - | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); - | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); - | $valueAssignmentChecked - |} - |$arrayData = $unsafeArrayData; - """.stripMargin + val valueAssignmentChecked = CodeGenerator.createArrayAssignment( + unsafeRow, childDataType.valueType, values, "1", z, childDataType.valueContainsNull) - ctx.createUnsafeArrayWithFallback( - unsafeArrayData, - numElements, - structSize + wordSize, - assignmentLoop, - genCodeForAnyElements(ctx, keys, values, arrayData, numElements)) + s""" + |UnsafeArrayData $unsafeArrayData = (UnsafeArrayData)$arrayData; + |Object $baseObject = $unsafeArrayData.getBaseObject(); + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int $z = 0; $z < $numElements; $z++) { + | long $offset = $structsOffset + $z * $structSizeAsLong; + | $unsafeArrayData.setLong($z, ($offset << 32) + $structSizeAsLong); + | $unsafeRow.pointTo($baseObject, $baseOffset + $offset, $structSize); + | $setKey; + | $valueAssignmentChecked + |} + |$resultArrayData = $arrayData; + """.stripMargin } private def genCodeForAnyElements( ctx: CodegenContext, + arrayData: String, keys: String, values: String, - arrayData: String, + resultArrayData: String, numElements: String): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val rowClass = classOf[GenericInternalRow].getName - val data = ctx.freshName("internalRowArray") - + val z = ctx.freshName("z") val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { - s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" + s"$values.isNullAt($z) ? null : (Object)${getValue(values, z)}" } else { - getValue(values) + getValue(values, z) } + val rowClass = classOf[GenericInternalRow].getName + val genericArrayDataClass = classOf[GenericArrayData].getName + val genericArrayData = ctx.freshName("genericArrayData") + val rowObject = s"new $rowClass(new Object[]{${getKey(keys, z)}, $getValueWithCheck})" s""" - |final Object[] $data = new Object[$numElements]; - |for (int z = 0; z < $numElements; z++) { - | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); + |$genericArrayDataClass $genericArrayData = ($genericArrayDataClass)$arrayData; + |for (int $z = 0; $z < $numElements; $z++) { + | $genericArrayData.update($z, $rowObject); |} - |$arrayData = new $genericArrayClass($data); + |$resultArrayData = $arrayData; """.stripMargin } @@ -610,20 +626,14 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres val finKeysName = ctx.freshName("finalKeys") val finValsName = ctx.freshName("finalValues") - val keyConcat = if (CodeGenerator.isPrimitiveType(keyType)) { - genCodeForPrimitiveArrays(ctx, keyType, false) - } else { - genCodeForNonPrimitiveArrays(ctx, keyType) - } + val keyConcat = genCodeForArrays(ctx, keyType, false) val valueConcat = if (valueType.sameType(keyType) && !(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) { keyConcat - } else if (CodeGenerator.isPrimitiveType(valueType)) { - genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) } else { - genCodeForNonPrimitiveArrays(ctx, valueType) + genCodeForArrays(ctx, valueType, dataType.valueContainsNull) } val keyArgsName = ctx.freshName("keyArgs") @@ -662,7 +672,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres """.stripMargin) } - private def genCodeForPrimitiveArrays( + private def genCodeForArrays( ctx: CodegenContext, elementType: DataType, checkForNull: Boolean): String = { @@ -670,35 +680,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres val arrayData = ctx.freshName("arrayData") val argsName = ctx.freshName("args") val numElemName = ctx.freshName("numElements") - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - - val setterCode1 = - s""" - |$arrayData.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")} - |);""".stripMargin + val y = ctx.freshName("y") + val z = ctx.freshName("z") - val setterCode = if (checkForNull) { - s""" - |if ($argsName[y].isNullAt(z)) { - | $arrayData.setNullAt($counter); - |} else { - | $setterCode1 - |}""".stripMargin - } else { - setterCode1 - } + val allocation = CodeGenerator.createArrayData( + arrayData, elementType, numElemName, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull) val concat = ctx.freshName("concat") val concatDef = s""" |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { - | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | $allocation | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < $argsName[y].numElements(); z++) { - | $setterCode + | for (int $y = 0; $y < ${children.length}; $y++) { + | for (int $z = 0; $z < $argsName[$y].numElements(); $z++) { + | $assignment | $counter++; | } | } @@ -709,32 +707,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres ctx.addNewFunction(concat, concatDef) } - private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayData = ctx.freshName("arrayObjects") - val counter = ctx.freshName("counter") - val argsName = ctx.freshName("args") - val numElemName = ctx.freshName("numElements") - - val concat = ctx.freshName("concat") - val concatDef = - s""" - |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { - | Object[] $arrayData = new Object[$numElemName]; - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < $argsName[y].numElements(); z++) { - | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; - | $counter++; - | } - | } - | return new $genericArrayClass($arrayData); - |} - """.stripMargin - - ctx.addNewFunction(concat, concatDef) - } - override def prettyName: String = "map_concat" } @@ -867,25 +839,12 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { val valueSize = dataType.valueType.defaultSize val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" - val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) - val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) - val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);" - val valueAssignment = (entry: String, idx: String) => { - val value = CodeGenerator.getValue(entry, dataType.valueType, "1") - val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);" - if (dataType.valueContainsNull) { - s""" - |if ($entry.isNullAt(1)) { - | $valueArrayData.setNullAt($idx); - |} else { - | $valueNullUnsafeAssignment - |} - """.stripMargin - } else { - valueNullUnsafeAssignment - } - } + val keyAssignment = (key: String, idx: String) => + CodeGenerator.setArrayElement(keyArrayData, dataType.keyType, idx, key) + val valueAssignment = (entry: String, idx: String) => + CodeGenerator.createArrayAssignment( + valueArrayData, dataType.valueType, entry, idx, "1", dataType.valueContainsNull) val assignmentLoop = genCodeForAssignmentLoop( ctx, childVariable, @@ -1263,40 +1222,15 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) ctx.addPartitionInitializationStatement( s"$rand = new $randomClass(${randomSeed.get}L + partitionIndex);") - val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) - val numElements = ctx.freshName("numElements") val arrayData = ctx.freshName("arrayData") - - val initialization = if (isPrimitiveType) { - ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.") - } else { - val arrayDataClass = classOf[GenericArrayData].getName() - s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);" - } - val indices = ctx.freshName("indices") val i = ctx.freshName("i") - val getValue = CodeGenerator.getValue(childName, elementType, s"$indices[$i]") - - val setFunc = if (isPrimitiveType) { - s"set${CodeGenerator.primitiveTypeName(elementType)}" - } else { - "update" - } - - val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($childName.isNullAt($indices[$i])) { - | $arrayData.setNullAt($i); - |} else { - | $arrayData.$setFunc($i, $getValue); - |} - """.stripMargin - } else { - s"$arrayData.$setFunc($i, $getValue);" - } + val initialization = CodeGenerator.createArrayData( + arrayData, elementType, numElements, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment(arrayData, elementType, childName, + i, s"$indices[$i]", dataType.asInstanceOf[ArrayType].containsNull) s""" |int $numElements = $childName.numElements(); @@ -1354,40 +1288,16 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { - val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) - val numElements = ctx.freshName("numElements") val arrayData = ctx.freshName("arrayData") - val initialization = if (isPrimitiveType) { - ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.") - } else { - val arrayDataClass = classOf[GenericArrayData].getName - s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);" - } - val i = ctx.freshName("i") val j = ctx.freshName("j") - val getValue = CodeGenerator.getValue(childName, elementType, i) - - val setFunc = if (isPrimitiveType) { - s"set${CodeGenerator.primitiveTypeName(elementType)}" - } else { - "update" - } - - val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($childName.isNullAt($i)) { - | $arrayData.setNullAt($j); - |} else { - | $arrayData.$setFunc($j, $getValue); - |} - """.stripMargin - } else { - s"$arrayData.$setFunc($j, $getValue);" - } + val initialization = CodeGenerator.createArrayData( + arrayData, elementType, numElements, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + arrayData, elementType, childName, i, j, dataType.asInstanceOf[ArrayType].containsNull) s""" |final int $numElements = $childName.numElements(); @@ -1803,38 +1713,24 @@ case class Slice(x: Expression, start: Expression, length: Expression) resLength: String): String = { val values = ctx.freshName("values") val i = ctx.freshName("i") - val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") - if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayClass = classOf[GenericArrayData].getName - s""" - |Object[] $values; - |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { - | $values = new Object[0]; - |} else { - | $values = new Object[$resLength]; - | for (int $i = 0; $i < $resLength; $i ++) { - | $values[$i] = $getValue; - | } - |} - |${ev.value} = new $arrayClass($values); - """.stripMargin - } else { - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" - |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { - | $resLength = 0; - |} - |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")} - |for (int $i = 0; $i < $resLength; $i ++) { - | if ($inputArray.isNullAt($i + $startIdx)) { - | $values.setNullAt($i); - | } else { - | $values.set$primitiveValueTypeName($i, $getValue); - | } - |} - |${ev.value} = $values; - """.stripMargin - } + val genericArrayData = classOf[GenericArrayData].getName + + val allocation = CodeGenerator.createArrayData( + values, elementType, resLength, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment(values, elementType, inputArray, + i, s"$i + $startIdx", dataType.asInstanceOf[ArrayType].containsNull) + + s""" + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | ${ev.value} = new $genericArrayData(new Object[0]); + |} else { + | $allocation + | for (int $i = 0; $i < $resLength; $i ++) { + | $assignment + | } + | ${ev.value} = $values; + |} + """.stripMargin } } @@ -2452,11 +2348,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio case StringType => ("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];") case ArrayType(elementType, containsNull) => - val concat = if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForPrimitiveArrays(ctx, elementType, containsNull) - } else { - genCodeForNonPrimitiveArrays(ctx, elementType) - } + val concat = genCodeForArrays(ctx, elementType, containsNull) (concat, s"ArrayData[] $args = new ArrayData[${evals.length}];") } @@ -2475,62 +2367,44 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { val numElements = ctx.freshName("numElements") + val z = ctx.freshName("z") val code = s""" |long $numElements = 0L; - |for (int z = 0; z < ${children.length}; z++) { - | $numElements += args[z].numElements(); - |} - |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit" + - | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + |for (int $z = 0; $z < ${children.length}; $z++) { + | $numElements += args[$z].numElements(); |} """.stripMargin (code, numElements) } - private def genCodeForPrimitiveArrays( + private def genCodeForArrays( ctx: CodegenContext, elementType: DataType, checkForNull: Boolean): String = { val counter = ctx.freshName("counter") val arrayData = ctx.freshName("arrayData") + val y = ctx.freshName("y") + val z = ctx.freshName("z") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - - val setterCode = - s""" - |$arrayData.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} - |); - """.stripMargin - - val nullSafeSetterCode = if (checkForNull) { - s""" - |if (args[y].isNullAt(z)) { - | $arrayData.setNullAt($counter); - |} else { - | $setterCode - |} - """.stripMargin - } else { - setterCode - } + val initialization = CodeGenerator.createArrayData( + arrayData, elementType, numElemName, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + arrayData, elementType, s"args[$y]", counter, z, + dataType.asInstanceOf[ArrayType].containsNull) val concat = ctx.freshName("concat") val concatDef = s""" |private ArrayData $concat(ArrayData[] args) { | $numElemCode - | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | $initialization | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | $nullSafeSetterCode + | for (int $y = 0; $y < ${children.length}; $y++) { + | for (int $z = 0; $z < args[$y].numElements(); $z++) { + | $assignment | $counter++; | } | } @@ -2541,33 +2415,6 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio ctx.addNewFunction(concat, concatDef) } - private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayData = ctx.freshName("arrayObjects") - val counter = ctx.freshName("counter") - - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - - val concat = ctx.freshName("concat") - val concatDef = - s""" - |private ArrayData $concat(ArrayData[] args) { - | $numElemCode - | Object[] $arrayData = new Object[(int)$numElemName]; - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; - | $counter++; - | } - | } - | return new $genericArrayClass($arrayData); - |} - """.stripMargin - - ctx.addNewFunction(concat, concatDef) - } - override def toString: String = s"concat(${children.mkString(", ")})" override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" @@ -2630,11 +2477,7 @@ case class Flatten(child: Expression) extends UnaryExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { - val code = if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value) - } else { - genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) - } + val code = genCodeForFlatten(ctx, c, ev.value) ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code) }) } @@ -2648,41 +2491,36 @@ case class Flatten(child: Expression) extends UnaryExpression { |for (int z = 0; z < $childVariableName.numElements(); z++) { | $variableName += $childVariableName.getArray(z).numElements(); |} - |if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $variableName + " elements due to exceeding the array size limit" + - | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); - |} """.stripMargin (code, variableName) } - private def genCodeForFlattenOfPrimitiveElements( + private def genCodeForFlatten( ctx: CodegenContext, childVariableName: String, arrayDataName: String): String = { val counter = ctx.freshName("counter") val tempArrayDataName = ctx.freshName("tempArrayData") + val k = ctx.freshName("k") + val l = ctx.freshName("l") + val arr = ctx.freshName("arr") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val allocation = CodeGenerator.createArrayData( + tempArrayDataName, elementType, numElemName, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + tempArrayDataName, elementType, arr, counter, l, + dataType.asInstanceOf[ArrayType].containsNull) s""" |$numElemCode - |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")} + |$allocation |int $counter = 0; - |for (int k = 0; k < $childVariableName.numElements(); k++) { - | ArrayData arr = $childVariableName.getArray(k); - | for (int l = 0; l < arr.numElements(); l++) { - | if (arr.isNullAt(l)) { - | $tempArrayDataName.setNullAt($counter); - | } else { - | $tempArrayDataName.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue("arr", elementType, "l")} - | ); - | } + |for (int $k = 0; $k < $childVariableName.numElements(); $k++) { + | ArrayData $arr = $childVariableName.getArray($k); + | for (int $l = 0; $l < $arr.numElements(); $l++) { + | $assignment | $counter++; | } |} @@ -2690,30 +2528,6 @@ case class Flatten(child: Expression) extends UnaryExpression { """.stripMargin } - private def genCodeForFlattenOfNonPrimitiveElements( - ctx: CodegenContext, - childVariableName: String, - arrayDataName: String): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayName = ctx.freshName("arrayObject") - val counter = ctx.freshName("counter") - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - - s""" - |$numElemCode - |Object[] $arrayName = new Object[(int)$numElemName]; - |int $counter = 0; - |for (int k = 0; k < $childVariableName.numElements(); k++) { - | ArrayData arr = $childVariableName.getArray(k); - | for (int l = 0; l < arr.numElements(); l++) { - | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; - | $counter++; - | } - |} - |$arrayDataName = new $genericArrayClass($arrayName); - """.stripMargin - } - override def prettyName: String = "flatten" } @@ -3155,11 +2969,7 @@ case class ArrayRepeat(left: Expression, right: Expression) val count = rightGen.value val et = dataType.elementType - val coreLogic = if (CodeGenerator.isPrimitiveType(et)) { - genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value) - } else { - genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value) - } + val coreLogic = genCodeForElement(ctx, et, element, count, leftGen.isNull, ev.value) val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) ev.copy(code = @@ -3198,17 +3008,12 @@ case class ArrayRepeat(left: Expression, right: Expression) |if ($count > 0) { | $numElements = $count; |} - |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + - | " elements due to exceeding the array size limit" + - | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); - |} """.stripMargin (numElements, numElementsCode) } - private def genCodeForPrimitiveElement( + private def genCodeForElement( ctx: CodegenContext, elementType: DataType, element: String, @@ -3216,48 +3021,30 @@ case class ArrayRepeat(left: Expression, right: Expression) leftIsNull: String, arrayDataName: String): String = { val tempArrayDataName = ctx.freshName("tempArrayData") - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val errorMessage = s" $prettyName failed." + val k = ctx.freshName("k") val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + val allocation = CodeGenerator.createArrayData( + tempArrayDataName, elementType, numElemName, s" $prettyName failed.") + val assignment = + CodeGenerator.setArrayElement(tempArrayDataName, elementType, k, element) + s""" |$numElemCode - |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)} + |$allocation |if (!$leftIsNull) { - | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { - | $tempArrayDataName.set$primitiveValueTypeName(k, $element); + | for (int $k = 0; $k < $tempArrayDataName.numElements(); $k++) { + | $assignment | } |} else { - | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { - | $tempArrayDataName.setNullAt(k); + | for (int $k = 0; $k < $tempArrayDataName.numElements(); $k++) { + | $tempArrayDataName.setNullAt($k); | } |} |$arrayDataName = $tempArrayDataName; """.stripMargin } - private def genCodeForNonPrimitiveElement( - ctx: CodegenContext, - element: String, - count: String, - leftIsNull: String, - arrayDataName: String): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayName = ctx.freshName("arrayObject") - val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) - - s""" - |$numElemCode - |Object[] $arrayName = new Object[(int)$numElemName]; - |if (!$leftIsNull) { - | for (int k = 0; k < $numElemName; k++) { - | $arrayName[k] = $element; - | } - |} - |$arrayDataName = new $genericArrayClass($arrayName); - """.stripMargin - } - } /** @@ -3339,50 +3126,117 @@ case class ArrayRemove(left: Expression, right: Expression) val pos = ctx.freshName("pos") val getValue = CodeGenerator.getValue(inputArray, elementType, i) val isEqual = ctx.genEqual(elementType, value, getValue) - if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayClass = classOf[GenericArrayData].getName + + val allocation = CodeGenerator.createArrayData( + values, elementType, newArraySize, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + values, elementType, inputArray, pos, i, false) + + s""" + |$allocation + |int $pos = 0; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values.setNullAt($pos); + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $assignment + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = $values; + """.stripMargin + } + + override def prettyName: String = "array_remove" +} + +/** + * Will become common base class for [[ArrayDistinct]], [[ArrayUnion]], [[ArrayIntersect]], + * and [[ArrayExcept]]. + */ +trait ArraySetLike { + protected def dt: DataType + protected def et: DataType + + @transient protected lazy val canUseSpecializedHashSet = et match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false + } + + @transient protected lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(et) + + protected def genGetValue(array: String, i: String): String = + CodeGenerator.getValue(array, et, i) + + @transient protected lazy val (hsPostFix, hsTypeName) = { + val ptName = CodeGenerator.primitiveTypeName(et) + et match { + // we cast byte/short to int when writing to the hash set. + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", ptName) + case FloatType => ("$mcF$sp", ptName) + case DoubleType => ("$mcD$sp", ptName) + } + } + + // we cast byte/short to int when writing to the hash set. + @transient protected lazy val hsValueCast = et match { + case ByteType | ShortType => "(int) " + case _ => "" + } + + // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will + // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. + @transient protected lazy val nullValueHolder = et match { + case ByteType => "(byte) 0" + case ShortType => "(short) 0" + case _ => "0" + } + + protected def withResultArrayNullCheck( + body: String, + value: String, + nullElementIndex: String): String = { + if (dt.asInstanceOf[ArrayType].containsNull) { s""" - |int $pos = 0; - |Object[] $values = new Object[$newArraySize]; - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $values[$pos] = null; - | $pos = $pos + 1; - | } - | else { - | if (!($isEqual)) { - | $values[$pos] = $getValue; - | $pos = $pos + 1; - | } - | } + |$body + |if ($nullElementIndex >= 0) { + | // result has null element + | $value.setNullAt($nullElementIndex); |} - |${ev.value} = new $arrayClass($values); """.stripMargin } else { - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" - |${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")} - |int $pos = 0; - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $values.setNullAt($pos); - | $pos = $pos + 1; - | } - | else { - | if (!($isEqual)) { - | $values.set$primitiveValueTypeName($pos, $getValue); - | $pos = $pos + 1; - | } - | } - |} - |${ev.value} = $values; - """.stripMargin + body } } - override def prettyName: String = "array_remove" + def buildResultArray( + builder: String, + value : String, + size : String, + nullElementIndex : String): String = withResultArrayNullCheck( + s""" + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Cannot create array with " + $size + + | " elements of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); + |} + | + |if (!UnsafeArrayData.shouldUseGenericArrayData(${et.defaultSize}, $size)) { + | $value = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | $value = new ${classOf[GenericArrayData].getName}($builder.result()); + |} + """.stripMargin, value, nullElementIndex) + } + /** * Removes duplicate values from the array. */ @@ -3394,7 +3248,7 @@ case class ArrayRemove(left: Expression, right: Expression) [1,2,3,null] """, since = "2.4.0") case class ArrayDistinct(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ArraySetLike with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) @@ -3402,8 +3256,8 @@ case class ArrayDistinct(child: Expression) @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType - @transient private lazy val ordering: Ordering[Any] = - TypeUtils.getInterpretedOrdering(elementType) + override protected def dt: DataType = dataType + override protected def et: DataType = elementType override def checkInputDataTypes(): TypeCheckResult = { super.checkInputDataTypes() match { @@ -3413,28 +3267,6 @@ case class ArrayDistinct(child: Expression) } } - @transient protected lazy val canUseSpecializedHashSet = elementType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true - case _ => false - } - - @transient protected lazy val (hsPostFix, hsTypeName) = { - val ptName = CodeGenerator.primitiveTypeName(elementType) - elementType match { - // we cast byte/short to int when writing to the hash set. - case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") - case LongType => ("$mcJ$sp", ptName) - case FloatType => ("$mcF$sp", ptName) - case DoubleType => ("$mcD$sp", ptName) - } - } - - // we cast byte/short to int when writing to the hash set. - @transient protected lazy val hsValueCast = elementType match { - case ByteType | ShortType => "(int) " - case _ => "" - } - override def nullSafeEval(array: Any): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) doEvaluation(data) @@ -3471,28 +3303,73 @@ case class ArrayDistinct(child: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) + nullSafeCodeGen(ctx, ev, (array) => { - val i = ctx.freshName("i") - val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") val foundNullElement = ctx.freshName("foundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") val openHashSet = classOf[OpenHashSet[_]].getName - val hs = ctx.freshName("hs") val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" - val getValue = CodeGenerator.getValue(array, elementType, i) + val hashSet = ctx.freshName("hashSet") + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + + // Only need to track null element index when array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; + """.stripMargin + } else { + "" + } + + def withArrayNullAssignment(body: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = true; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + + val processArray = withArrayNullAssignment( + s""" + |$jt $value = ${genGetValue(array, i)}; + |if (!$hashSet.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSet.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) + s""" - |int $sizeOfDistinctArray = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $array.numElements(); $i ++) { - | if ($array.isNullAt($i)) { - | $foundNullElement = true; - | } else { - | $hs.add$hsPostFix($hsValueCast$getValue); - | } + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |$arrayBuilderClass $builder = new $arrayBuilderClass(); + |int $size = 0; + |for (int $i = 0; $i < $array.numElements(); $i++) { + | $processArray |} - |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); - |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + |${buildResultArray(builder, ev.value, size, nullElementIndex)} """.stripMargin }) } else { @@ -3503,73 +3380,16 @@ case class ArrayDistinct(child: Expression) } } - private def setNull( - foundNullElement: String, - distinctArray: String, - pos: String): String = { - val setNullValue = s"$distinctArray.setNullAt($pos)" - s""" - |if (!($foundNullElement)) { - | $setNullValue; - | $pos = $pos + 1; - | $foundNullElement = true; - |} - """.stripMargin - } - - private def setValue( - hs: String, - distinctArray: String, - pos: String, - getValue1: String, - primitiveValueTypeName: String): String = { - s""" - |if (!($hs.contains$hsPostFix($hsValueCast$getValue1))) { - | $hs.add$hsPostFix($hsValueCast$getValue1); - | $distinctArray.set$primitiveValueTypeName($pos, $getValue1); - | $pos = $pos + 1; - |} - """.stripMargin - } - - def genCodeForResult( - ctx: CodegenContext, - ev: ExprCode, - inputArray: String, - size: String): String = { - val distinctArray = ctx.freshName("distinctArray") - val i = ctx.freshName("i") - val pos = ctx.freshName("pos") - val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) - val foundNullElement = ctx.freshName("foundNullElement") - val hs = ctx.freshName("hs") - val openHashSet = classOf[OpenHashSet[_]].getName - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" - - s""" - |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} - |int $pos = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | ${setNull(foundNullElement, distinctArray, pos)} - | } else { - | ${setValue(hs, distinctArray, pos, getValue1, primitiveValueTypeName)} - | } - |} - |${ev.value} = $distinctArray; - """.stripMargin - } - override def prettyName: String = "array_distinct" } /** * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. */ -abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { +trait ArrayBinaryLike extends BinaryArrayExpressionWithImplicitCast with ArraySetLike { + override protected def dt: DataType = dataType + override protected def et: DataType = elementType + override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() if (typeCheckResult.isSuccess) { @@ -3579,81 +3399,9 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { typeCheckResult } } - - @transient protected lazy val ordering: Ordering[Any] = - TypeUtils.getInterpretedOrdering(elementType) - - @transient protected lazy val canUseSpecializedHashSet = elementType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true - case _ => false - } - - protected def genGetValue(array: String, i: String): String = - CodeGenerator.getValue(array, elementType, i) - - @transient protected lazy val (hsPostFix, hsTypeName) = { - val ptName = CodeGenerator.primitiveTypeName (elementType) - elementType match { - // we cast byte/short to int when writing to the hash set. - case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") - case LongType => ("$mcJ$sp", ptName) - case FloatType => ("$mcF$sp", ptName) - case DoubleType => ("$mcD$sp", ptName) - } - } - - // we cast byte/short to int when writing to the hash set. - @transient protected lazy val hsValueCast = elementType match { - case ByteType | ShortType => "(int) " - case _ => "" - } - - // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will - // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. - @transient protected lazy val nullValueHolder = elementType match { - case ByteType => "(byte) 0" - case ShortType => "(short) 0" - case _ => "0" - } - - protected def withResultArrayNullCheck( - body: String, - value: String, - nullElementIndex: String): String = { - if (dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |$body - |if ($nullElementIndex >= 0) { - | // result has null element - | $value.setNullAt($nullElementIndex); - |} - """.stripMargin - } else { - body - } - } - - def buildResultArray( - builder: String, - value : String, - size : String, - nullElementIndex : String): String = withResultArrayNullCheck( - s""" - |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Cannot create array with " + $size + - | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); - |} - | - |if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) { - | $value = UnsafeArrayData.fromPrimitiveArray($builder.result()); - |} else { - | $value = new ${classOf[GenericArrayData].getName}($builder.result()); - |} - """.stripMargin, value, nullElementIndex) } -object ArraySetLike { +object ArrayBinaryLike { def throwUnionLengthOverflowException(length: Int): Unit = { throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + s"elements due to exceeding the array size limit " + @@ -3676,7 +3424,7 @@ object ArraySetLike { array(1, 2, 3, 5) """, since = "2.4.0") -case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike +case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLike with ComplexTypeMergingExpression { @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = { @@ -3697,7 +3445,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike val elem = array.get(i, elementType) if (!hs.contains(elem)) { if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) } arrayBuffer += elem hs.add(elem) @@ -3732,7 +3480,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } if (!found) { if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.length) } arrayBuffer += elem } @@ -3864,7 +3612,7 @@ object ArrayUnion { } if (!found) { if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.length) } arrayBuffer += elem } @@ -3887,7 +3635,7 @@ object ArrayUnion { array(1, 3) """, since = "2.4.0") -case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetLike +case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBinaryLike with ComplexTypeMergingExpression { override def dataType: DataType = { dataTypeCheck @@ -4128,7 +3876,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL array(2) """, since = "2.4.0") -case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike +case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryLike with ComplexTypeMergingExpression { override def dataType: DataType = { http://git-wip-us.apache.org/repos/asf/spark/blob/e319ac92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 104b428..4da8ce0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -22,6 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods object ArrayData { def toArrayData(input: Any): ArrayData = input match { @@ -34,6 +36,31 @@ object ArrayData { case a: Array[Double] => UnsafeArrayData.fromPrimitiveArray(a) case other => new GenericArrayData(other) } + + + /** + * Allocate [[UnsafeArrayData]] or [[GenericArrayData]] based on given parameters. + * + * @param elementSize a size of an element in bytes. If less than zero, the type of an element is + * non-primitive type + * @param numElements the number of elements the array should contain + * @param additionalErrorMessage string to include in the error message + */ + def allocateArrayData( + elementSize: Int, + numElements: Long, + additionalErrorMessage: String): ArrayData = { + if (elementSize >= 0 && !UnsafeArrayData.shouldUseGenericArrayData(elementSize, numElements)) { + UnsafeArrayData.createFreshArray(numElements.toInt, elementSize) + } else if (numElements <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toLong) { + new GenericArrayData(new Array[Any](numElements.toInt)) + } else { + throw new RuntimeException(s"Cannot create array with $numElements " + + "elements of data due to exceeding the limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData. " + + additionalErrorMessage) + } + } } abstract class ArrayData extends SpecializedGetters with Serializable { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
