aokolnychyi commented on code in PR #7646:
URL: https://github.com/apache/iceberg/pull/7646#discussion_r1198247245


##########
spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala:
##########
@@ -315,143 +312,104 @@ object RewriteMergeIntoTable extends 
RewriteRowLevelIcebergCommand with Predicat
 
     // build a plan to write the row delta to the table
     val writeRelation = relation.copy(table = operationTable)
-    val projections = buildMergeDeltaProjections(mergeRows, rowAttrs, 
rowIdAttrs, metadataAttrs)
+    val projections = buildDeltaProjections(mergeRows, rowAttrs, rowIdAttrs, 
metadataAttrs)
     WriteIcebergDelta(writeRelation, mergeRows, relation, projections)
   }
 
   private def actionCondition(action: MergeAction): Expression = {
     action.condition.getOrElse(TrueLiteral)
   }
 
-  private def actionOutput(
+  private def matchedActionOutput(
       clause: MergeAction,
-      metadataAttrs: Seq[Attribute]): Seq[Expression] = {
+      metadataAttrs: Seq[Attribute]): Seq[Seq[Expression]] = {
 
     clause match {
       case u: UpdateAction =>
-        u.assignments.map(_.value) ++ metadataAttrs
+        Seq(u.assignments.map(_.value) ++ metadataAttrs)
 
       case _: DeleteAction =>
         Nil
 
+      case other =>
+        throw new AnalysisException(s"Unexpected WHEN MATCHED action: $other")
+    }
+  }
+
+  private def notMatchedActionOutput(
+      clause: MergeAction,
+      metadataAttrs: Seq[Attribute]): Seq[Expression] = {
+
+    clause match {
       case i: InsertAction =>
         i.assignments.map(_.value) ++ metadataAttrs.map(attr => Literal(null, 
attr.dataType))
 
       case other =>
-        throw new AnalysisException(s"Unexpected action: $other")
+        throw new AnalysisException(s"Unexpected WHEN NOT MATCHED action: 
$other")
     }
   }
 
-  private def deltaActionOutput(
+  private def matchedDeltaActionOutput(
       action: MergeAction,
-      deleteRowValues: Seq[Expression],
-      metadataAttrs: Seq[Attribute]): Seq[Expression] = {
+      rowAttrs: Seq[Attribute],
+      rowIdAttrs: Seq[Attribute],
+      metadataAttrs: Seq[Attribute]): Seq[Seq[Expression]] = {
 
     action match {
       case u: UpdateAction =>
-        Seq(Literal(UPDATE_OPERATION)) ++ u.assignments.map(_.value) ++ 
metadataAttrs
+        val delete = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs)
+        val insert = deltaInsertOutput(u.assignments.map(_.value), 
metadataAttrs)
+        Seq(delete, insert)
 
       case _: DeleteAction =>
-        Seq(Literal(DELETE_OPERATION)) ++ deleteRowValues ++ metadataAttrs
+        val delete = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs)
+        Seq(delete)
+
+      case other =>
+        throw new AnalysisException(s"Unexpected WHEN MATCHED action: $other")
+    }
+  }
+
+  private def notMatchedDeltaActionOutput(
+      action: MergeAction,
+      metadataAttrs: Seq[Attribute]): Seq[Expression] = {
 
+    action match {
       case i: InsertAction =>
-        val metadataAttrValues = metadataAttrs.map(attr => Literal(null, 
attr.dataType))
-        Seq(Literal(INSERT_OPERATION)) ++ i.assignments.map(_.value) ++ 
metadataAttrValues
+        deltaInsertOutput(i.assignments.map(_.value), metadataAttrs)
 
       case other =>
-        throw new AnalysisException(s"Unexpected action: $other")
+        throw new AnalysisException(s"Unexpected WHEN NOT MATCHED action: 
$other")
     }
   }
 
   private def buildMergeRowsOutput(
-      matchedOutputs: Seq[Seq[Expression]],
+      matchedOutputs: Seq[Seq[Seq[Expression]]],
       notMatchedOutputs: Seq[Seq[Expression]],
       attrs: Seq[Attribute]): Seq[Attribute] = {
 
-    // collect all outputs from matched and not matched actions (ignoring 
DELETEs)
-    val outputs = matchedOutputs.filter(_.nonEmpty) ++ 
notMatchedOutputs.filter(_.nonEmpty)
-
-    // build a correct nullability map for output attributes
-    // an attribute is nullable if at least one matched or not matched action 
may produce null
-    val nullabilityMap = attrs.indices.map { index =>
-      index -> outputs.exists(output => output(index).nullable)
-    }.toMap
-
-    attrs.zipWithIndex.map { case (attr, index) =>
-      AttributeReference(attr.name, attr.dataType, nullabilityMap(index), 
attr.metadata)()
-    }
+    // collect all outputs from matched and not matched actions (ignoring 
actions that discard rows)
+    val outputs = matchedOutputs.flatten.filter(_.nonEmpty) ++ 
notMatchedOutputs.filter(_.nonEmpty)
+    buildMergingOutput(outputs, attrs)
   }
 
   private def isCardinalityCheckNeeded(actions: Seq[MergeAction]): Boolean = 
actions match {
     case Seq(DeleteAction(None)) => false
     case _ => true
   }
 
-  private def buildDeltaDeleteRowValues(
-      rowAttrs: Seq[Attribute],
-      rowIdAttrs: Seq[Attribute]): Seq[Expression] = {
-
-    // nullify all row attrs that are not part of the row ID
-    val rowIdAttSet = AttributeSet(rowIdAttrs)
-    rowAttrs.map {
-      case attr if rowIdAttSet.contains(attr) => attr
-      case attr => Literal(null, attr.dataType)
-    }
-  }
-
   private def resolveAttrRef(ref: NamedReference, plan: LogicalPlan): 
AttributeReference = {
     V2ExpressionUtils.resolveRef[AttributeReference](ref, plan)
   }
 
-  private def buildMergeDeltaProjections(
+  private def buildDeltaProjections(
       mergeRows: MergeRows,
       rowAttrs: Seq[Attribute],
       rowIdAttrs: Seq[Attribute],
       metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
 
-    val outputAttrs = mergeRows.output

Review Comment:
   Moved into parent to reuse.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to