ChrisHegarty commented on code in PR #12311: URL: https://github.com/apache/lucene/pull/12311#discussion_r1201940187
########## lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java: ########## @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.ShortVector; +import jdk.incubator.vector.Vector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorShape; +import jdk.incubator.vector.VectorSpecies; + +/** A VectorUtil provider implementation that leverages the Panama Vector API. */ +final class VectorUtilPanamaProvider implements VectorUtilProvider { + + static final VectorSpecies<Float> SPECIES = FloatVector.SPECIES_PREFERRED; + + VectorUtilPanamaProvider() {} + + @Override + public float dotProduct(float[] a, float[] b) { + int i = 0; + float res = 0; + // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize + if (a.length > 2 * SPECIES.length()) { + // vector loop is unrolled 4x (4 accumulators in parallel) + FloatVector acc1 = FloatVector.zero(SPECIES); + FloatVector acc2 = FloatVector.zero(SPECIES); + FloatVector acc3 = FloatVector.zero(SPECIES); + FloatVector acc4 = FloatVector.zero(SPECIES); + int upperBound = SPECIES.loopBound(a.length - 3 * SPECIES.length()); + for (; i < upperBound; i += 4 * SPECIES.length()) { + FloatVector va = FloatVector.fromArray(SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(SPECIES, b, i); + acc1 = acc1.add(va.mul(vb)); + FloatVector vc = FloatVector.fromArray(SPECIES, a, i + SPECIES.length()); + FloatVector vd = FloatVector.fromArray(SPECIES, b, i + SPECIES.length()); + acc2 = acc2.add(vc.mul(vd)); + FloatVector ve = FloatVector.fromArray(SPECIES, a, i + 2 * SPECIES.length()); + FloatVector vf = FloatVector.fromArray(SPECIES, b, i + 2 * SPECIES.length()); + acc3 = acc3.add(ve.mul(vf)); + FloatVector vg = FloatVector.fromArray(SPECIES, a, i + 3 * SPECIES.length()); + FloatVector vh = FloatVector.fromArray(SPECIES, b, i + 3 * SPECIES.length()); + acc4 = acc4.add(vg.mul(vh)); + } + // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + upperBound = SPECIES.loopBound(a.length); + for (; i < upperBound; i += SPECIES.length()) { + FloatVector va = FloatVector.fromArray(SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(SPECIES, b, i); + acc1 = acc1.add(va.mul(vb)); + } + // reduce + FloatVector res1 = acc1.add(acc2); + FloatVector res2 = acc3.add(acc4); + res += res1.add(res2).reduceLanes(VectorOperators.ADD); + } + + for (; i < a.length; i++) { + res += b[i] * a[i]; + } + return res; + } + + @Override + public float cosine(float[] a, float[] b) { + int i = 0; + float sum = 0; + float norm1 = 0; + float norm2 = 0; + // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize + if (a.length > 2 * SPECIES.length()) { + // vector loop is unrolled 4x (4 accumulators in parallel) + FloatVector sum1 = FloatVector.zero(SPECIES); + FloatVector sum2 = FloatVector.zero(SPECIES); + FloatVector sum3 = FloatVector.zero(SPECIES); + FloatVector sum4 = FloatVector.zero(SPECIES); + FloatVector norm1_1 = FloatVector.zero(SPECIES); + FloatVector norm1_2 = FloatVector.zero(SPECIES); + FloatVector norm1_3 = FloatVector.zero(SPECIES); + FloatVector norm1_4 = FloatVector.zero(SPECIES); + FloatVector norm2_1 = FloatVector.zero(SPECIES); + FloatVector norm2_2 = FloatVector.zero(SPECIES); + FloatVector norm2_3 = FloatVector.zero(SPECIES); + FloatVector norm2_4 = FloatVector.zero(SPECIES); + int upperBound = SPECIES.loopBound(a.length - 3 * SPECIES.length()); + for (; i < upperBound; i += 4 * SPECIES.length()) { + FloatVector va = FloatVector.fromArray(SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(SPECIES, b, i); + sum1 = sum1.add(va.mul(vb)); + norm1_1 = norm1_1.add(va.mul(va)); + norm2_1 = norm2_1.add(vb.mul(vb)); + FloatVector vc = FloatVector.fromArray(SPECIES, a, i + SPECIES.length()); + FloatVector vd = FloatVector.fromArray(SPECIES, b, i + SPECIES.length()); + sum2 = sum2.add(vc.mul(vd)); + norm1_2 = norm1_2.add(vc.mul(vc)); + norm2_2 = norm2_2.add(vd.mul(vd)); + FloatVector ve = FloatVector.fromArray(SPECIES, a, i + 2 * SPECIES.length()); + FloatVector vf = FloatVector.fromArray(SPECIES, b, i + 2 * SPECIES.length()); + sum3 = sum3.add(ve.mul(vf)); + norm1_3 = norm1_3.add(ve.mul(ve)); + norm2_3 = norm2_3.add(vf.mul(vf)); + FloatVector vg = FloatVector.fromArray(SPECIES, a, i + 3 * SPECIES.length()); + FloatVector vh = FloatVector.fromArray(SPECIES, b, i + 3 * SPECIES.length()); + sum4 = sum4.add(vg.mul(vh)); + norm1_4 = norm1_4.add(vg.mul(vg)); + norm2_4 = norm2_4.add(vh.mul(vh)); + } + // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + upperBound = SPECIES.loopBound(a.length); + for (; i < upperBound; i += SPECIES.length()) { + FloatVector va = FloatVector.fromArray(SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(SPECIES, b, i); + sum1 = sum1.add(va.mul(vb)); + norm1_1 = norm1_1.add(va.mul(va)); + norm2_1 = norm2_1.add(vb.mul(vb)); + } + // reduce + FloatVector sumres1 = sum1.add(sum2); + FloatVector sumres2 = sum3.add(sum4); + FloatVector norm1res1 = norm1_1.add(norm1_2); + FloatVector norm1res2 = norm1_3.add(norm1_4); + FloatVector norm2res1 = norm2_1.add(norm2_2); + FloatVector norm2res2 = norm2_3.add(norm2_4); + sum += sumres1.add(sumres2).reduceLanes(VectorOperators.ADD); + norm1 += norm1res1.add(norm1res2).reduceLanes(VectorOperators.ADD); + norm2 += norm2res1.add(norm2res2).reduceLanes(VectorOperators.ADD); + } + + for (; i < a.length; i++) { + float elem1 = a[i]; + float elem2 = b[i]; + sum += elem1 * elem2; + norm1 += elem1 * elem1; + norm2 += elem2 * elem2; + } + return (float) (sum / Math.sqrt(norm1 * norm2)); + } + + @Override + public float squareDistance(float[] a, float[] b) { + int i = 0; + float res = 0; + // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize + if (a.length > 2 * SPECIES.length()) { + // vector loop is unrolled 4x (4 accumulators in parallel) + FloatVector acc1 = FloatVector.zero(SPECIES); + FloatVector acc2 = FloatVector.zero(SPECIES); + FloatVector acc3 = FloatVector.zero(SPECIES); + FloatVector acc4 = FloatVector.zero(SPECIES); + int upperBound = SPECIES.loopBound(a.length - 3 * SPECIES.length()); + for (; i < upperBound; i += 4 * SPECIES.length()) { + FloatVector va = FloatVector.fromArray(SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(SPECIES, b, i); + FloatVector diff1 = va.sub(vb); + acc1 = acc1.add(diff1.mul(diff1)); + FloatVector vc = FloatVector.fromArray(SPECIES, a, i + SPECIES.length()); + FloatVector vd = FloatVector.fromArray(SPECIES, b, i + SPECIES.length()); + FloatVector diff2 = vc.sub(vd); + acc2 = acc2.add(diff2.mul(diff2)); + FloatVector ve = FloatVector.fromArray(SPECIES, a, i + 2 * SPECIES.length()); + FloatVector vf = FloatVector.fromArray(SPECIES, b, i + 2 * SPECIES.length()); + FloatVector diff3 = ve.sub(vf); + acc3 = acc3.add(diff3.mul(diff3)); + FloatVector vg = FloatVector.fromArray(SPECIES, a, i + 3 * SPECIES.length()); + FloatVector vh = FloatVector.fromArray(SPECIES, b, i + 3 * SPECIES.length()); + FloatVector diff4 = vg.sub(vh); + acc4 = acc4.add(diff4.mul(diff4)); + } + // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + upperBound = SPECIES.loopBound(a.length); + for (; i < upperBound; i += SPECIES.length()) { + FloatVector va = FloatVector.fromArray(SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(SPECIES, b, i); + FloatVector diff = va.sub(vb); + acc1 = acc1.add(diff.mul(diff)); + } + // reduce + FloatVector res1 = acc1.add(acc2); + FloatVector res2 = acc3.add(acc4); + res += res1.add(res2).reduceLanes(VectorOperators.ADD); + } + + for (; i < a.length; i++) { + float diff = a[i] - b[i]; + res += diff * diff; + } + return res; + } + + // Binary functions, these all follow a general pattern like this: + // + // short intermediate = a * b; + // int accumulator = accumulator + intermediate; + // + // 256 or 512 bit vectors can process 64 or 128 bits at a time, respectively + // intermediate results use 128 or 256 bit vectors, respectively + // final accumulator uses 256 or 512 bit vectors, respectively + // + // We also support 128 bit vectors, using two 128 bit accumulators. + // This is slower but still faster than not vectorizing at all. + + static final VectorSpecies<Byte> PREFERRED_BYTE_SPECIES; + static final VectorSpecies<Short> PREFERRED_SHORT_SPECIES; + + static { + if (IntVector.SPECIES_PREFERRED.vectorBitSize() >= 256) { + PREFERRED_BYTE_SPECIES = + ByteVector.SPECIES_MAX.withShape( + VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 2)); + PREFERRED_SHORT_SPECIES = + ShortVector.SPECIES_MAX.withShape( + VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 1)); + } else { + PREFERRED_BYTE_SPECIES = null; + PREFERRED_SHORT_SPECIES = null; + } + } + + static final int INT_SPECIES_PREFERRED_BIT_SIZE = IntVector.SPECIES_PREFERRED.vectorBitSize(); + + @Override + public int dotProduct(byte[] a, byte[] b) { + int i = 0; + int res = 0; + // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit + // vectors + if (a.length >= 16 && INT_SPECIES_PREFERRED_BIT_SIZE >= 128) { + // compute vectorized dot product consistent with VPDPBUSD instruction, acts like: + // int sum = 0; + // for (...) { + // short product = (short) (x[i] * y[i]); + // sum += product; + // } + if (INT_SPECIES_PREFERRED_BIT_SIZE >= 256) { + // optimized 256/512 bit implementation, processes 8/16 bytes at a time + int upperBound = PREFERRED_BYTE_SPECIES.loopBound(a.length); + IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED); + for (; i < upperBound; i += PREFERRED_BYTE_SPECIES.length()) { + ByteVector va8 = ByteVector.fromArray(PREFERRED_BYTE_SPECIES, a, i); + ByteVector vb8 = ByteVector.fromArray(PREFERRED_BYTE_SPECIES, b, i); + Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, PREFERRED_SHORT_SPECIES, 0); + Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, PREFERRED_SHORT_SPECIES, 0); + Vector<Short> prod16 = va16.mul(vb16); + Vector<Integer> prod32 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0); + acc = acc.add(prod32); + } + // reduce + res += acc.reduceLanes(VectorOperators.ADD); + } else { + // 128-bit implementation, which must "split up" vectors due to widening conversions + int upperBound = ByteVector.SPECIES_64.loopBound(a.length); + IntVector acc1 = IntVector.zero(IntVector.SPECIES_128); + IntVector acc2 = IntVector.zero(IntVector.SPECIES_128); + for (; i < upperBound; i += ByteVector.SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + // expand each byte vector into short vector and multiply + Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector<Short> prod16 = va16.mul(vb16); + // split each short vector into two int vectors and add + Vector<Integer> prod32_1 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0); + Vector<Integer> prod32_2 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1); + acc1 = acc1.add(prod32_1); + acc2 = acc2.add(prod32_2); + } + // reduce + res += acc1.add(acc2).reduceLanes(VectorOperators.ADD); + } + } + + for (; i < a.length; i++) { + res += b[i] * a[i]; + } + return res; + } + + @Override + public float cosine(byte[] a, byte[] b) { + int i = 0; + int sum = 0; + int norm1 = 0; + int norm2 = 0; + // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit + // vectors + if (a.length >= 16 && INT_SPECIES_PREFERRED_BIT_SIZE >= 128) { + // acts like: + // int sum = 0; + // for (...) { + // short difference = (short) (x[i] - y[i]); + // sum += (int) difference * (int) difference; + // } Review Comment: Done. ########## lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java: ########## @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.ShortVector; +import jdk.incubator.vector.Vector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorShape; +import jdk.incubator.vector.VectorSpecies; + +/** A VectorUtil provider implementation that leverages the Panama Vector API. */ +final class VectorUtilPanamaProvider implements VectorUtilProvider { + + static final VectorSpecies<Float> SPECIES = FloatVector.SPECIES_PREFERRED; + + VectorUtilPanamaProvider() {} + + @Override + public float dotProduct(float[] a, float[] b) { + int i = 0; + float res = 0; + // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize + if (a.length > 2 * SPECIES.length()) { + // vector loop is unrolled 4x (4 accumulators in parallel) + FloatVector acc1 = FloatVector.zero(SPECIES); + FloatVector acc2 = FloatVector.zero(SPECIES); + FloatVector acc3 = FloatVector.zero(SPECIES); + FloatVector acc4 = FloatVector.zero(SPECIES); + int upperBound = SPECIES.loopBound(a.length - 3 * SPECIES.length()); + for (; i < upperBound; i += 4 * SPECIES.length()) { + FloatVector va = FloatVector.fromArray(SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(SPECIES, b, i); + acc1 = acc1.add(va.mul(vb)); + FloatVector vc = FloatVector.fromArray(SPECIES, a, i + SPECIES.length()); + FloatVector vd = FloatVector.fromArray(SPECIES, b, i + SPECIES.length()); + acc2 = acc2.add(vc.mul(vd)); + FloatVector ve = FloatVector.fromArray(SPECIES, a, i + 2 * SPECIES.length()); + FloatVector vf = FloatVector.fromArray(SPECIES, b, i + 2 * SPECIES.length()); + acc3 = acc3.add(ve.mul(vf)); + FloatVector vg = FloatVector.fromArray(SPECIES, a, i + 3 * SPECIES.length()); + FloatVector vh = FloatVector.fromArray(SPECIES, b, i + 3 * SPECIES.length()); + acc4 = acc4.add(vg.mul(vh)); + } + // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + upperBound = SPECIES.loopBound(a.length); + for (; i < upperBound; i += SPECIES.length()) { + FloatVector va = FloatVector.fromArray(SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(SPECIES, b, i); + acc1 = acc1.add(va.mul(vb)); + } + // reduce + FloatVector res1 = acc1.add(acc2); + FloatVector res2 = acc3.add(acc4); + res += res1.add(res2).reduceLanes(VectorOperators.ADD); + } + + for (; i < a.length; i++) { + res += b[i] * a[i]; + } + return res; + } + + @Override + public float cosine(float[] a, float[] b) { + int i = 0; + float sum = 0; + float norm1 = 0; + float norm2 = 0; + // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize + if (a.length > 2 * SPECIES.length()) { + // vector loop is unrolled 4x (4 accumulators in parallel) + FloatVector sum1 = FloatVector.zero(SPECIES); + FloatVector sum2 = FloatVector.zero(SPECIES); + FloatVector sum3 = FloatVector.zero(SPECIES); + FloatVector sum4 = FloatVector.zero(SPECIES); + FloatVector norm1_1 = FloatVector.zero(SPECIES); + FloatVector norm1_2 = FloatVector.zero(SPECIES); + FloatVector norm1_3 = FloatVector.zero(SPECIES); + FloatVector norm1_4 = FloatVector.zero(SPECIES); + FloatVector norm2_1 = FloatVector.zero(SPECIES); + FloatVector norm2_2 = FloatVector.zero(SPECIES); + FloatVector norm2_3 = FloatVector.zero(SPECIES); + FloatVector norm2_4 = FloatVector.zero(SPECIES); + int upperBound = SPECIES.loopBound(a.length - 3 * SPECIES.length()); + for (; i < upperBound; i += 4 * SPECIES.length()) { + FloatVector va = FloatVector.fromArray(SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(SPECIES, b, i); + sum1 = sum1.add(va.mul(vb)); + norm1_1 = norm1_1.add(va.mul(va)); + norm2_1 = norm2_1.add(vb.mul(vb)); + FloatVector vc = FloatVector.fromArray(SPECIES, a, i + SPECIES.length()); + FloatVector vd = FloatVector.fromArray(SPECIES, b, i + SPECIES.length()); + sum2 = sum2.add(vc.mul(vd)); + norm1_2 = norm1_2.add(vc.mul(vc)); + norm2_2 = norm2_2.add(vd.mul(vd)); + FloatVector ve = FloatVector.fromArray(SPECIES, a, i + 2 * SPECIES.length()); + FloatVector vf = FloatVector.fromArray(SPECIES, b, i + 2 * SPECIES.length()); + sum3 = sum3.add(ve.mul(vf)); + norm1_3 = norm1_3.add(ve.mul(ve)); + norm2_3 = norm2_3.add(vf.mul(vf)); + FloatVector vg = FloatVector.fromArray(SPECIES, a, i + 3 * SPECIES.length()); + FloatVector vh = FloatVector.fromArray(SPECIES, b, i + 3 * SPECIES.length()); + sum4 = sum4.add(vg.mul(vh)); + norm1_4 = norm1_4.add(vg.mul(vg)); + norm2_4 = norm2_4.add(vh.mul(vh)); + } + // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + upperBound = SPECIES.loopBound(a.length); + for (; i < upperBound; i += SPECIES.length()) { + FloatVector va = FloatVector.fromArray(SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(SPECIES, b, i); + sum1 = sum1.add(va.mul(vb)); + norm1_1 = norm1_1.add(va.mul(va)); + norm2_1 = norm2_1.add(vb.mul(vb)); + } + // reduce + FloatVector sumres1 = sum1.add(sum2); + FloatVector sumres2 = sum3.add(sum4); + FloatVector norm1res1 = norm1_1.add(norm1_2); + FloatVector norm1res2 = norm1_3.add(norm1_4); + FloatVector norm2res1 = norm2_1.add(norm2_2); + FloatVector norm2res2 = norm2_3.add(norm2_4); + sum += sumres1.add(sumres2).reduceLanes(VectorOperators.ADD); + norm1 += norm1res1.add(norm1res2).reduceLanes(VectorOperators.ADD); + norm2 += norm2res1.add(norm2res2).reduceLanes(VectorOperators.ADD); + } + + for (; i < a.length; i++) { + float elem1 = a[i]; + float elem2 = b[i]; + sum += elem1 * elem2; + norm1 += elem1 * elem1; + norm2 += elem2 * elem2; + } + return (float) (sum / Math.sqrt(norm1 * norm2)); + } + + @Override + public float squareDistance(float[] a, float[] b) { + int i = 0; + float res = 0; + // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize + if (a.length > 2 * SPECIES.length()) { + // vector loop is unrolled 4x (4 accumulators in parallel) + FloatVector acc1 = FloatVector.zero(SPECIES); + FloatVector acc2 = FloatVector.zero(SPECIES); + FloatVector acc3 = FloatVector.zero(SPECIES); + FloatVector acc4 = FloatVector.zero(SPECIES); + int upperBound = SPECIES.loopBound(a.length - 3 * SPECIES.length()); + for (; i < upperBound; i += 4 * SPECIES.length()) { + FloatVector va = FloatVector.fromArray(SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(SPECIES, b, i); + FloatVector diff1 = va.sub(vb); + acc1 = acc1.add(diff1.mul(diff1)); + FloatVector vc = FloatVector.fromArray(SPECIES, a, i + SPECIES.length()); + FloatVector vd = FloatVector.fromArray(SPECIES, b, i + SPECIES.length()); + FloatVector diff2 = vc.sub(vd); + acc2 = acc2.add(diff2.mul(diff2)); + FloatVector ve = FloatVector.fromArray(SPECIES, a, i + 2 * SPECIES.length()); + FloatVector vf = FloatVector.fromArray(SPECIES, b, i + 2 * SPECIES.length()); + FloatVector diff3 = ve.sub(vf); + acc3 = acc3.add(diff3.mul(diff3)); + FloatVector vg = FloatVector.fromArray(SPECIES, a, i + 3 * SPECIES.length()); + FloatVector vh = FloatVector.fromArray(SPECIES, b, i + 3 * SPECIES.length()); + FloatVector diff4 = vg.sub(vh); + acc4 = acc4.add(diff4.mul(diff4)); + } + // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + upperBound = SPECIES.loopBound(a.length); + for (; i < upperBound; i += SPECIES.length()) { + FloatVector va = FloatVector.fromArray(SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(SPECIES, b, i); + FloatVector diff = va.sub(vb); + acc1 = acc1.add(diff.mul(diff)); + } + // reduce + FloatVector res1 = acc1.add(acc2); + FloatVector res2 = acc3.add(acc4); + res += res1.add(res2).reduceLanes(VectorOperators.ADD); + } + + for (; i < a.length; i++) { + float diff = a[i] - b[i]; + res += diff * diff; + } + return res; + } + + // Binary functions, these all follow a general pattern like this: + // + // short intermediate = a * b; + // int accumulator = accumulator + intermediate; + // + // 256 or 512 bit vectors can process 64 or 128 bits at a time, respectively + // intermediate results use 128 or 256 bit vectors, respectively + // final accumulator uses 256 or 512 bit vectors, respectively + // + // We also support 128 bit vectors, using two 128 bit accumulators. + // This is slower but still faster than not vectorizing at all. + + static final VectorSpecies<Byte> PREFERRED_BYTE_SPECIES; + static final VectorSpecies<Short> PREFERRED_SHORT_SPECIES; + + static { + if (IntVector.SPECIES_PREFERRED.vectorBitSize() >= 256) { + PREFERRED_BYTE_SPECIES = + ByteVector.SPECIES_MAX.withShape( + VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 2)); + PREFERRED_SHORT_SPECIES = + ShortVector.SPECIES_MAX.withShape( + VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 1)); + } else { + PREFERRED_BYTE_SPECIES = null; + PREFERRED_SHORT_SPECIES = null; + } + } + + static final int INT_SPECIES_PREFERRED_BIT_SIZE = IntVector.SPECIES_PREFERRED.vectorBitSize(); + + @Override + public int dotProduct(byte[] a, byte[] b) { + int i = 0; + int res = 0; + // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit + // vectors + if (a.length >= 16 && INT_SPECIES_PREFERRED_BIT_SIZE >= 128) { + // compute vectorized dot product consistent with VPDPBUSD instruction, acts like: + // int sum = 0; + // for (...) { + // short product = (short) (x[i] * y[i]); + // sum += product; + // } + if (INT_SPECIES_PREFERRED_BIT_SIZE >= 256) { + // optimized 256/512 bit implementation, processes 8/16 bytes at a time + int upperBound = PREFERRED_BYTE_SPECIES.loopBound(a.length); + IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED); + for (; i < upperBound; i += PREFERRED_BYTE_SPECIES.length()) { + ByteVector va8 = ByteVector.fromArray(PREFERRED_BYTE_SPECIES, a, i); + ByteVector vb8 = ByteVector.fromArray(PREFERRED_BYTE_SPECIES, b, i); + Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, PREFERRED_SHORT_SPECIES, 0); + Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, PREFERRED_SHORT_SPECIES, 0); + Vector<Short> prod16 = va16.mul(vb16); + Vector<Integer> prod32 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0); + acc = acc.add(prod32); + } + // reduce + res += acc.reduceLanes(VectorOperators.ADD); + } else { + // 128-bit implementation, which must "split up" vectors due to widening conversions + int upperBound = ByteVector.SPECIES_64.loopBound(a.length); + IntVector acc1 = IntVector.zero(IntVector.SPECIES_128); + IntVector acc2 = IntVector.zero(IntVector.SPECIES_128); + for (; i < upperBound; i += ByteVector.SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + // expand each byte vector into short vector and multiply + Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector<Short> prod16 = va16.mul(vb16); + // split each short vector into two int vectors and add + Vector<Integer> prod32_1 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0); + Vector<Integer> prod32_2 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1); + acc1 = acc1.add(prod32_1); + acc2 = acc2.add(prod32_2); + } + // reduce + res += acc1.add(acc2).reduceLanes(VectorOperators.ADD); + } + } + + for (; i < a.length; i++) { + res += b[i] * a[i]; + } + return res; + } + + @Override + public float cosine(byte[] a, byte[] b) { + int i = 0; + int sum = 0; + int norm1 = 0; + int norm2 = 0; + // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit + // vectors + if (a.length >= 16 && INT_SPECIES_PREFERRED_BIT_SIZE >= 128) { + // acts like: + // int sum = 0; + // for (...) { + // short difference = (short) (x[i] - y[i]); + // sum += (int) difference * (int) difference; + // } Review Comment: Done. [11e6634](https://github.com/apache/lucene/pull/12311/commits/11e66348adc0304a5df996bfde8253acdbd15806) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org