From 995ed2ad31a24cb36e20beae2aa36d3e58fc6298 Mon Sep 17 00:00:00 2001
From: Joel Jakobsson <github@compiler.org>
Date: Sun, 7 Jul 2024 19:21:35 +0200
Subject: [PATCH] Optimize mul_var() for var1ndigits >= 8

The idea is to reduce the "n" in O(n^2) by a factor of two.

This is achieved by first converting the (ndigits) number of int16 NBASE digits,
to (ndigits/2) number of int32 NBASE^2 digits, as well as upgrading the
int32 variables to int64-variables so that the products and carry values fit.

The existing multiplication algorithm is then executed without change.

Finally, the int32 NBASE^2 result digits are converted back to twice the number
of int16 NBASE digits.

This adds overhead of approximately 4 * O(n), due to the conversion.
Benchmark indicates it's a win when var1 is at least 8 ndigits.
---
 src/backend/utils/adt/numeric.c | 243 ++++++++++++++++++++++++++++++++
 1 file changed, 243 insertions(+)

diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
index 5510a203b0..ddfc71feda 100644
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -101,6 +101,8 @@ typedef signed char NumericDigit;
 typedef int16 NumericDigit;
 #endif
 
+#define SQUARE_NBASE	(NBASE * NBASE)
+
 /*
  * The Numeric type as stored on disk.
  *
@@ -551,6 +553,8 @@ static void sub_var(const NumericVar *var1, const NumericVar *var2,
 static void mul_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result,
 					int rscale);
+static void mul_var_large(const NumericVar *var1, const NumericVar *var2,
+						  NumericVar *result, int rscale);
 static void div_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result,
 					int rscale, bool round);
@@ -8715,6 +8719,16 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
 		return;
 	}
 
+	/*
+	 * If var1 has at least 8 digits, delegate to mul_var_large()
+	 * which uses a multiplication algorithm faster for large multiplicands.
+	 */
+	if (var1ndigits >= 8)
+	{
+		mul_var_large(var1, var2, result, rscale);
+		return;
+	}
+
 	/* Determine result sign and (maximum possible) weight */
 	if (var1->sign == var2->sign)
 		res_sign = NUMERIC_POS;
@@ -8864,6 +8878,235 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
 	strip_var(result);
 }
 
+/*
+ * mul_var_large() -
+ *
+ *	Special-case multiplication function used when var1 has at least 8 digits,
+ *	that reduces the "n" in O(n^2) by a factor of two.
+ *
+ *	This is achieved by first converting the (ndigits) number of int16 NBASE
+ *	digits, to (ndigits/2) number of int32 NBASE^2 digits, as well as upgrading
+ *	the int32 variables to int64-variables so that the products and carry
+ *	values fit.
+ *
+ *	The existing multiplication algorithm is then executed without change.
+ *
+ *	Finally, the int32 NBASE^2 result digits are converted back to twice
+ *	the number of int16 NBASE digits.
+ *
+ *	This adds overhead of approximately 4 * O(n), due to the conversion,
+ *	which seems to be a win when var1 has at least 8 digits.
+ */
+static void
+mul_var_large(const NumericVar *var1, const NumericVar *var2,
+			  NumericVar *result, int rscale)
+{
+	int			res_ndigits;
+	int			res_sign;
+	int			res_weight;
+	int			maxdigits;
+	int64	   *dig;
+	int64		carry;
+	int64		maxdig;
+	int64		newdig;
+	int			var1ndigits = (var1->ndigits + 1) / 2;
+	int			var2ndigits = (var2->ndigits + 1) / 2;
+	int64 	   *var1digits;
+	int64	   *var2digits;
+	int		   *res_digits;
+	int			i,
+				i1,
+				i2;
+
+	/* Check preconditions */
+	Assert(var1->ndigits >= 8);
+	Assert(var2->ndigits >= var1->ndigits);
+
+	/* Determine result sign */
+	if (var1->sign == var2->sign)
+		res_sign = NUMERIC_POS;
+	else
+		res_sign = NUMERIC_NEG;
+
+	/*
+	 * Determine the number of result digits to compute.  If the exact result
+	 * would have more than rscale fractional digits, truncate the computation
+	 * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
+	 * would only contribute to the right of that.  (This will give the exact
+	 * rounded-to-rscale answer unless carries out of the ignored positions
+	 * would have propagated through more than MUL_GUARD_DIGITS digits.)
+	 *
+	 * Additionally, determine the (maximum possible) weight of the result,
+	 * considering the base conversion and the ceiling division by 2
+	 * of the number of digits.
+	 *
+	 * Note: an exact computation could not produce more than var1ndigits +
+	 * var2ndigits digits, but we allocate one extra output digit in case
+	 * rscale-driven rounding produces a carry out of the highest exact digit.
+	 */
+	res_ndigits = var1ndigits + var2ndigits + 1;
+	res_weight = var1->weight + var2->weight + 2 +
+				 ((res_ndigits * 2) - (var1->ndigits + var2->ndigits + 1));
+	maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
+		MUL_GUARD_DIGITS;
+	res_ndigits = Min(res_ndigits, maxdigits);
+
+	if (res_ndigits < 3)
+	{
+		/* All input digits will be ignored; so result is zero */
+		zero_var(result);
+		result->dscale = rscale;
+		return;
+	}
+
+	/*
+	 * We do the arithmetic in an array "dig[]" of signed int64's.  Since
+	 * PG_INT64_MAX is noticeably larger than SQUARE_NBASE*SQUARE_NBASE, this
+	 * gives us headroom to avoid normalizing carries immediately.
+	 *
+	 * maxdig tracks the maximum possible value of any dig[] entry; when this
+	 * threatens to exceed PG_INT64_MAX, we take the time to propagate carries.
+	 * Furthermore, we need to ensure that overflow doesn't occur during the
+	 * carry propagation passes either.  The carry values could be as much as
+	 * PG_INT64_MAX/SQUARE_NBASE, so really we must normalize when digits
+	 * threaten to exceed PG_INT64_MAX - PG_INT64_MAX/SQUARE_NBASE.
+	 *
+	 * To avoid overflow in maxdig itself, it actually represents the max
+	 * possible value divided by SQUARE_NBASE-1, ie, at the top of the loop it
+	 * is known that no dig[] entry exceeds maxdig * (SQUARE_NBASE-1).
+	 *
+	 * The allocated dig[] array will both be used to write the result,
+	 * as well as the result of the base conversion of var1 and var2.
+	 */
+	dig = (int64 *) palloc0((res_ndigits + var1ndigits + var2ndigits) *
+							sizeof(int64));
+	maxdig = 0;
+	var1digits = dig + res_ndigits;
+	var2digits = dig + res_ndigits + var1ndigits;
+
+	/*
+	 * Base conversion of var1 and var2 from NBASE to SQUARE_NBASE.
+	 */
+	i1 = 0; i2 = 0;
+	if (var1->ndigits % 2 != 0)
+		var1digits[i1++] = (int64) var1->digits[i2++];
+	for (; i1 < var1ndigits; i1++, i2 += 2)
+		var1digits[i1] = (int64) var1->digits[i2] * NBASE + var1->digits[i2+1];
+
+	i1 = 0; i2 = 0;
+	if (var2->ndigits % 2 != 0)
+		var2digits[i1++] = (int64) var2->digits[i2++];
+	for (; i1 < var2ndigits; i1++, i2 += 2)
+		var2digits[i1] = (int64) var2->digits[i2] * NBASE + var2->digits[i2+1];
+
+	/*
+	 * The least significant digits of var1 should be ignored if they don't
+	 * contribute directly to the first res_ndigits digits of the result that
+	 * we are computing.
+	 *
+	 * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
+	 * i1+i2+2 of the accumulator array, so we need only consider digits of
+	 * var1 for which i1 <= res_ndigits - 3.
+	 */
+	for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
+	{
+		int64 var1digit = var1digits[i1];
+
+		if (var1digit == 0)
+			continue;
+
+		/* Time to normalize? */
+		maxdig += var1digit;
+		if (maxdig > (PG_INT64_MAX - PG_INT64_MAX / SQUARE_NBASE) /
+					 (SQUARE_NBASE - 1))
+		{
+			/* Yes, do it */
+			carry = 0;
+			for (i = res_ndigits - 1; i >= 0; i--)
+			{
+				newdig = dig[i] + carry;
+				if (newdig >= SQUARE_NBASE)
+				{
+					carry = newdig / SQUARE_NBASE;
+					newdig -= carry * SQUARE_NBASE;
+				}
+				else
+					carry = 0;
+				dig[i] = newdig;
+			}
+			Assert(carry == 0);
+			/* Reset maxdig to indicate new worst-case */
+			maxdig = 1 + var1digit;
+		}
+
+		/*
+		 * Add the appropriate multiple of var2 into the accumulator.
+		 *
+		 * As above, digits of var2 can be ignored if they don't contribute,
+		 * so we only include digits for which i1+i2+2 < res_ndigits.
+		 *
+		 * This inner loop is the performance bottleneck for multiplication,
+		 * so we want to keep it simple enough so that it can be
+		 * auto-vectorized.  Accordingly, process the digits left-to-right
+		 * even though schoolbook multiplication would suggest right-to-left.
+		 * Since we aren't propagating carries in this loop, the order does
+		 * not matter.
+		 */
+		{
+			int			i2limit = Min(var2ndigits, res_ndigits - i1 - 2);
+			int64	   *dig_i1_2 = &dig[i1 + 2];
+
+			for (i2 = 0; i2 < i2limit; i2++)
+				dig_i1_2[i2] += var1digit * var2digits[i2];
+		}
+	}
+
+	/*
+	 * Now we do a final carry propagation pass to normalize the result, which
+	 * we combine with storing the result digits into the output. Note that
+	 * this is still done at full precision w/guard digits.
+	 */
+	res_digits = (int *) palloc0(res_ndigits * sizeof(int));
+	carry = 0;
+	for (i = res_ndigits - 1; i >= 0; i--)
+	{
+		newdig = dig[i] + carry;
+		if (newdig >= SQUARE_NBASE)
+		{
+			carry = newdig / SQUARE_NBASE;
+			newdig -= carry * SQUARE_NBASE;
+		}
+		else
+			carry = 0;
+		res_digits[i] = newdig;
+	}
+	Assert(carry == 0);
+
+	/*
+	 * Base conversion of res_digits from SQUARE_NBASE to NBASE.
+	 */
+	alloc_var(result, res_ndigits * 2);
+	for (i = 0; i < res_ndigits; i++)
+	{
+		int q = res_digits[i];
+		result->digits[i*2] = q / NBASE;
+		result->digits[i*2 + 1] = q % NBASE;
+	}
+
+	pfree(dig);
+
+	/*
+	 * Finally, round the result to the requested precision.
+	 */
+	result->weight = res_weight;
+	result->sign = res_sign;
+
+	/* Round to target rscale (and set result->dscale) */
+	round_var(result, rscale);
+
+	/* Strip leading and trailing zeroes */
+	strip_var(result);
+}
 
 /*
  * div_var() -
-- 
2.45.1

