Repository: spark
Updated Branches:
  refs/heads/master 4cb2ff9d8 -> e319ac92e


[SPARK-24962][SQL] Refactor CodeGenerator.createUnsafeArray, ArraySetLike, and 
ArrayDistinct

## What changes were proposed in this pull request?

This PR integrates handling of `UnsafeArrayData` and `GenericArrayData` into 
one. The current `CodeGenerator.createUnsafeArray` handles only allocation of 
`UnsafeArrayData`.
This PR introduces a new method `createArrayData` that returns a code to 
allocate `UnsafeArrayData` or `GenericArrayData` and to assign a value into the 
allocated array.

This PR also reduce the size of generated code by calling a runtime helper.

This PR replaced `createArrayData` with `createUnsafeArray`. This PR also 
refactor `ArraySetLike` that can be used for `ArrayDistinct`, too.
This PR also refactors`ArrayDistinct` to use `ArraryBuilder`.

## How was this patch tested?

Existing tests

Closes #21912 from kiszk/SPARK-24962.

Lead-authored-by: Kazuaki Ishizaki <[email protected]>
Co-authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e319ac92
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e319ac92
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e319ac92

Branch: refs/heads/master
Commit: e319ac92e597d95eba5b787bb7a5d5499bb3f87c
Parents: 4cb2ff9
Author: Kazuaki Ishizaki <[email protected]>
Authored: Tue Sep 4 15:26:34 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Tue Sep 4 15:26:34 2018 +0800

----------------------------------------------------------------------
 .../catalyst/expressions/UnsafeArrayData.java   |  22 +-
 .../expressions/codegen/CodeGenerator.scala     | 150 +--
 .../expressions/collectionOperations.scala      | 926 +++++++------------
 .../spark/sql/catalyst/util/ArrayData.scala     |  27 +
 4 files changed, 464 insertions(+), 661 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e319ac92/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index cf2a5ed..9e7b15d 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -473,13 +473,27 @@ public final class UnsafeArrayData extends ArrayData {
     return result;
   }
 
-  public static UnsafeArrayData forPrimitiveArray(int offset, int length, int 
elementSize) {
-    return fromPrimitiveArray(null, offset, length, elementSize);
+  public static UnsafeArrayData createFreshArray(int length, int elementSize) {
+    final long headerInBytes = calculateHeaderPortionInBytes(length);
+    final long valueRegionInBytes = (long)elementSize * length;
+    final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
+    if (totalSizeInLongs > Integer.MAX_VALUE / 8) {
+      throw new UnsupportedOperationException("Cannot convert this array to 
unsafe format as " +
+        "it's too big.");
+    }
+
+    final long[] data = new long[(int)totalSizeInLongs];
+
+    Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length);
+
+    UnsafeArrayData result = new UnsafeArrayData();
+    result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 
8);
+    return result;
   }
 
-  public static boolean shouldUseGenericArrayData(int elementSize, int length) 
{
+  public static boolean shouldUseGenericArrayData(int elementSize, long 
length) {
     final long headerInBytes = calculateHeaderPortionInBytes(length);
-    final long valueRegionInBytes = (long)elementSize * length;
+    final long valueRegionInBytes = elementSize * length;
     final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
     return totalSizeInLongs > Integer.MAX_VALUE / 8;
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e319ac92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index b8f0976..d5857e0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -39,7 +39,7 @@ import org.apache.spark.metrics.source.CodegenMetrics
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, 
MapData}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.Platform
@@ -747,73 +747,6 @@ class CodegenContext {
   }
 
   /**
-   * Generates code creating a [[UnsafeArrayData]].
-   *
-   * @param arrayName name of the array to create
-   * @param numElements code representing the number of elements the array 
should contain
-   * @param elementType data type of the elements in the array
-   * @param additionalErrorMessage string to include in the error message
-   */
-  def createUnsafeArray(
-      arrayName: String,
-      numElements: String,
-      elementType: DataType,
-      additionalErrorMessage: String): String = {
-    val arraySize = freshName("size")
-    val arrayBytes = freshName("arrayBytes")
-
-    s"""
-       |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
-       |  $numElements,
-       |  ${elementType.defaultSize});
-       |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
-       |  throw new RuntimeException("Unsuccessful try create array with " + 
$arraySize +
-       |    " bytes of data due to exceeding the limit " +
-       |    "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for 
UnsafeArrayData." +
-       |    "$additionalErrorMessage");
-       |}
-       |byte[] $arrayBytes = new byte[(int)$arraySize];
-       |UnsafeArrayData $arrayName = new UnsafeArrayData();
-       |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, 
$numElements);
-       |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, 
(int)$arraySize);
-      """.stripMargin
-  }
-
-  /**
-   * Generates code creating a [[UnsafeArrayData]]. The generated code executes
-   * a provided fallback when the size of backing array would exceed the array 
size limit.
-   * @param arrayName a name of the array to create
-   * @param numElements a piece of code representing the number of elements 
the array should contain
-   * @param elementSize a size of an element in bytes
-   * @param bodyCode a function generating code that fills up the 
[[UnsafeArrayData]]
-   *                 and getting the backing array as a parameter
-   * @param fallbackCode a piece of code executed when the array size limit is 
exceeded
-   */
-  def createUnsafeArrayWithFallback(
-      arrayName: String,
-      numElements: String,
-      elementSize: Int,
-      bodyCode: String => String,
-      fallbackCode: String): String = {
-    val arraySize = freshName("size")
-    val arrayBytes = freshName("arrayBytes")
-    s"""
-       |final long $arraySize = 
UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
-       |  $numElements,
-       |  $elementSize);
-       |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
-       |  $fallbackCode
-       |} else {
-       |  final byte[] $arrayBytes = new byte[(int)$arraySize];
-       |  UnsafeArrayData $arrayName = new UnsafeArrayData();
-       |  Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, 
$numElements);
-       |  $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, 
(int)$arraySize);
-       |  ${bodyCode(arrayBytes)}
-       |}
-     """.stripMargin
-  }
-
-  /**
    * Generates code to do null safe execution, i.e. only execute the code when 
the input is not
    * null by adding null check if necessary.
    *
@@ -1491,6 +1424,59 @@ object CodeGenerator extends Logging {
   }
 
   /**
+   * Generates code creating a [[UnsafeArrayData]] or [[GenericArrayData]] 
based on
+   * given parameters.
+   *
+   * @param arrayName name of the array to create
+   * @param elementType data type of the elements in source array
+   * @param numElements code representing the number of elements the array 
should contain
+   * @param additionalErrorMessage string to include in the error message
+   *
+   * @return code representing the allocation of [[ArrayData]]
+   */
+  def createArrayData(
+      arrayName: String,
+      elementType: DataType,
+      numElements: String,
+      additionalErrorMessage: String): String = {
+    val elementSize = if (CodeGenerator.isPrimitiveType(elementType)) {
+      elementType.defaultSize
+    } else {
+      -1
+    }
+    s"""
+       |ArrayData $arrayName = ArrayData.allocateArrayData(
+       |  $elementSize, $numElements, "$additionalErrorMessage");
+     """.stripMargin
+  }
+
+  /**
+   * Generates assignment code for an [[ArrayData]]
+   *
+   * @param dstArray name of the array to be assigned
+   * @param elementType data type of the elements in destination and source 
arrays
+   * @param srcArray name of the array to be read
+   * @param needNullCheck value which shows whether a nullcheck is required 
for the returning
+   *                      assignment
+   * @param dstArrayIndex an index variable to access each element of 
destination array
+   * @param srcArrayIndex an index variable to access each element of source 
array
+   *
+   * @return code representing an assignment to each element of the 
[[ArrayData]], which requires
+   *         a pair of destination and source loop index variables
+   */
+  def createArrayAssignment(
+      dstArray: String,
+      elementType: DataType,
+      srcArray: String,
+      dstArrayIndex: String,
+      srcArrayIndex: String,
+      needNullCheck: Boolean): String = {
+    CodeGenerator.setArrayElement(dstArray, elementType, dstArrayIndex,
+      CodeGenerator.getValue(srcArray, elementType, srcArrayIndex),
+      if (needNullCheck) Some(s"$srcArray.isNullAt($srcArrayIndex)") else None)
+  }
+
+  /**
    * Returns the code to update a column in Row for a given DataType.
    */
   def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): 
String = {
@@ -1559,6 +1545,34 @@ object CodeGenerator extends Logging {
   }
 
   /**
+   * Generates code of setter for an [[ArrayData]].
+   */
+  def setArrayElement(
+      array: String,
+      elementType: DataType,
+      i: String,
+      value: String,
+      isNull: Option[String] = None): String = {
+    val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
+    val setFunc = if (isPrimitiveType) {
+      s"set${CodeGenerator.primitiveTypeName(elementType)}"
+    } else {
+      "update"
+    }
+    if (isNull.isDefined && isPrimitiveType) {
+      s"""
+         |if (${isNull.get}) {
+         |  $array.setNullAt($i);
+         |} else {
+         |  $array.$setFunc($i, $value);
+         |}
+       """.stripMargin
+    } else {
+      s"$array.$setFunc($i, $value);"
+    }
+  }
+
+  /**
    * Returns the specialized code to set a given value in a column vector for 
a given `DataType`
    * that could potentially be nullable.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/e319ac92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 5e4f48e..ea6fccc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -372,7 +372,7 @@ case class MapEntries(child: Expression) extends 
UnaryExpression with ExpectsInp
     val values = childMap.valueArray()
     val length = childMap.numElements()
     val resultData = new Array[AnyRef](length)
-    var i = 0;
+    var i = 0
     while (i < length) {
       val key = keys.get(i, childDataType.keyType)
       val value = values.get(i, childDataType.valueType)
@@ -385,107 +385,123 @@ case class MapEntries(child: Expression) extends 
UnaryExpression with ExpectsInp
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     nullSafeCodeGen(ctx, ev, c => {
+      val arrayData = ctx.freshName("arrayData")
       val numElements = ctx.freshName("numElements")
       val keys = ctx.freshName("keys")
       val values = ctx.freshName("values")
       val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType)
       val isValuePrimitive = 
CodeGenerator.isPrimitiveType(childDataType.valueType)
-      val code = if (isKeyPrimitive && isValuePrimitive) {
-        genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements)
+
+      val wordSize = UnsafeRow.WORD_SIZE
+      val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2
+      val (isPrimitive, elementSize) = if (isKeyPrimitive && isValuePrimitive) 
{
+        (true, structSize + wordSize)
+      } else {
+        (false, -1)
+      }
+
+      val allocation =
+        s"""
+           |ArrayData $arrayData = ArrayData.allocateArrayData(
+           |  $elementSize, $numElements, " $prettyName failed.");
+         """.stripMargin
+
+      val code = if (isPrimitive) {
+        val genCodeForPrimitive = genCodeForPrimitiveElements(
+          ctx, arrayData, keys, values, ev.value, numElements, structSize)
+        s"""
+           |if ($arrayData instanceof UnsafeArrayData) {
+           |  $genCodeForPrimitive
+           |} else {
+           |  ${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, 
numElements)}
+           |}
+         """.stripMargin
       } else {
-        genCodeForAnyElements(ctx, keys, values, ev.value, numElements)
+        s"${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, 
numElements)}"
       }
+
       s"""
          |final int $numElements = $c.numElements();
          |final ArrayData $keys = $c.keyArray();
          |final ArrayData $values = $c.valueArray();
+         |$allocation
          |$code
        """.stripMargin
     })
   }
 
-  private def getKey(varName: String) = CodeGenerator.getValue(varName, 
childDataType.keyType, "z")
+  private def getKey(varName: String, index: String) =
+    CodeGenerator.getValue(varName, childDataType.keyType, index)
 
-  private def getValue(varName: String) = {
-    CodeGenerator.getValue(varName, childDataType.valueType, "z")
-  }
+  private def getValue(varName: String, index: String) =
+    CodeGenerator.getValue(varName, childDataType.valueType, index)
 
   private def genCodeForPrimitiveElements(
       ctx: CodegenContext,
+      arrayData: String,
       keys: String,
       values: String,
-      arrayData: String,
-      numElements: String): String = {
-    val unsafeRow = ctx.freshName("unsafeRow")
+      resultArrayData: String,
+      numElements: String,
+      structSize: Int): String = {
     val unsafeArrayData = ctx.freshName("unsafeArrayData")
+    val baseObject = ctx.freshName("baseObject")
+    val unsafeRow = ctx.freshName("unsafeRow")
     val structsOffset = ctx.freshName("structsOffset")
+    val offset = ctx.freshName("offset")
+    val z = ctx.freshName("z")
     val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
 
     val baseOffset = Platform.BYTE_ARRAY_OFFSET
     val wordSize = UnsafeRow.WORD_SIZE
-    val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2
-    val structSizeAsLong = structSize + "L"
-    val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
-    val valueTypeName = 
CodeGenerator.primitiveTypeName(childDataType.valueType)
+    val structSizeAsLong = s"${structSize}L"
 
-    val valueAssignment = s"$unsafeRow.set$valueTypeName(1, 
${getValue(values)});"
-    val valueAssignmentChecked = if (childDataType.valueContainsNull) {
-      s"""
-         |if ($values.isNullAt(z)) {
-         |  $unsafeRow.setNullAt(1);
-         |} else {
-         |  $valueAssignment
-         |}
-       """.stripMargin
-    } else {
-      valueAssignment
-    }
+    val setKey = CodeGenerator.setColumn(unsafeRow, childDataType.keyType, 0, 
getKey(keys, z))
 
-    val assignmentLoop = (byteArray: String) =>
-      s"""
-         |final int $structsOffset = $calculateHeader($numElements) + 
$numElements * $wordSize;
-         |UnsafeRow $unsafeRow = new UnsafeRow(2);
-         |for (int z = 0; z < $numElements; z++) {
-         |  long offset = $structsOffset + z * $structSizeAsLong;
-         |  $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong);
-         |  $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize);
-         |  $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
-         |  $valueAssignmentChecked
-         |}
-         |$arrayData = $unsafeArrayData;
-       """.stripMargin
+    val valueAssignmentChecked = CodeGenerator.createArrayAssignment(
+      unsafeRow, childDataType.valueType, values, "1", z, 
childDataType.valueContainsNull)
 
-    ctx.createUnsafeArrayWithFallback(
-      unsafeArrayData,
-      numElements,
-      structSize + wordSize,
-      assignmentLoop,
-      genCodeForAnyElements(ctx, keys, values, arrayData, numElements))
+    s"""
+       |UnsafeArrayData $unsafeArrayData = (UnsafeArrayData)$arrayData;
+       |Object $baseObject = $unsafeArrayData.getBaseObject();
+       |final int $structsOffset = $calculateHeader($numElements) + 
$numElements * $wordSize;
+       |UnsafeRow $unsafeRow = new UnsafeRow(2);
+       |for (int $z = 0; $z < $numElements; $z++) {
+       |  long $offset = $structsOffset + $z * $structSizeAsLong;
+       |  $unsafeArrayData.setLong($z, ($offset << 32) + $structSizeAsLong);
+       |  $unsafeRow.pointTo($baseObject, $baseOffset + $offset, $structSize);
+       |  $setKey;
+       |  $valueAssignmentChecked
+       |}
+       |$resultArrayData = $arrayData;
+     """.stripMargin
   }
 
   private def genCodeForAnyElements(
       ctx: CodegenContext,
+      arrayData: String,
       keys: String,
       values: String,
-      arrayData: String,
+      resultArrayData: String,
       numElements: String): String = {
-    val genericArrayClass = classOf[GenericArrayData].getName
-    val rowClass = classOf[GenericInternalRow].getName
-    val data = ctx.freshName("internalRowArray")
-
+    val z = ctx.freshName("z")
     val isValuePrimitive = 
CodeGenerator.isPrimitiveType(childDataType.valueType)
     val getValueWithCheck = if (childDataType.valueContainsNull && 
isValuePrimitive) {
-      s"$values.isNullAt(z) ? null : (Object)${getValue(values)}"
+      s"$values.isNullAt($z) ? null : (Object)${getValue(values, z)}"
     } else {
-      getValue(values)
+      getValue(values, z)
     }
 
+    val rowClass = classOf[GenericInternalRow].getName
+    val genericArrayDataClass = classOf[GenericArrayData].getName
+    val genericArrayData = ctx.freshName("genericArrayData")
+    val rowObject = s"new $rowClass(new Object[]{${getKey(keys, z)}, 
$getValueWithCheck})"
     s"""
-       |final Object[] $data = new Object[$numElements];
-       |for (int z = 0; z < $numElements; z++) {
-       |  $data[z] = new $rowClass(new Object[]{${getKey(keys)}, 
$getValueWithCheck});
+       |$genericArrayDataClass $genericArrayData = 
($genericArrayDataClass)$arrayData;
+       |for (int $z = 0; $z < $numElements; $z++) {
+       |  $genericArrayData.update($z, $rowObject);
        |}
-       |$arrayData = new $genericArrayClass($data);
+       |$resultArrayData = $arrayData;
      """.stripMargin
   }
 
@@ -610,20 +626,14 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
     val finKeysName = ctx.freshName("finalKeys")
     val finValsName = ctx.freshName("finalValues")
 
-    val keyConcat = if (CodeGenerator.isPrimitiveType(keyType)) {
-      genCodeForPrimitiveArrays(ctx, keyType, false)
-    } else {
-      genCodeForNonPrimitiveArrays(ctx, keyType)
-    }
+    val keyConcat = genCodeForArrays(ctx, keyType, false)
 
     val valueConcat =
       if (valueType.sameType(keyType) &&
           !(CodeGenerator.isPrimitiveType(valueType) && 
dataType.valueContainsNull)) {
         keyConcat
-      } else if (CodeGenerator.isPrimitiveType(valueType)) {
-        genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
       } else {
-        genCodeForNonPrimitiveArrays(ctx, valueType)
+        genCodeForArrays(ctx, valueType, dataType.valueContainsNull)
       }
 
     val keyArgsName = ctx.freshName("keyArgs")
@@ -662,7 +672,7 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
       """.stripMargin)
   }
 
-  private def genCodeForPrimitiveArrays(
+  private def genCodeForArrays(
       ctx: CodegenContext,
       elementType: DataType,
       checkForNull: Boolean): String = {
@@ -670,35 +680,23 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
     val arrayData = ctx.freshName("arrayData")
     val argsName = ctx.freshName("args")
     val numElemName = ctx.freshName("numElements")
-    val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
-
-    val setterCode1 =
-      s"""
-         |$arrayData.set$primitiveValueTypeName(
-         |  $counter,
-         |  ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}
-         |);""".stripMargin
+    val y = ctx.freshName("y")
+    val z = ctx.freshName("z")
 
-    val setterCode = if (checkForNull) {
-      s"""
-         |if ($argsName[y].isNullAt(z)) {
-         |  $arrayData.setNullAt($counter);
-         |} else {
-         |  $setterCode1
-         |}""".stripMargin
-    } else {
-      setterCode1
-    }
+    val allocation = CodeGenerator.createArrayData(
+      arrayData, elementType, numElemName, s" $prettyName failed.")
+    val assignment = CodeGenerator.createArrayAssignment(
+      arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull)
 
     val concat = ctx.freshName("concat")
     val concatDef =
       s"""
          |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
-         |  ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" 
$prettyName failed.")}
+         |  $allocation
          |  int $counter = 0;
-         |  for (int y = 0; y < ${children.length}; y++) {
-         |    for (int z = 0; z < $argsName[y].numElements(); z++) {
-         |      $setterCode
+         |  for (int $y = 0; $y < ${children.length}; $y++) {
+         |    for (int $z = 0; $z < $argsName[$y].numElements(); $z++) {
+         |      $assignment
          |      $counter++;
          |    }
          |  }
@@ -709,32 +707,6 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
     ctx.addNewFunction(concat, concatDef)
   }
 
-  private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: 
DataType): String = {
-    val genericArrayClass = classOf[GenericArrayData].getName
-    val arrayData = ctx.freshName("arrayObjects")
-    val counter = ctx.freshName("counter")
-    val argsName = ctx.freshName("args")
-    val numElemName = ctx.freshName("numElements")
-
-    val concat = ctx.freshName("concat")
-    val concatDef =
-      s"""
-         |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
-         |  Object[] $arrayData = new Object[$numElemName];
-         |  int $counter = 0;
-         |  for (int y = 0; y < ${children.length}; y++) {
-         |    for (int z = 0; z < $argsName[y].numElements(); z++) {
-         |      $arrayData[$counter] = 
${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
-         |      $counter++;
-         |    }
-         |  }
-         |  return new $genericArrayClass($arrayData);
-         |}
-       """.stripMargin
-
-    ctx.addNewFunction(concat, concatDef)
-  }
-
   override def prettyName: String = "map_concat"
 }
 
@@ -867,25 +839,12 @@ case class MapFromEntries(child: Expression) extends 
UnaryExpression {
     val valueSize = dataType.valueType.defaultSize
     val kByteSize = 
s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)"
     val vByteSize = 
s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)"
-    val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType)
-    val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType)
 
-    val keyAssignment = (key: String, idx: String) => 
s"$keyArrayData.set$keyTypeName($idx, $key);"
-    val valueAssignment = (entry: String, idx: String) => {
-      val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
-      val valueNullUnsafeAssignment = 
s"$valueArrayData.set$valueTypeName($idx, $value);"
-      if (dataType.valueContainsNull) {
-        s"""
-           |if ($entry.isNullAt(1)) {
-           |  $valueArrayData.setNullAt($idx);
-           |} else {
-           |  $valueNullUnsafeAssignment
-           |}
-         """.stripMargin
-      } else {
-        valueNullUnsafeAssignment
-      }
-    }
+    val keyAssignment = (key: String, idx: String) =>
+      CodeGenerator.setArrayElement(keyArrayData, dataType.keyType, idx, key)
+    val valueAssignment = (entry: String, idx: String) =>
+      CodeGenerator.createArrayAssignment(
+        valueArrayData, dataType.valueType, entry, idx, "1", 
dataType.valueContainsNull)
     val assignmentLoop = genCodeForAssignmentLoop(
       ctx,
       childVariable,
@@ -1263,40 +1222,15 @@ case class Shuffle(child: Expression, randomSeed: 
Option[Long] = None)
     ctx.addPartitionInitializationStatement(
       s"$rand = new $randomClass(${randomSeed.get}L + partitionIndex);")
 
-    val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
-
     val numElements = ctx.freshName("numElements")
     val arrayData = ctx.freshName("arrayData")
-
-    val initialization = if (isPrimitiveType) {
-      ctx.createUnsafeArray(arrayData, numElements, elementType, s" 
$prettyName failed.")
-    } else {
-      val arrayDataClass = classOf[GenericArrayData].getName()
-      s"$arrayDataClass $arrayData = new $arrayDataClass(new 
Object[$numElements]);"
-    }
-
     val indices = ctx.freshName("indices")
     val i = ctx.freshName("i")
 
-    val getValue = CodeGenerator.getValue(childName, elementType, 
s"$indices[$i]")
-
-    val setFunc = if (isPrimitiveType) {
-      s"set${CodeGenerator.primitiveTypeName(elementType)}"
-    } else {
-      "update"
-    }
-
-    val assignment = if (isPrimitiveType && 
dataType.asInstanceOf[ArrayType].containsNull) {
-      s"""
-         |if ($childName.isNullAt($indices[$i])) {
-         |  $arrayData.setNullAt($i);
-         |} else {
-         |  $arrayData.$setFunc($i, $getValue);
-         |}
-       """.stripMargin
-    } else {
-      s"$arrayData.$setFunc($i, $getValue);"
-    }
+    val initialization = CodeGenerator.createArrayData(
+      arrayData, elementType, numElements, s" $prettyName failed.")
+    val assignment = CodeGenerator.createArrayAssignment(arrayData, 
elementType, childName,
+      i, s"$indices[$i]", dataType.asInstanceOf[ArrayType].containsNull)
 
     s"""
        |int $numElements = $childName.numElements();
@@ -1354,40 +1288,16 @@ case class Reverse(child: Expression) extends 
UnaryExpression with ImplicitCastI
 
   private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: 
String): String = {
 
-    val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
-
     val numElements = ctx.freshName("numElements")
     val arrayData = ctx.freshName("arrayData")
 
-    val initialization = if (isPrimitiveType) {
-      ctx.createUnsafeArray(arrayData, numElements, elementType, s" 
$prettyName failed.")
-    } else {
-      val arrayDataClass = classOf[GenericArrayData].getName
-      s"$arrayDataClass $arrayData = new $arrayDataClass(new 
Object[$numElements]);"
-    }
-
     val i = ctx.freshName("i")
     val j = ctx.freshName("j")
 
-    val getValue = CodeGenerator.getValue(childName, elementType, i)
-
-    val setFunc = if (isPrimitiveType) {
-      s"set${CodeGenerator.primitiveTypeName(elementType)}"
-    } else {
-      "update"
-    }
-
-    val assignment = if (isPrimitiveType && 
dataType.asInstanceOf[ArrayType].containsNull) {
-      s"""
-         |if ($childName.isNullAt($i)) {
-         |  $arrayData.setNullAt($j);
-         |} else {
-         |  $arrayData.$setFunc($j, $getValue);
-         |}
-       """.stripMargin
-    } else {
-      s"$arrayData.$setFunc($j, $getValue);"
-    }
+    val initialization = CodeGenerator.createArrayData(
+      arrayData, elementType, numElements, s" $prettyName failed.")
+    val assignment = CodeGenerator.createArrayAssignment(
+      arrayData, elementType, childName, i, j, 
dataType.asInstanceOf[ArrayType].containsNull)
 
     s"""
        |final int $numElements = $childName.numElements();
@@ -1803,38 +1713,24 @@ case class Slice(x: Expression, start: Expression, 
length: Expression)
       resLength: String): String = {
     val values = ctx.freshName("values")
     val i = ctx.freshName("i")
-    val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + 
$startIdx")
-    if (!CodeGenerator.isPrimitiveType(elementType)) {
-      val arrayClass = classOf[GenericArrayData].getName
-      s"""
-         |Object[] $values;
-         |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
-         |  $values = new Object[0];
-         |} else {
-         |  $values = new Object[$resLength];
-         |  for (int $i = 0; $i < $resLength; $i ++) {
-         |    $values[$i] = $getValue;
-         |  }
-         |}
-         |${ev.value} = new $arrayClass($values);
-       """.stripMargin
-    } else {
-      val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
-      s"""
-         |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
-         |  $resLength = 0;
-         |}
-         |${ctx.createUnsafeArray(values, resLength, elementType, s" 
$prettyName failed.")}
-         |for (int $i = 0; $i < $resLength; $i ++) {
-         |  if ($inputArray.isNullAt($i + $startIdx)) {
-         |    $values.setNullAt($i);
-         |  } else {
-         |    $values.set$primitiveValueTypeName($i, $getValue);
-         |  }
-         |}
-         |${ev.value} = $values;
-       """.stripMargin
-    }
+    val genericArrayData = classOf[GenericArrayData].getName
+
+    val allocation = CodeGenerator.createArrayData(
+      values, elementType, resLength, s" $prettyName failed.")
+    val assignment = CodeGenerator.createArrayAssignment(values, elementType, 
inputArray,
+      i, s"$i + $startIdx", dataType.asInstanceOf[ArrayType].containsNull)
+
+    s"""
+       |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
+       |  ${ev.value} = new $genericArrayData(new Object[0]);
+       |} else {
+       |  $allocation
+       |  for (int $i = 0; $i < $resLength; $i ++) {
+       |    $assignment
+       |  }
+       |  ${ev.value} = $values;
+       |}
+     """.stripMargin
   }
 }
 
@@ -2452,11 +2348,7 @@ case class Concat(children: Seq[Expression]) extends 
ComplexTypeMergingExpressio
       case StringType =>
         ("UTF8String.concat", s"UTF8String[] $args = new 
UTF8String[${evals.length}];")
       case ArrayType(elementType, containsNull) =>
-        val concat = if (CodeGenerator.isPrimitiveType(elementType)) {
-          genCodeForPrimitiveArrays(ctx, elementType, containsNull)
-        } else {
-          genCodeForNonPrimitiveArrays(ctx, elementType)
-        }
+        val concat = genCodeForArrays(ctx, elementType, containsNull)
         (concat, s"ArrayData[] $args = new ArrayData[${evals.length}];")
     }
 
@@ -2475,62 +2367,44 @@ case class Concat(children: Seq[Expression]) extends 
ComplexTypeMergingExpressio
 
   private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, 
String) = {
     val numElements = ctx.freshName("numElements")
+    val z = ctx.freshName("z")
     val code = s"""
         |long $numElements = 0L;
-        |for (int z = 0; z < ${children.length}; z++) {
-        |  $numElements += args[z].numElements();
-        |}
-        |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
-        |  throw new RuntimeException("Unsuccessful try to concat arrays with 
" + $numElements +
-        |    " elements due to exceeding the array size limit" +
-        |    " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
+        |for (int $z = 0; $z < ${children.length}; $z++) {
+        |  $numElements += args[$z].numElements();
         |}
       """.stripMargin
 
     (code, numElements)
   }
 
-  private def genCodeForPrimitiveArrays(
+  private def genCodeForArrays(
       ctx: CodegenContext,
       elementType: DataType,
       checkForNull: Boolean): String = {
     val counter = ctx.freshName("counter")
     val arrayData = ctx.freshName("arrayData")
+    val y = ctx.freshName("y")
+    val z = ctx.freshName("z")
 
     val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
 
-    val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
-
-    val setterCode =
-      s"""
-         |$arrayData.set$primitiveValueTypeName(
-         |  $counter,
-         |  ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
-         |);
-       """.stripMargin
-
-    val nullSafeSetterCode = if (checkForNull) {
-      s"""
-         |if (args[y].isNullAt(z)) {
-         |  $arrayData.setNullAt($counter);
-         |} else {
-         |  $setterCode
-         |}
-       """.stripMargin
-    } else {
-      setterCode
-    }
+    val initialization = CodeGenerator.createArrayData(
+      arrayData, elementType, numElemName, s" $prettyName failed.")
+    val assignment = CodeGenerator.createArrayAssignment(
+      arrayData, elementType, s"args[$y]", counter, z,
+      dataType.asInstanceOf[ArrayType].containsNull)
 
     val concat = ctx.freshName("concat")
     val concatDef =
       s"""
          |private ArrayData $concat(ArrayData[] args) {
          |  $numElemCode
-         |  ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" 
$prettyName failed.")}
+         |  $initialization
          |  int $counter = 0;
-         |  for (int y = 0; y < ${children.length}; y++) {
-         |    for (int z = 0; z < args[y].numElements(); z++) {
-         |      $nullSafeSetterCode
+         |  for (int $y = 0; $y < ${children.length}; $y++) {
+         |    for (int $z = 0; $z < args[$y].numElements(); $z++) {
+         |      $assignment
          |      $counter++;
          |    }
          |  }
@@ -2541,33 +2415,6 @@ case class Concat(children: Seq[Expression]) extends 
ComplexTypeMergingExpressio
     ctx.addNewFunction(concat, concatDef)
   }
 
-  private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: 
DataType): String = {
-    val genericArrayClass = classOf[GenericArrayData].getName
-    val arrayData = ctx.freshName("arrayObjects")
-    val counter = ctx.freshName("counter")
-
-    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
-
-    val concat = ctx.freshName("concat")
-    val concatDef =
-      s"""
-         |private ArrayData $concat(ArrayData[] args) {
-         |  $numElemCode
-         |  Object[] $arrayData = new Object[(int)$numElemName];
-         |  int $counter = 0;
-         |  for (int y = 0; y < ${children.length}; y++) {
-         |    for (int z = 0; z < args[y].numElements(); z++) {
-         |      $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", 
elementType, "z")};
-         |      $counter++;
-         |    }
-         |  }
-         |  return new $genericArrayClass($arrayData);
-         |}
-       """.stripMargin
-
-    ctx.addNewFunction(concat, concatDef)
-  }
-
   override def toString: String = s"concat(${children.mkString(", ")})"
 
   override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
@@ -2630,11 +2477,7 @@ case class Flatten(child: Expression) extends 
UnaryExpression {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     nullSafeCodeGen(ctx, ev, c => {
-      val code = if (CodeGenerator.isPrimitiveType(elementType)) {
-        genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value)
-      } else {
-        genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value)
-      }
+      val code = genCodeForFlatten(ctx, c, ev.value)
       ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, 
c)(code)
     })
   }
@@ -2648,41 +2491,36 @@ case class Flatten(child: Expression) extends 
UnaryExpression {
       |for (int z = 0; z < $childVariableName.numElements(); z++) {
       |  $variableName += $childVariableName.getArray(z).numElements();
       |}
-      |if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
-      |  throw new RuntimeException("Unsuccessful try to flatten an array of 
arrays with " +
-      |    $variableName + " elements due to exceeding the array size limit" +
-      |    " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
-      |}
       """.stripMargin
     (code, variableName)
   }
 
-  private def genCodeForFlattenOfPrimitiveElements(
+  private def genCodeForFlatten(
       ctx: CodegenContext,
       childVariableName: String,
       arrayDataName: String): String = {
     val counter = ctx.freshName("counter")
     val tempArrayDataName = ctx.freshName("tempArrayData")
+    val k = ctx.freshName("k")
+    val l = ctx.freshName("l")
+    val arr = ctx.freshName("arr")
 
     val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, 
childVariableName)
 
-    val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+    val allocation = CodeGenerator.createArrayData(
+      tempArrayDataName, elementType, numElemName, s" $prettyName failed.")
+    val assignment = CodeGenerator.createArrayAssignment(
+      tempArrayDataName, elementType, arr, counter, l,
+      dataType.asInstanceOf[ArrayType].containsNull)
 
     s"""
     |$numElemCode
-    |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" 
$prettyName failed.")}
+    |$allocation
     |int $counter = 0;
-    |for (int k = 0; k < $childVariableName.numElements(); k++) {
-    |  ArrayData arr = $childVariableName.getArray(k);
-    |  for (int l = 0; l < arr.numElements(); l++) {
-    |   if (arr.isNullAt(l)) {
-    |     $tempArrayDataName.setNullAt($counter);
-    |   } else {
-    |     $tempArrayDataName.set$primitiveValueTypeName(
-    |       $counter,
-    |       ${CodeGenerator.getValue("arr", elementType, "l")}
-    |     );
-    |   }
+    |for (int $k = 0; $k < $childVariableName.numElements(); $k++) {
+    |  ArrayData $arr = $childVariableName.getArray($k);
+    |  for (int $l = 0; $l < $arr.numElements(); $l++) {
+    |   $assignment
     |   $counter++;
     | }
     |}
@@ -2690,30 +2528,6 @@ case class Flatten(child: Expression) extends 
UnaryExpression {
     """.stripMargin
   }
 
-  private def genCodeForFlattenOfNonPrimitiveElements(
-      ctx: CodegenContext,
-      childVariableName: String,
-      arrayDataName: String): String = {
-    val genericArrayClass = classOf[GenericArrayData].getName
-    val arrayName = ctx.freshName("arrayObject")
-    val counter = ctx.freshName("counter")
-    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, 
childVariableName)
-
-    s"""
-    |$numElemCode
-    |Object[] $arrayName = new Object[(int)$numElemName];
-    |int $counter = 0;
-    |for (int k = 0; k < $childVariableName.numElements(); k++) {
-    |  ArrayData arr = $childVariableName.getArray(k);
-    |  for (int l = 0; l < arr.numElements(); l++) {
-    |    $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, 
"l")};
-    |    $counter++;
-    |  }
-    |}
-    |$arrayDataName = new $genericArrayClass($arrayName);
-    """.stripMargin
-  }
-
   override def prettyName: String = "flatten"
 }
 
@@ -3155,11 +2969,7 @@ case class ArrayRepeat(left: Expression, right: 
Expression)
     val count = rightGen.value
     val et = dataType.elementType
 
-    val coreLogic = if (CodeGenerator.isPrimitiveType(et)) {
-      genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, 
ev.value)
-    } else {
-      genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, 
ev.value)
-    }
+    val coreLogic = genCodeForElement(ctx, et, element, count, leftGen.isNull, 
ev.value)
     val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic)
 
     ev.copy(code =
@@ -3198,17 +3008,12 @@ case class ArrayRepeat(left: Expression, right: 
Expression)
          |if ($count > 0) {
          |  $numElements = $count;
          |}
-         |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
-         |  throw new RuntimeException("Unsuccessful try to create array with 
" + $numElements +
-         |    " elements due to exceeding the array size limit" +
-         |    " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
-         |}
        """.stripMargin
 
     (numElements, numElementsCode)
   }
 
-  private def genCodeForPrimitiveElement(
+  private def genCodeForElement(
       ctx: CodegenContext,
       elementType: DataType,
       element: String,
@@ -3216,48 +3021,30 @@ case class ArrayRepeat(left: Expression, right: 
Expression)
       leftIsNull: String,
       arrayDataName: String): String = {
     val tempArrayDataName = ctx.freshName("tempArrayData")
-    val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
-    val errorMessage = s" $prettyName failed."
+    val k = ctx.freshName("k")
     val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
 
+    val allocation = CodeGenerator.createArrayData(
+      tempArrayDataName, elementType, numElemName, s" $prettyName failed.")
+    val assignment =
+      CodeGenerator.setArrayElement(tempArrayDataName, elementType, k, element)
+
     s"""
        |$numElemCode
-       |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, 
errorMessage)}
+       |$allocation
        |if (!$leftIsNull) {
-       |  for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
-       |    $tempArrayDataName.set$primitiveValueTypeName(k, $element);
+       |  for (int $k = 0; $k < $tempArrayDataName.numElements(); $k++) {
+       |    $assignment
        |  }
        |} else {
-       |  for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
-       |    $tempArrayDataName.setNullAt(k);
+       |  for (int $k = 0; $k < $tempArrayDataName.numElements(); $k++) {
+       |    $tempArrayDataName.setNullAt($k);
        |  }
        |}
        |$arrayDataName = $tempArrayDataName;
      """.stripMargin
   }
 
-  private def genCodeForNonPrimitiveElement(
-      ctx: CodegenContext,
-      element: String,
-      count: String,
-      leftIsNull: String,
-      arrayDataName: String): String = {
-    val genericArrayClass = classOf[GenericArrayData].getName
-    val arrayName = ctx.freshName("arrayObject")
-    val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
-
-    s"""
-       |$numElemCode
-       |Object[] $arrayName = new Object[(int)$numElemName];
-       |if (!$leftIsNull) {
-       |  for (int k = 0; k < $numElemName; k++) {
-       |    $arrayName[k] = $element;
-       |  }
-       |}
-       |$arrayDataName = new $genericArrayClass($arrayName);
-     """.stripMargin
-  }
-
 }
 
 /**
@@ -3339,50 +3126,117 @@ case class ArrayRemove(left: Expression, right: 
Expression)
     val pos = ctx.freshName("pos")
     val getValue = CodeGenerator.getValue(inputArray, elementType, i)
     val isEqual = ctx.genEqual(elementType, value, getValue)
-    if (!CodeGenerator.isPrimitiveType(elementType)) {
-      val arrayClass = classOf[GenericArrayData].getName
+
+    val allocation = CodeGenerator.createArrayData(
+      values, elementType, newArraySize, s" $prettyName failed.")
+    val assignment = CodeGenerator.createArrayAssignment(
+      values, elementType, inputArray, pos, i, false)
+
+    s"""
+       |$allocation
+       |int $pos = 0;
+       |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
+       |  if ($inputArray.isNullAt($i)) {
+       |    $values.setNullAt($pos);
+       |    $pos = $pos + 1;
+       |  }
+       |  else {
+       |    if (!($isEqual)) {
+       |      $assignment
+       |      $pos = $pos + 1;
+       |    }
+       |  }
+       |}
+       |${ev.value} = $values;
+     """.stripMargin
+  }
+
+  override def prettyName: String = "array_remove"
+}
+
+/**
+ * Will become common base class for [[ArrayDistinct]], [[ArrayUnion]], 
[[ArrayIntersect]],
+ * and [[ArrayExcept]].
+ */
+trait ArraySetLike {
+  protected def dt: DataType
+  protected def et: DataType
+
+  @transient protected lazy val canUseSpecializedHashSet = et match {
+    case ByteType | ShortType | IntegerType | LongType | FloatType | 
DoubleType => true
+    case _ => false
+  }
+
+  @transient protected lazy val ordering: Ordering[Any] =
+    TypeUtils.getInterpretedOrdering(et)
+
+  protected def genGetValue(array: String, i: String): String =
+    CodeGenerator.getValue(array, et, i)
+
+  @transient protected lazy val (hsPostFix, hsTypeName) = {
+    val ptName = CodeGenerator.primitiveTypeName(et)
+    et match {
+      // we cast byte/short to int when writing to the hash set.
+      case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int")
+      case LongType => ("$mcJ$sp", ptName)
+      case FloatType => ("$mcF$sp", ptName)
+      case DoubleType => ("$mcD$sp", ptName)
+    }
+  }
+
+  // we cast byte/short to int when writing to the hash set.
+  @transient protected lazy val hsValueCast = et match {
+    case ByteType | ShortType => "(int) "
+    case _ => ""
+  }
+
+  // When hitting a null value, put a null holder in the ArrayBuilder. Finally 
we will
+  // convert ArrayBuilder to ArrayData and setNull on the slot with null 
holder.
+  @transient protected lazy val nullValueHolder = et match {
+    case ByteType => "(byte) 0"
+    case ShortType => "(short) 0"
+    case _ => "0"
+  }
+
+  protected def withResultArrayNullCheck(
+      body: String,
+      value: String,
+      nullElementIndex: String): String = {
+    if (dt.asInstanceOf[ArrayType].containsNull) {
       s"""
-         |int $pos = 0;
-         |Object[] $values = new Object[$newArraySize];
-         |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
-         |  if ($inputArray.isNullAt($i)) {
-         |    $values[$pos] = null;
-         |    $pos = $pos + 1;
-         |  }
-         |  else {
-         |    if (!($isEqual)) {
-         |      $values[$pos] = $getValue;
-         |      $pos = $pos + 1;
-         |    }
-         |  }
+         |$body
+         |if ($nullElementIndex >= 0) {
+         |  // result has null element
+         |  $value.setNullAt($nullElementIndex);
          |}
-         |${ev.value} = new $arrayClass($values);
        """.stripMargin
     } else {
-      val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
-      s"""
-         |${ctx.createUnsafeArray(values, newArraySize, elementType, s" 
$prettyName failed.")}
-         |int $pos = 0;
-         |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
-         |  if ($inputArray.isNullAt($i)) {
-         |      $values.setNullAt($pos);
-         |      $pos = $pos + 1;
-         |  }
-         |  else {
-         |    if (!($isEqual)) {
-         |      $values.set$primitiveValueTypeName($pos, $getValue);
-         |      $pos = $pos + 1;
-         |    }
-         |  }
-         |}
-         |${ev.value} = $values;
-       """.stripMargin
+      body
     }
   }
 
-  override def prettyName: String = "array_remove"
+  def buildResultArray(
+      builder: String,
+      value : String,
+      size : String,
+      nullElementIndex : String): String = withResultArrayNullCheck(
+    s"""
+       |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+       |  throw new RuntimeException("Cannot create array with " + $size +
+       |  " elements of data due to exceeding the limit " +
+       |  "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for 
ArrayData.");
+       |}
+       |
+       |if (!UnsafeArrayData.shouldUseGenericArrayData(${et.defaultSize}, 
$size)) {
+       |  $value = UnsafeArrayData.fromPrimitiveArray($builder.result());
+       |} else {
+       |  $value = new ${classOf[GenericArrayData].getName}($builder.result());
+       |}
+     """.stripMargin, value, nullElementIndex)
+
 }
 
+
 /**
  * Removes duplicate values from the array.
  */
@@ -3394,7 +3248,7 @@ case class ArrayRemove(left: Expression, right: 
Expression)
        [1,2,3,null]
   """, since = "2.4.0")
 case class ArrayDistinct(child: Expression)
-  extends UnaryExpression with ExpectsInputTypes {
+  extends UnaryExpression with ArraySetLike with ExpectsInputTypes {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
 
@@ -3402,8 +3256,8 @@ case class ArrayDistinct(child: Expression)
 
   @transient private lazy val elementType: DataType = 
dataType.asInstanceOf[ArrayType].elementType
 
-  @transient private lazy val ordering: Ordering[Any] =
-    TypeUtils.getInterpretedOrdering(elementType)
+  override protected def dt: DataType = dataType
+  override protected def et: DataType = elementType
 
   override def checkInputDataTypes(): TypeCheckResult = {
     super.checkInputDataTypes() match {
@@ -3413,28 +3267,6 @@ case class ArrayDistinct(child: Expression)
     }
   }
 
-  @transient protected lazy val canUseSpecializedHashSet = elementType match {
-    case ByteType | ShortType | IntegerType | LongType | FloatType | 
DoubleType => true
-    case _ => false
-  }
-
-  @transient protected lazy val (hsPostFix, hsTypeName) = {
-    val ptName = CodeGenerator.primitiveTypeName(elementType)
-    elementType match {
-      // we cast byte/short to int when writing to the hash set.
-      case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int")
-      case LongType => ("$mcJ$sp", ptName)
-      case FloatType => ("$mcF$sp", ptName)
-      case DoubleType => ("$mcD$sp", ptName)
-    }
-  }
-
-  // we cast byte/short to int when writing to the hash set.
-  @transient protected lazy val hsValueCast = elementType match {
-    case ByteType | ShortType => "(int) "
-    case _ => ""
-  }
-
   override def nullSafeEval(array: Any): Any = {
     val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
     doEvaluation(data)
@@ -3471,28 +3303,73 @@ case class ArrayDistinct(child: Expression)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val i = ctx.freshName("i")
+    val value = ctx.freshName("value")
+    val size = ctx.freshName("size")
+
     if (canUseSpecializedHashSet) {
+      val jt = CodeGenerator.javaType(elementType)
+      val ptName = CodeGenerator.primitiveTypeName(jt)
+
       nullSafeCodeGen(ctx, ev, (array) => {
-        val i = ctx.freshName("i")
-        val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray")
         val foundNullElement = ctx.freshName("foundNullElement")
+        val nullElementIndex = ctx.freshName("nullElementIndex")
+        val builder = ctx.freshName("builder")
         val openHashSet = classOf[OpenHashSet[_]].getName
-        val hs = ctx.freshName("hs")
         val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
-        val getValue = CodeGenerator.getValue(array, elementType, i)
+        val hashSet = ctx.freshName("hashSet")
+        val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
+        val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
+
+        // Only need to track null element index when array's element is 
nullable.
+        val declareNullTrackVariables = if 
(dataType.asInstanceOf[ArrayType].containsNull) {
+          s"""
+             |boolean $foundNullElement = false;
+             |int $nullElementIndex = -1;
+           """.stripMargin
+        } else {
+          ""
+        }
+
+        def withArrayNullAssignment(body: String) =
+          if (dataType.asInstanceOf[ArrayType].containsNull) {
+            s"""
+               |if ($array.isNullAt($i)) {
+               |  if (!$foundNullElement) {
+               |    $nullElementIndex = $size;
+               |    $foundNullElement = true;
+               |    $size++;
+               |    $builder.$$plus$$eq($nullValueHolder);
+               |  }
+               |} else {
+               |  $body
+               |}
+             """.stripMargin
+          } else {
+            body
+          }
+
+        val processArray = withArrayNullAssignment(
+          s"""
+             |$jt $value = ${genGetValue(array, i)};
+             |if (!$hashSet.contains($hsValueCast$value)) {
+             |  if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+             |    break;
+             |  }
+             |  $hashSet.add$hsPostFix($hsValueCast$value);
+             |  $builder.$$plus$$eq($value);
+             |}
+           """.stripMargin)
+
         s"""
-           |int $sizeOfDistinctArray = 0;
-           |boolean $foundNullElement = false;
-           |$openHashSet $hs = new $openHashSet($classTag);
-           |for (int $i = 0; $i < $array.numElements(); $i ++) {
-           |  if ($array.isNullAt($i)) {
-           |    $foundNullElement = true;
-           |  } else {
-           |    $hs.add$hsPostFix($hsValueCast$getValue);
-           |  }
+           |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
+           |$declareNullTrackVariables
+           |$arrayBuilderClass $builder = new $arrayBuilderClass();
+           |int $size = 0;
+           |for (int $i = 0; $i < $array.numElements(); $i++) {
+           |  $processArray
            |}
-           |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0);
-           |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
+           |${buildResultArray(builder, ev.value, size, nullElementIndex)}
          """.stripMargin
       })
     } else {
@@ -3503,73 +3380,16 @@ case class ArrayDistinct(child: Expression)
     }
   }
 
-  private def setNull(
-      foundNullElement: String,
-      distinctArray: String,
-      pos: String): String = {
-    val setNullValue = s"$distinctArray.setNullAt($pos)"
-    s"""
-       |if (!($foundNullElement)) {
-       |  $setNullValue;
-       |  $pos = $pos + 1;
-       |  $foundNullElement = true;
-       |}
-    """.stripMargin
-  }
-
-  private def setValue(
-      hs: String,
-      distinctArray: String,
-      pos: String,
-      getValue1: String,
-      primitiveValueTypeName: String): String = {
-    s"""
-       |if (!($hs.contains$hsPostFix($hsValueCast$getValue1))) {
-       |  $hs.add$hsPostFix($hsValueCast$getValue1);
-       |  $distinctArray.set$primitiveValueTypeName($pos, $getValue1);
-       |  $pos = $pos + 1;
-       |}
-    """.stripMargin
-  }
-
-  def genCodeForResult(
-      ctx: CodegenContext,
-      ev: ExprCode,
-      inputArray: String,
-      size: String): String = {
-    val distinctArray = ctx.freshName("distinctArray")
-    val i = ctx.freshName("i")
-    val pos = ctx.freshName("pos")
-    val getValue1 = CodeGenerator.getValue(inputArray, elementType, i)
-    val foundNullElement = ctx.freshName("foundNullElement")
-    val hs = ctx.freshName("hs")
-    val openHashSet = classOf[OpenHashSet[_]].getName
-    val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
-    val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
-
-    s"""
-       |${ctx.createUnsafeArray(distinctArray, size, elementType, s" 
$prettyName failed.")}
-       |int $pos = 0;
-       |boolean $foundNullElement = false;
-       |$openHashSet $hs = new $openHashSet($classTag);
-       |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
-       |  if ($inputArray.isNullAt($i)) {
-       |    ${setNull(foundNullElement, distinctArray, pos)}
-       |  } else {
-       |    ${setValue(hs, distinctArray, pos, getValue1, 
primitiveValueTypeName)}
-       |  }
-       |}
-       |${ev.value} = $distinctArray;
-    """.stripMargin
-  }
-
   override def prettyName: String = "array_distinct"
 }
 
 /**
  * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and 
[[ArrayExcept]].
  */
-abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
+trait ArrayBinaryLike extends BinaryArrayExpressionWithImplicitCast with 
ArraySetLike {
+  override protected def dt: DataType = dataType
+  override protected def et: DataType = elementType
+
   override def checkInputDataTypes(): TypeCheckResult = {
     val typeCheckResult = super.checkInputDataTypes()
     if (typeCheckResult.isSuccess) {
@@ -3579,81 +3399,9 @@ abstract class ArraySetLike extends 
BinaryArrayExpressionWithImplicitCast {
       typeCheckResult
     }
   }
-
-  @transient protected lazy val ordering: Ordering[Any] =
-    TypeUtils.getInterpretedOrdering(elementType)
-
-  @transient protected lazy val canUseSpecializedHashSet = elementType match {
-    case ByteType | ShortType | IntegerType | LongType | FloatType | 
DoubleType => true
-    case _ => false
-  }
-
-  protected def genGetValue(array: String, i: String): String =
-    CodeGenerator.getValue(array, elementType, i)
-
-  @transient protected lazy val (hsPostFix, hsTypeName) = {
-    val ptName = CodeGenerator.primitiveTypeName (elementType)
-    elementType match {
-      // we cast byte/short to int when writing to the hash set.
-      case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int")
-      case LongType => ("$mcJ$sp", ptName)
-      case FloatType => ("$mcF$sp", ptName)
-      case DoubleType => ("$mcD$sp", ptName)
-    }
-  }
-
-  // we cast byte/short to int when writing to the hash set.
-  @transient protected lazy val hsValueCast = elementType match {
-    case ByteType | ShortType => "(int) "
-    case _ => ""
-  }
-
-  // When hitting a null value, put a null holder in the ArrayBuilder. Finally 
we will
-  // convert ArrayBuilder to ArrayData and setNull on the slot with null 
holder.
-  @transient protected lazy val nullValueHolder = elementType match {
-    case ByteType => "(byte) 0"
-    case ShortType => "(short) 0"
-    case _ => "0"
-  }
-
-  protected def withResultArrayNullCheck(
-      body: String,
-      value: String,
-      nullElementIndex: String): String = {
-    if (dataType.asInstanceOf[ArrayType].containsNull) {
-      s"""
-         |$body
-         |if ($nullElementIndex >= 0) {
-         |  // result has null element
-         |  $value.setNullAt($nullElementIndex);
-         |}
-       """.stripMargin
-    } else {
-      body
-    }
-  }
-
-  def buildResultArray(
-      builder: String,
-      value : String,
-      size : String,
-      nullElementIndex : String): String = withResultArrayNullCheck(
-    s"""
-       |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
-       |  throw new RuntimeException("Cannot create array with " + $size +
-       |  " bytes of data due to exceeding the limit " +
-       |  "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for 
ArrayData.");
-       |}
-       |
-       |if 
(!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) 
{
-       |  $value = UnsafeArrayData.fromPrimitiveArray($builder.result());
-       |} else {
-       |  $value = new ${classOf[GenericArrayData].getName}($builder.result());
-       |}
-     """.stripMargin, value, nullElementIndex)
 }
 
-object ArraySetLike {
+object ArrayBinaryLike {
   def throwUnionLengthOverflowException(length: Int): Unit = {
     throw new RuntimeException(s"Unsuccessful try to union arrays with $length 
" +
       s"elements due to exceeding the array size limit " +
@@ -3676,7 +3424,7 @@ object ArraySetLike {
        array(1, 2, 3, 5)
   """,
   since = "2.4.0")
-case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
+case class ArrayUnion(left: Expression, right: Expression) extends 
ArrayBinaryLike
   with ComplexTypeMergingExpression {
 
   @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = {
@@ -3697,7 +3445,7 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArraySetLike
               val elem = array.get(i, elementType)
               if (!hs.contains(elem)) {
                 if (arrayBuffer.size > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-                  
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
+                  
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
                 }
                 arrayBuffer += elem
                 hs.add(elem)
@@ -3732,7 +3480,7 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArraySetLike
           }
           if (!found) {
             if (arrayBuffer.length > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-              
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length)
+              
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.length)
             }
             arrayBuffer += elem
           }
@@ -3864,7 +3612,7 @@ object ArrayUnion {
       }
       if (!found) {
         if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-          ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length)
+          ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.length)
         }
         arrayBuffer += elem
       }
@@ -3887,7 +3635,7 @@ object ArrayUnion {
        array(1, 3)
   """,
   since = "2.4.0")
-case class ArrayIntersect(left: Expression, right: Expression) extends 
ArraySetLike
+case class ArrayIntersect(left: Expression, right: Expression) extends 
ArrayBinaryLike
   with ComplexTypeMergingExpression {
   override def dataType: DataType = {
     dataTypeCheck
@@ -4128,7 +3876,7 @@ case class ArrayIntersect(left: Expression, right: 
Expression) extends ArraySetL
        array(2)
   """,
   since = "2.4.0")
-case class ArrayExcept(left: Expression, right: Expression) extends 
ArraySetLike
+case class ArrayExcept(left: Expression, right: Expression) extends 
ArrayBinaryLike
   with ComplexTypeMergingExpression {
 
   override def dataType: DataType = {

http://git-wip-us.apache.org/repos/asf/spark/blob/e319ac92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
index 104b428..4da8ce0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
@@ -22,6 +22,8 @@ import scala.reflect.ClassTag
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, 
UnsafeArrayData}
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.array.ByteArrayMethods
 
 object ArrayData {
   def toArrayData(input: Any): ArrayData = input match {
@@ -34,6 +36,31 @@ object ArrayData {
     case a: Array[Double] => UnsafeArrayData.fromPrimitiveArray(a)
     case other => new GenericArrayData(other)
   }
+
+
+  /**
+   * Allocate [[UnsafeArrayData]] or [[GenericArrayData]] based on given 
parameters.
+   *
+   * @param elementSize a size of an element in bytes. If less than zero, the 
type of an element is
+   *                    non-primitive type
+   * @param numElements the number of elements the array should contain
+   * @param additionalErrorMessage string to include in the error message
+   */
+  def allocateArrayData(
+      elementSize: Int,
+      numElements: Long,
+      additionalErrorMessage: String): ArrayData = {
+    if (elementSize >= 0 && 
!UnsafeArrayData.shouldUseGenericArrayData(elementSize, numElements)) {
+      UnsafeArrayData.createFreshArray(numElements.toInt, elementSize)
+    } else if (numElements <= 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toLong) {
+      new GenericArrayData(new Array[Any](numElements.toInt))
+    } else {
+      throw new RuntimeException(s"Cannot create array with $numElements " +
+        "elements of data due to exceeding the limit " +
+        s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData. 
" +
+        additionalErrorMessage)
+    }
+  }
 }
 
 abstract class ArrayData extends SpecializedGetters with Serializable {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to