/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.internal.vectorization;

import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.ByteOrder;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorShape;
import jdk.incubator.vector.VectorSpecies;
import org.apache.lucene.internal.vectorization.PanamaVectorConstants;
import org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport;

public final class MemorySegmentBulkVectorOps {
    static final VectorSpecies<Float> FLOAT_SPECIES = VectorSpecies.of(Float.TYPE, (VectorShape)VectorShape.forBitSize((int)PanamaVectorConstants.PREFERRED_VECTOR_BITSIZE));
    static final ByteOrder LE = ByteOrder.LITTLE_ENDIAN;
    static final ValueLayout.OfFloat LAYOUT_LE_FLOAT = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(LE);
    public static final DotProduct DOT_INSTANCE = new DotProduct();
    public static final Cosine COS_INSTANCE = new Cosine();
    public static final SqrDistance SQR_INSTANCE = new SqrDistance();

    private MemorySegmentBulkVectorOps() {
    }

    public static final class DotProduct {
        private DotProduct() {
        }

        public void dotProductBulk(MemorySegment dataSeg, float[] scores, float[] q, long d1, long d2, long d3, long d4, int elementCount) {
            this.dotProductBulkImpl(dataSeg, scores, q, -1L, d1, d2, d3, d4, elementCount);
        }

        public void dotProductBulk(MemorySegment seg, float[] scores, long q, long d1, long d2, long d3, long d4, int elementCount) {
            this.dotProductBulkImpl(seg, scores, null, q, d1, d2, d3, d4, elementCount);
        }

        public float dotProduct(MemorySegment seg, long q, long d, int elementCount) {
            int i;
            FloatVector sv = FloatVector.zero(FLOAT_SPECIES);
            int limit = FLOAT_SPECIES.loopBound(elementCount);
            for (i = 0; i < limit; i += FLOAT_SPECIES.length()) {
                long offset = (long)i * 4L;
                FloatVector qv = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(q + offset), (ByteOrder)LE);
                FloatVector dv = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d + offset), (ByteOrder)LE);
                sv = PanamaVectorUtilSupport.fma(qv, dv, sv);
            }
            float score = sv.reduceLanes(VectorOperators.ADD);
            while (i < elementCount) {
                long offset = (long)i * 4L;
                score += seg.get(LAYOUT_LE_FLOAT, q + offset) * seg.get(LAYOUT_LE_FLOAT, d + offset);
                ++i;
            }
            return score;
        }

        private void dotProductBulkImpl(MemorySegment seg, float[] scores, float[] qArray, long qOffset, long d1, long d2, long d3, long d4, int elementCount) {
            int i;
            FloatVector sv1 = FloatVector.zero(FLOAT_SPECIES);
            FloatVector sv2 = FloatVector.zero(FLOAT_SPECIES);
            FloatVector sv3 = FloatVector.zero(FLOAT_SPECIES);
            FloatVector sv4 = FloatVector.zero(FLOAT_SPECIES);
            int limit = FLOAT_SPECIES.loopBound(elementCount);
            for (i = 0; i < limit; i += FLOAT_SPECIES.length()) {
                long offset = (long)i * 4L;
                FloatVector dv1 = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d1 + offset), (ByteOrder)LE);
                FloatVector dv2 = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d2 + offset), (ByteOrder)LE);
                FloatVector dv3 = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d3 + offset), (ByteOrder)LE);
                FloatVector dv4 = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d4 + offset), (ByteOrder)LE);
                FloatVector qv = qOffset == -1L ? FloatVector.fromArray(FLOAT_SPECIES, (float[])qArray, (int)i) : FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(qOffset + offset), (ByteOrder)LE);
                sv1 = PanamaVectorUtilSupport.fma(qv, dv1, sv1);
                sv2 = PanamaVectorUtilSupport.fma(qv, dv2, sv2);
                sv3 = PanamaVectorUtilSupport.fma(qv, dv3, sv3);
                sv4 = PanamaVectorUtilSupport.fma(qv, dv4, sv4);
            }
            float sum1 = sv1.reduceLanes(VectorOperators.ADD);
            float sum2 = sv2.reduceLanes(VectorOperators.ADD);
            float sum3 = sv3.reduceLanes(VectorOperators.ADD);
            float sum4 = sv4.reduceLanes(VectorOperators.ADD);
            while (i < elementCount) {
                long offset = (long)i * 4L;
                float qValue = qOffset == -1L ? qArray[i] : seg.get(LAYOUT_LE_FLOAT, qOffset + offset);
                sum1 = PanamaVectorUtilSupport.fma(qValue, seg.get(LAYOUT_LE_FLOAT, d1 + offset), sum1);
                sum2 = PanamaVectorUtilSupport.fma(qValue, seg.get(LAYOUT_LE_FLOAT, d2 + offset), sum2);
                sum3 = PanamaVectorUtilSupport.fma(qValue, seg.get(LAYOUT_LE_FLOAT, d3 + offset), sum3);
                sum4 = PanamaVectorUtilSupport.fma(qValue, seg.get(LAYOUT_LE_FLOAT, d4 + offset), sum4);
                ++i;
            }
            scores[0] = sum1;
            scores[1] = sum2;
            scores[2] = sum3;
            scores[3] = sum4;
        }
    }

    public static final class Cosine {
        private Cosine() {
        }

        public void cosineBulk(MemorySegment dataSeg, float[] scores, float[] q, long d1, long d2, long d3, long d4, int elementCount) {
            this.cosineBulkImpl(dataSeg, scores, q, -1L, d1, d2, d3, d4, elementCount);
        }

        public void cosineBulk(MemorySegment seg, float[] scores, long q, long d1, long d2, long d3, long d4, int elementCount) {
            this.cosineBulkImpl(seg, scores, null, q, d1, d2, d3, d4, elementCount);
        }

        public float cosine(MemorySegment seg, long q, long d, int elementCount) {
            int i;
            FloatVector sv = FloatVector.zero(FLOAT_SPECIES);
            FloatVector qvNorm = FloatVector.zero(FLOAT_SPECIES);
            FloatVector dvNorm = FloatVector.zero(FLOAT_SPECIES);
            int limit = FLOAT_SPECIES.loopBound(elementCount);
            for (i = 0; i < limit; i += FLOAT_SPECIES.length()) {
                long offset = (long)i * 4L;
                FloatVector qv = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(q + offset), (ByteOrder)LE);
                FloatVector dv = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d + offset), (ByteOrder)LE);
                sv = PanamaVectorUtilSupport.fma(qv, dv, sv);
                qvNorm = PanamaVectorUtilSupport.fma(qv, qv, qvNorm);
                dvNorm = PanamaVectorUtilSupport.fma(dv, dv, dvNorm);
            }
            float sum = sv.reduceLanes(VectorOperators.ADD);
            float qNorm = qvNorm.reduceLanes(VectorOperators.ADD);
            float dNorm = dvNorm.reduceLanes(VectorOperators.ADD);
            while (i < elementCount) {
                long offset = (long)i * 4L;
                float qValue = seg.get(LAYOUT_LE_FLOAT, q + offset);
                float dValue = seg.get(LAYOUT_LE_FLOAT, d + offset);
                sum = PanamaVectorUtilSupport.fma(qValue, dValue, sum);
                qNorm = PanamaVectorUtilSupport.fma(qValue, qValue, qNorm);
                dNorm = PanamaVectorUtilSupport.fma(dValue, dValue, dNorm);
                ++i;
            }
            return (float)((double)sum / Math.sqrt((double)qNorm * (double)dNorm));
        }

        private void cosineBulkImpl(MemorySegment seg, float[] scores, float[] qArray, long qOffset, long d1, long d2, long d3, long d4, int elementCount) {
            int i;
            FloatVector sv1 = FloatVector.zero(FLOAT_SPECIES);
            FloatVector sv2 = FloatVector.zero(FLOAT_SPECIES);
            FloatVector sv3 = FloatVector.zero(FLOAT_SPECIES);
            FloatVector sv4 = FloatVector.zero(FLOAT_SPECIES);
            FloatVector qvNorm = FloatVector.zero(FLOAT_SPECIES);
            FloatVector dv1Norm = FloatVector.zero(FLOAT_SPECIES);
            FloatVector dv2Norm = FloatVector.zero(FLOAT_SPECIES);
            FloatVector dv3Norm = FloatVector.zero(FLOAT_SPECIES);
            FloatVector dv4Norm = FloatVector.zero(FLOAT_SPECIES);
            int limit = FLOAT_SPECIES.loopBound(elementCount);
            for (i = 0; i < limit; i += FLOAT_SPECIES.length()) {
                long offset = (long)i * 4L;
                FloatVector dv1 = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d1 + offset), (ByteOrder)LE);
                FloatVector dv2 = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d2 + offset), (ByteOrder)LE);
                FloatVector dv3 = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d3 + offset), (ByteOrder)LE);
                FloatVector dv4 = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d4 + offset), (ByteOrder)LE);
                FloatVector qv = qOffset == -1L ? FloatVector.fromArray(FLOAT_SPECIES, (float[])qArray, (int)i) : FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(qOffset + offset), (ByteOrder)LE);
                qvNorm = PanamaVectorUtilSupport.fma(qv, qv, qvNorm);
                dv1Norm = PanamaVectorUtilSupport.fma(dv1, dv1, dv1Norm);
                sv1 = PanamaVectorUtilSupport.fma(qv, dv1, sv1);
                dv2Norm = PanamaVectorUtilSupport.fma(dv2, dv2, dv2Norm);
                sv2 = PanamaVectorUtilSupport.fma(qv, dv2, sv2);
                dv3Norm = PanamaVectorUtilSupport.fma(dv3, dv3, dv3Norm);
                sv3 = PanamaVectorUtilSupport.fma(qv, dv3, sv3);
                dv4Norm = PanamaVectorUtilSupport.fma(dv4, dv4, dv4Norm);
                sv4 = PanamaVectorUtilSupport.fma(qv, dv4, sv4);
            }
            float sum1 = sv1.reduceLanes(VectorOperators.ADD);
            float sum2 = sv2.reduceLanes(VectorOperators.ADD);
            float sum3 = sv3.reduceLanes(VectorOperators.ADD);
            float sum4 = sv4.reduceLanes(VectorOperators.ADD);
            float qNorm = qvNorm.reduceLanes(VectorOperators.ADD);
            float d1Norm = dv1Norm.reduceLanes(VectorOperators.ADD);
            float d2Norm = dv2Norm.reduceLanes(VectorOperators.ADD);
            float d3Norm = dv3Norm.reduceLanes(VectorOperators.ADD);
            float d4Norm = dv4Norm.reduceLanes(VectorOperators.ADD);
            while (i < elementCount) {
                long offset = (long)i * 4L;
                float qValue = qOffset == -1L ? qArray[i] : seg.get(LAYOUT_LE_FLOAT, qOffset + offset);
                float d1Value = seg.get(LAYOUT_LE_FLOAT, d1 + offset);
                float d2Value = seg.get(LAYOUT_LE_FLOAT, d2 + offset);
                float d3Value = seg.get(LAYOUT_LE_FLOAT, d3 + offset);
                float d4Value = seg.get(LAYOUT_LE_FLOAT, d4 + offset);
                sum1 = PanamaVectorUtilSupport.fma(qValue, d1Value, sum1);
                sum2 = PanamaVectorUtilSupport.fma(qValue, d2Value, sum2);
                sum3 = PanamaVectorUtilSupport.fma(qValue, d3Value, sum3);
                sum4 = PanamaVectorUtilSupport.fma(qValue, d4Value, sum4);
                qNorm = PanamaVectorUtilSupport.fma(qValue, qValue, qNorm);
                d1Norm = PanamaVectorUtilSupport.fma(d1Value, d1Value, d1Norm);
                d2Norm = PanamaVectorUtilSupport.fma(d2Value, d2Value, d2Norm);
                d3Norm = PanamaVectorUtilSupport.fma(d3Value, d3Value, d3Norm);
                d4Norm = PanamaVectorUtilSupport.fma(d4Value, d4Value, d4Norm);
                ++i;
            }
            scores[0] = (float)((double)sum1 / Math.sqrt((double)qNorm * (double)d1Norm));
            scores[1] = (float)((double)sum2 / Math.sqrt((double)qNorm * (double)d2Norm));
            scores[2] = (float)((double)sum3 / Math.sqrt((double)qNorm * (double)d3Norm));
            scores[3] = (float)((double)sum4 / Math.sqrt((double)qNorm * (double)d4Norm));
        }
    }

    public static final class SqrDistance {
        private SqrDistance() {
        }

        public void sqrDistanceBulk(MemorySegment dataSeg, float[] scores, float[] q, long d1, long d2, long d3, long d4, int elementCount) {
            this.sqrDistanceBulkImpl(dataSeg, scores, q, -1L, d1, d2, d3, d4, elementCount);
        }

        public void sqrDistanceBulk(MemorySegment seg, float[] scores, long q, long d1, long d2, long d3, long d4, int elementCount) {
            this.sqrDistanceBulkImpl(seg, scores, null, q, d1, d2, d3, d4, elementCount);
        }

        public float sqrDistance(MemorySegment seg, long q, long d, int elementCount) {
            int i;
            FloatVector sv = FloatVector.zero(FLOAT_SPECIES);
            int limit = FLOAT_SPECIES.loopBound(elementCount);
            for (i = 0; i < limit; i += FLOAT_SPECIES.length()) {
                long offset = (long)i * 4L;
                FloatVector qv = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(q + offset), (ByteOrder)LE);
                FloatVector dv = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d + offset), (ByteOrder)LE);
                FloatVector diff = qv.sub((Vector)dv);
                sv = PanamaVectorUtilSupport.fma(diff, diff, sv);
            }
            float score = sv.reduceLanes(VectorOperators.ADD);
            while (i < elementCount) {
                long offset = (long)i * 4L;
                float diff = seg.get(LAYOUT_LE_FLOAT, q + offset) - seg.get(LAYOUT_LE_FLOAT, d + offset);
                score = PanamaVectorUtilSupport.fma(diff, diff, score);
                ++i;
            }
            return score;
        }

        private void sqrDistanceBulkImpl(MemorySegment seg, float[] scores, float[] qArray, long qOffset, long d1, long d2, long d3, long d4, int elementCount) {
            int i;
            FloatVector sv1 = FloatVector.zero(FLOAT_SPECIES);
            FloatVector sv2 = FloatVector.zero(FLOAT_SPECIES);
            FloatVector sv3 = FloatVector.zero(FLOAT_SPECIES);
            FloatVector sv4 = FloatVector.zero(FLOAT_SPECIES);
            int limit = FLOAT_SPECIES.loopBound(elementCount);
            for (i = 0; i < limit; i += FLOAT_SPECIES.length()) {
                long offset = (long)i * 4L;
                FloatVector dv1 = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d1 + offset), (ByteOrder)LE);
                FloatVector dv2 = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d2 + offset), (ByteOrder)LE);
                FloatVector dv3 = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d3 + offset), (ByteOrder)LE);
                FloatVector dv4 = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(d4 + offset), (ByteOrder)LE);
                FloatVector qv = qOffset == -1L ? FloatVector.fromArray(FLOAT_SPECIES, (float[])qArray, (int)i) : FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)seg, (long)(qOffset + offset), (ByteOrder)LE);
                FloatVector diff1 = qv.sub((Vector)dv1);
                FloatVector diff2 = qv.sub((Vector)dv2);
                FloatVector diff3 = qv.sub((Vector)dv3);
                FloatVector diff4 = qv.sub((Vector)dv4);
                sv1 = PanamaVectorUtilSupport.fma(diff1, diff1, sv1);
                sv2 = PanamaVectorUtilSupport.fma(diff2, diff2, sv2);
                sv3 = PanamaVectorUtilSupport.fma(diff3, diff3, sv3);
                sv4 = PanamaVectorUtilSupport.fma(diff4, diff4, sv4);
            }
            float sum1 = sv1.reduceLanes(VectorOperators.ADD);
            float sum2 = sv2.reduceLanes(VectorOperators.ADD);
            float sum3 = sv3.reduceLanes(VectorOperators.ADD);
            float sum4 = sv4.reduceLanes(VectorOperators.ADD);
            while (i < elementCount) {
                long offset = (long)i * 4L;
                float qValue = qOffset == -1L ? qArray[i] : seg.get(LAYOUT_LE_FLOAT, qOffset + offset);
                float diff1 = qValue - seg.get(LAYOUT_LE_FLOAT, d1 + offset);
                float diff2 = qValue - seg.get(LAYOUT_LE_FLOAT, d2 + offset);
                float diff3 = qValue - seg.get(LAYOUT_LE_FLOAT, d3 + offset);
                float diff4 = qValue - seg.get(LAYOUT_LE_FLOAT, d4 + offset);
                sum1 = PanamaVectorUtilSupport.fma(diff1, diff1, sum1);
                sum2 = PanamaVectorUtilSupport.fma(diff2, diff2, sum2);
                sum3 = PanamaVectorUtilSupport.fma(diff3, diff3, sum3);
                sum4 = PanamaVectorUtilSupport.fma(diff4, diff4, sum4);
                ++i;
            }
            scores[0] = sum1;
            scores[1] = sum2;
            scores[2] = sum3;
            scores[3] = sum4;
        }
    }
}

