From bebb4f6f60d7cac9a6073922652f7a6e96df41d6 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Thu, 14 Nov 2024 15:08:26 -0600 Subject: [PATCH] Make PVUS#assembleAndSum use SimdOps; optimize SimdOps assembleAndSum --- .../jbellis/jvector/vector/PanamaVectorUtilSupport.java | 6 +----- .../java/io/github/jbellis/jvector/vector/SimdOps.java | 9 ++++----- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index e71d780a..060ca9e1 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -92,11 +92,7 @@ public VectorFloat sub(VectorFloat a, int aOffset, VectorFloat b, int b @Override public float assembleAndSum(VectorFloat data, int dataBase, ByteSequence baseOffsets) { - float sum = 0f; - for (int i = 0; i < baseOffsets.length(); i++) { - sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i))); - } - return sum; + return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ArrayByteSequence) baseOffsets).get()); } @Override diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index 55ef28f3..60c5170b 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -525,11 +525,10 @@ static float assembleAndSum512(float[] data, int dataBase, byte[] baseOffsets) { FloatVector sum = FloatVector.zero(FloatVector.SPECIES_512); int i = 0; int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length); + var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(dataBase); for (; i < limit; i += ByteVector.SPECIES_128.length()) { - var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(1).add(i).mul(dataBase); - - ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i) + ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i * dataBase) .convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512) .reinterpretAsInts() @@ -553,11 +552,11 @@ static float assembleAndSum256(float[] data, int dataBase, byte[] baseOffsets) { FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); int i = 0; int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length); + var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(dataBase); for (; i < limit; i += ByteVector.SPECIES_64.length()) { - var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(1).add(i).mul(dataBase); - ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i) + ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i * dataBase) .convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256) .reinterpretAsInts()