/*
 * Decompiled with CFR 0.152.
 */
package org.graalvm.compiler.lir.aarch64;

import java.util.Arrays;
import jdk.vm.ci.aarch64.AArch64;
import jdk.vm.ci.aarch64.AArch64Kind;
import jdk.vm.ci.code.Register;
import jdk.vm.ci.code.ValueUtil;
import jdk.vm.ci.meta.AllocatableValue;
import jdk.vm.ci.meta.Value;
import org.graalvm.compiler.asm.Label;
import org.graalvm.compiler.asm.aarch64.AArch64ASIMDAssembler;
import org.graalvm.compiler.asm.aarch64.AArch64Address;
import org.graalvm.compiler.asm.aarch64.AArch64Assembler;
import org.graalvm.compiler.asm.aarch64.AArch64MacroAssembler;
import org.graalvm.compiler.code.DataSection;
import org.graalvm.compiler.core.common.Stride;
import org.graalvm.compiler.core.common.StrideUtil;
import org.graalvm.compiler.debug.GraalError;
import org.graalvm.compiler.lir.LIRInstruction;
import org.graalvm.compiler.lir.LIRInstructionClass;
import org.graalvm.compiler.lir.Opcode;
import org.graalvm.compiler.lir.aarch64.AArch64ComplexVectorOp;
import org.graalvm.compiler.lir.aarch64.AArch64ControlFlow;
import org.graalvm.compiler.lir.aarch64.AArch64LIRInstruction;
import org.graalvm.compiler.lir.asm.CompilationResultBuilder;
import org.graalvm.compiler.lir.gen.LIRGeneratorTool;

@Opcode(value="ARRAY_REGION_COMPARE_TO")
public final class AArch64ArrayRegionCompareToOp
extends AArch64ComplexVectorOp {
    public static final LIRInstructionClass<AArch64ArrayRegionCompareToOp> TYPE = LIRInstructionClass.create(AArch64ArrayRegionCompareToOp.class);
    private final Stride argStrideA;
    private final Stride argStrideB;
    @LIRInstruction.Def(value={LIRInstruction.OperandFlag.REG})
    protected Value resultValue;
    @LIRInstruction.Alive(value={LIRInstruction.OperandFlag.REG})
    protected Value arrayAValue;
    @LIRInstruction.Alive(value={LIRInstruction.OperandFlag.REG})
    protected Value offsetAValue;
    @LIRInstruction.Alive(value={LIRInstruction.OperandFlag.REG})
    protected Value arrayBValue;
    @LIRInstruction.Alive(value={LIRInstruction.OperandFlag.REG})
    protected Value offsetBValue;
    @LIRInstruction.Alive(value={LIRInstruction.OperandFlag.REG})
    protected Value lengthValue;
    @LIRInstruction.Alive(value={LIRInstruction.OperandFlag.REG, LIRInstruction.OperandFlag.ILLEGAL})
    private Value dynamicStridesValue;
    @LIRInstruction.Temp(value={LIRInstruction.OperandFlag.REG})
    protected AllocatableValue[] temp;
    @LIRInstruction.Temp(value={LIRInstruction.OperandFlag.REG})
    protected AllocatableValue[] vectorTemp;

    public AArch64ArrayRegionCompareToOp(LIRGeneratorTool tool, Stride strideA, Stride strideB, Value result, Value arrayA, Value offsetA, Value arrayB, Value offsetB, Value length, Value dynamicStrides) {
        super((LIRInstructionClass<? extends AArch64LIRInstruction>)TYPE);
        this.argStrideA = strideA;
        this.argStrideB = strideB;
        GraalError.guarantee(result.getPlatformKind() == AArch64Kind.DWORD, "int value expected");
        GraalError.guarantee(arrayA.getPlatformKind() == AArch64Kind.QWORD && arrayA.getPlatformKind() == arrayB.getPlatformKind(), "pointer value expected");
        GraalError.guarantee(offsetA.getPlatformKind() == AArch64Kind.QWORD, "long value expected");
        GraalError.guarantee(offsetB.getPlatformKind() == AArch64Kind.QWORD, "long value expected");
        GraalError.guarantee(length.getPlatformKind() == AArch64Kind.DWORD, "int value expected");
        this.resultValue = result;
        this.arrayAValue = arrayA;
        this.offsetAValue = offsetA;
        this.arrayBValue = arrayB;
        this.offsetBValue = offsetB;
        this.lengthValue = length;
        this.dynamicStridesValue = dynamicStrides == null ? Value.ILLEGAL : dynamicStrides;
        this.temp = AArch64ArrayRegionCompareToOp.allocateTempRegisters(tool, this.withDynamicStrides() ? 3 : 2);
        this.vectorTemp = AArch64ArrayRegionCompareToOp.allocateVectorRegisters(tool, 7);
    }

    @Override
    public void emitCode(CompilationResultBuilder crb, AArch64MacroAssembler masm) {
        try (AArch64MacroAssembler.ScratchRegister sc1 = masm.getScratchRegister();
             AArch64MacroAssembler.ScratchRegister sc2 = masm.getScratchRegister();){
            Register arrayA = sc1.getRegister();
            Register arrayB = sc2.getRegister();
            Register length = ValueUtil.asRegister((Value)this.temp[0]);
            Register tmp = ValueUtil.asRegister((Value)this.temp[1]);
            Register ret = ValueUtil.asRegister((Value)this.resultValue);
            Label end = new Label();
            masm.add(64, arrayA, ValueUtil.asRegister((Value)this.arrayAValue), ValueUtil.asRegister((Value)this.offsetAValue));
            masm.add(64, arrayB, ValueUtil.asRegister((Value)this.arrayBValue), ValueUtil.asRegister((Value)this.offsetBValue));
            masm.mov(32, length, ValueUtil.asRegister((Value)this.lengthValue));
            if (this.withDynamicStrides()) {
                Label[] variants = new Label[9];
                for (int i = 0; i < variants.length; ++i) {
                    variants[i] = new Label();
                }
                Register tmp2 = ValueUtil.asRegister((Value)this.temp[2]);
                masm.mov(32, tmp2, ValueUtil.asRegister((Value)this.dynamicStridesValue));
                AArch64ControlFlow.RangeTableSwitchOp.emitJumpTable(crb, masm, tmp, tmp2, 0, 8, Arrays.stream(variants));
                for (Stride stride1 : new Stride[]{Stride.S1, Stride.S2, Stride.S4}) {
                    for (Stride stride2 : new Stride[]{Stride.S1, Stride.S2, Stride.S4}) {
                        masm.align(16);
                        masm.bind(variants[StrideUtil.getDirectStubCallIndex(stride1, stride2)]);
                        this.emitArrayCompare(crb, masm, stride1, stride2, arrayA, arrayB, length, tmp, ret, end);
                        masm.jmp(end);
                    }
                }
            } else {
                this.emitArrayCompare(crb, masm, this.argStrideA, this.argStrideB, arrayA, arrayB, length, tmp, ret, end);
            }
            masm.align(16);
            masm.bind(end);
        }
    }

    private boolean withDynamicStrides() {
        return !ValueUtil.isIllegal((Value)this.dynamicStridesValue);
    }

    private void emitArrayCompare(CompilationResultBuilder crb, AArch64MacroAssembler asm, Stride strideA, Stride strideB, Register arrayA, Register arrayB, Register len, Register tmp, Register ret, Label end) {
        Label tailLessThan32 = new Label();
        Label tailLessThan16 = new Label();
        Label tailLessThan8 = new Label();
        Label tailLessThan4 = new Label();
        Label tailLessThan2 = new Label();
        Label returnV1 = new Label();
        Label vectorLoop = new Label();
        Label diffFound = new Label();
        Register vecArrayA1 = ValueUtil.asRegister((Value)this.vectorTemp[0]);
        Register vecArrayA2 = ValueUtil.asRegister((Value)this.vectorTemp[1]);
        Register vecArrayB1 = ValueUtil.asRegister((Value)this.vectorTemp[2]);
        Register vecArrayB2 = ValueUtil.asRegister((Value)this.vectorTemp[3]);
        Register vecTmp1 = ValueUtil.asRegister((Value)this.vectorTemp[4]);
        Register vecTmp2 = ValueUtil.asRegister((Value)this.vectorTemp[5]);
        Register vecMask = ValueUtil.asRegister((Value)this.vectorTemp[6]);
        Register maxStrideArray = strideA.value < strideB.value ? arrayB : arrayA;
        Register minStrideArray = strideA.value < strideB.value ? arrayA : arrayB;
        Stride strideMax = Stride.max(strideA, strideB);
        Stride strideMin = Stride.min(strideA, strideB);
        byte[] maskIndices = new byte[16];
        for (int i = 0; i < maskIndices.length; ++i) {
            maskIndices[i] = (byte)i;
        }
        DataSection.Data maskData = AArch64ArrayRegionCompareToOp.writeToDataSection(crb, maskIndices);
        AArch64ArrayRegionCompareToOp.loadDataSectionAddress(crb, asm, tmp, maskData);
        asm.fldr(128, vecMask, AArch64Address.createBaseRegisterOnlyAddress(128, tmp));
        asm.subs(64, len, len, 32 >> strideMax.log2);
        asm.branchConditionally(AArch64Assembler.ConditionFlag.MI, tailLessThan32);
        Register refAddress = len;
        asm.add(64, refAddress, maxStrideArray, len, AArch64Assembler.ShiftType.LSL, strideMax.log2);
        asm.align(16);
        asm.bind(vectorLoop);
        AArch64ArrayRegionCompareToOp.loadAndExtend(asm, strideMax, strideA, arrayA, vecArrayA1, vecArrayA2);
        AArch64ArrayRegionCompareToOp.loadAndExtend(asm, strideMax, strideB, arrayB, vecArrayB1, vecArrayB2);
        asm.neon.eorVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, vecTmp1, vecArrayA1, vecArrayB1);
        asm.neon.eorVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, vecTmp2, vecArrayA2, vecArrayB2);
        asm.neon.orrVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, vecTmp2, vecTmp2, vecTmp1);
        AArch64ArrayRegionCompareToOp.vectorCheckZero(asm, vecTmp2, vecTmp2);
        asm.branchConditionally(AArch64Assembler.ConditionFlag.NE, diffFound);
        asm.cmp(64, maxStrideArray, refAddress);
        asm.branchConditionally(AArch64Assembler.ConditionFlag.LO, vectorLoop);
        asm.sub(64, tmp, maxStrideArray, refAddress);
        asm.mov(64, maxStrideArray, refAddress);
        asm.sub(64, minStrideArray, minStrideArray, tmp, AArch64Assembler.ShiftType.LSR, strideMax.log2 - strideMin.log2);
        AArch64ArrayRegionCompareToOp.loadAndExtend(asm, strideMax, strideA, arrayA, vecArrayA1, vecArrayA2);
        AArch64ArrayRegionCompareToOp.loadAndExtend(asm, strideMax, strideB, arrayB, vecArrayB1, vecArrayB2);
        asm.neon.eorVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, vecTmp1, vecArrayA1, vecArrayB1);
        asm.neon.eorVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, vecTmp2, vecArrayA2, vecArrayB2);
        asm.neon.orrVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, vecTmp2, vecTmp2, vecTmp1);
        AArch64ArrayRegionCompareToOp.vectorCheckZero(asm, vecTmp2, vecTmp2);
        asm.branchConditionally(AArch64Assembler.ConditionFlag.NE, diffFound);
        asm.mov(64, ret, AArch64.zr);
        asm.jmp(end);
        AArch64ArrayRegionCompareToOp.tailLoad2Vec(asm, strideA, strideB, strideMax, arrayA, arrayB, len, vecArrayA1, vecArrayA2, vecArrayB1, vecArrayB2, tailLessThan32, tailLessThan16, 16);
        asm.neon.eorVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, vecTmp1, vecArrayA1, vecArrayB1);
        asm.jmp(diffFound);
        AArch64ArrayRegionCompareToOp.tailLoad1Vec(asm, strideA, strideB, strideMax, arrayA, arrayB, len, vecArrayA1, vecArrayA2, vecArrayB1, vecArrayB2, tmp, ret, tailLessThan16, tailLessThan8, returnV1, end, 8);
        AArch64ArrayRegionCompareToOp.tailLoad1Vec(asm, strideA, strideB, strideMax, arrayA, arrayB, len, vecArrayA1, vecArrayA2, vecArrayB1, vecArrayB2, tmp, ret, tailLessThan8, tailLessThan4, returnV1, end, 4);
        AArch64ArrayRegionCompareToOp.tailLoad1Vec(asm, strideA, strideB, strideMax, arrayA, arrayB, len, vecArrayA1, vecArrayA2, vecArrayB1, vecArrayB2, tmp, ret, tailLessThan4, tailLessThan2, returnV1, end, 2);
        AArch64ArrayRegionCompareToOp.tailLoad1Vec(asm, strideA, strideB, strideMax, arrayA, arrayB, len, vecArrayA1, vecArrayA2, vecArrayB1, vecArrayB2, tmp, ret, tailLessThan2, null, returnV1, end, 1);
        asm.align(16);
        asm.bind(diffFound);
        AArch64ArrayRegionCompareToOp.vectorCheckZero(asm, vecTmp1, vecTmp1);
        asm.branchConditionally(AArch64Assembler.ConditionFlag.NE, returnV1);
        AArch64ArrayRegionCompareToOp.calcReturnValue(asm, ret, vecArrayA2, vecArrayB2, vecArrayA1, vecArrayB1, vecMask, strideMax);
        asm.jmp(end);
        asm.align(16);
        asm.bind(returnV1);
        AArch64ArrayRegionCompareToOp.calcReturnValue(asm, ret, vecArrayA1, vecArrayB1, vecArrayA2, vecArrayB2, vecMask, strideMax);
    }

    private static void loadAndExtend(AArch64MacroAssembler asm, Stride strideDst, Stride strideSrc, Register arrayAddress, Register vectorLo, Register vectorHi) {
        assert (arrayAddress.getRegisterCategory().equals((Object)AArch64.CPU));
        assert (vectorLo.getRegisterCategory().equals((Object)AArch64.SIMD));
        assert (vectorHi.getRegisterCategory().equals((Object)AArch64.SIMD));
        switch (strideDst.log2 - strideSrc.log2) {
            case 0: {
                asm.fldp(128, vectorLo, vectorHi, AArch64Address.createImmediateAddress(128, AArch64Address.AddressingMode.IMMEDIATE_PAIR_POST_INDEXED, arrayAddress, 32));
                break;
            }
            case 1: {
                asm.fldr(128, vectorLo, AArch64Address.createImmediateAddress(128, AArch64Address.AddressingMode.IMMEDIATE_POST_INDEXED, arrayAddress, 16));
                asm.neon.uxtl2VV(AArch64ASIMDAssembler.ElementSize.fromStride(strideSrc), vectorHi, vectorLo);
                asm.neon.uxtlVV(AArch64ASIMDAssembler.ElementSize.fromStride(strideSrc), vectorLo, vectorLo);
                break;
            }
            case 2: {
                asm.fldr(64, vectorLo, AArch64Address.createImmediateAddress(64, AArch64Address.AddressingMode.IMMEDIATE_POST_INDEXED, arrayAddress, 8));
                asm.neon.uxtlVV(AArch64ASIMDAssembler.ElementSize.fromStride(strideSrc), vectorLo, vectorLo);
                asm.neon.uxtl2VV(AArch64ASIMDAssembler.ElementSize.fromStride(strideSrc).expand(), vectorHi, vectorLo);
                asm.neon.uxtlVV(AArch64ASIMDAssembler.ElementSize.fromStride(strideSrc).expand(), vectorLo, vectorLo);
                break;
            }
            default: {
                throw GraalError.unimplemented("conversion from " + strideSrc + " to " + strideDst + " not implemented");
            }
        }
    }

    private static void calcReturnValue(AArch64MacroAssembler asm, Register ret, Register vecArrayA, Register vecArrayB, Register vecTmp, Register vecIndex, Register vecMask, Stride strideMax) {
        asm.neon.cmeqVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, AArch64ASIMDAssembler.ElementSize.fromStride(strideMax), vecTmp, vecArrayA, vecArrayB);
        asm.neon.bicVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, vecIndex, vecMask, vecTmp);
        asm.neon.orrVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, vecIndex, vecIndex, vecTmp);
        asm.neon.uminvSV(AArch64ASIMDAssembler.ASIMDSize.FullReg, AArch64ASIMDAssembler.ElementSize.fromStride(strideMax), vecIndex, vecIndex);
        if (strideMax == Stride.S4) {
            asm.neon.subVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, AArch64ASIMDAssembler.ElementSize.fromStride(strideMax), vecTmp, vecArrayA, vecArrayB);
            asm.neon.tblVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, vecTmp, vecTmp, vecIndex);
            asm.neon.moveFromIndex(AArch64ASIMDAssembler.ElementSize.Word, AArch64ASIMDAssembler.ElementSize.fromStride(strideMax), ret, vecTmp, 0);
        } else {
            asm.neon.tblVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, vecArrayA, vecArrayA, vecIndex);
            asm.neon.tblVVV(AArch64ASIMDAssembler.ASIMDSize.FullReg, vecArrayB, vecArrayB, vecIndex);
            asm.neon.usublVVV(AArch64ASIMDAssembler.ElementSize.fromStride(strideMax), vecTmp, vecArrayA, vecArrayB);
            asm.neon.moveFromIndex(AArch64ASIMDAssembler.ElementSize.Word, AArch64ASIMDAssembler.ElementSize.fromStride(strideMax).expand(), ret, vecTmp, 0);
        }
    }

    private static void tailLoad(AArch64MacroAssembler asm, Stride strideA, Stride strideB, Stride strideMax, Register arrayA, Register arrayB, Register len, Register vecArrayA1, Register vecArrayA2, Register vecArrayB1, Register vecArrayB2, Label entry, Label nextTail, int nBytes) {
        int bitsA = AArch64ArrayRegionCompareToOp.loadBits(strideA, strideMax, nBytes);
        int bitsB = AArch64ArrayRegionCompareToOp.loadBits(strideB, strideMax, nBytes);
        asm.bind(entry);
        asm.adds(64, len, len, nBytes >> strideMax.log2);
        asm.branchConditionally(AArch64Assembler.ConditionFlag.MI, nextTail);
        asm.fldr(bitsA, vecArrayA1, AArch64Address.createBaseRegisterOnlyAddress(bitsA, arrayA));
        asm.fldr(bitsB, vecArrayB1, AArch64Address.createBaseRegisterOnlyAddress(bitsB, arrayB));
        asm.add(64, arrayA, arrayA, len, AArch64Assembler.ShiftType.LSL, strideA.log2);
        asm.add(64, arrayB, arrayB, len, AArch64Assembler.ShiftType.LSL, strideB.log2);
        asm.fldr(bitsA, vecArrayA2, AArch64Address.createBaseRegisterOnlyAddress(bitsA, arrayA));
        asm.fldr(bitsB, vecArrayB2, AArch64Address.createBaseRegisterOnlyAddress(bitsB, arrayB));
    }

    private static void tailLoad1Vec(AArch64MacroAssembler asm, Stride strideA, Stride strideB, Stride strideMax, Register arrayA, Register arrayB, Register len, Register vecArrayA1, Register vecArrayA2, Register vecArrayB1, Register vecArrayB2, Register tmp, Register ret, Label entry, Label nextTail, Label tailLoaded, Label end, int nBytes) {
        assert (nBytes <= 8);
        int bitsA = AArch64ArrayRegionCompareToOp.loadBits(strideA, strideMax, nBytes);
        int bitsB = AArch64ArrayRegionCompareToOp.loadBits(strideB, strideMax, nBytes);
        if (strideMax.value < nBytes) {
            AArch64ArrayRegionCompareToOp.tailLoad(asm, strideA, strideB, strideMax, arrayA, arrayB, len, vecArrayA1, vecArrayA2, vecArrayB1, vecArrayB2, entry, nextTail, nBytes);
            asm.neon.insXX(AArch64ASIMDAssembler.ElementSize.fromSize(bitsA), vecArrayA1, 1, vecArrayA2, 0);
            asm.neon.insXX(AArch64ASIMDAssembler.ElementSize.fromSize(bitsB), vecArrayB1, 1, vecArrayB2, 0);
            if (strideA.value < strideMax.value) {
                AArch64ArrayRegionCompareToOp.tailExtend(asm, strideA, strideMax, vecArrayA1);
            } else if (strideB.value < strideMax.value) {
                AArch64ArrayRegionCompareToOp.tailExtend(asm, strideB, strideMax, vecArrayB1);
            }
            asm.jmp(tailLoaded);
        } else if (strideMax.value == nBytes) {
            asm.bind(entry);
            asm.mov(64, ret, AArch64.zr);
            asm.adds(64, len, len, nBytes >> strideMax.log2);
            asm.branchConditionally(AArch64Assembler.ConditionFlag.MI, end);
            asm.ldr(strideA.getBitCount(), tmp, AArch64Address.createBaseRegisterOnlyAddress(strideA.getBitCount(), arrayA));
            asm.ldr(strideB.getBitCount(), ret, AArch64Address.createBaseRegisterOnlyAddress(strideB.getBitCount(), arrayB));
            asm.sub(64, ret, tmp, ret);
            asm.jmp(end);
        }
    }

    private static void tailLoad2Vec(AArch64MacroAssembler asm, Stride strideA, Stride strideB, Stride strideMax, Register arrayA, Register arrayB, Register len, Register vecArrayA1, Register vecArrayA2, Register vecArrayB1, Register vecArrayB2, Label entry, Label nextTail, int nBytes) {
        AArch64ArrayRegionCompareToOp.tailLoad(asm, strideA, strideB, strideMax, arrayA, arrayB, len, vecArrayA1, vecArrayA2, vecArrayB1, vecArrayB2, entry, nextTail, nBytes);
        if (strideA.value < strideMax.value) {
            AArch64ArrayRegionCompareToOp.tailExtend(asm, strideA, strideMax, vecArrayA1);
            AArch64ArrayRegionCompareToOp.tailExtend(asm, strideA, strideMax, vecArrayA2);
        } else if (strideB.value < strideMax.value) {
            AArch64ArrayRegionCompareToOp.tailExtend(asm, strideB, strideMax, vecArrayB1);
            AArch64ArrayRegionCompareToOp.tailExtend(asm, strideB, strideMax, vecArrayB2);
        }
    }

    private static void tailExtend(AArch64MacroAssembler asm, Stride stride, Stride strideMax, Register vecArray) {
        switch (strideMax.log2 - stride.log2) {
            case 1: {
                asm.neon.uxtlVV(AArch64ASIMDAssembler.ElementSize.fromStride(stride), vecArray, vecArray);
                break;
            }
            case 2: {
                asm.neon.uxtlVV(AArch64ASIMDAssembler.ElementSize.fromStride(stride), vecArray, vecArray);
                asm.neon.uxtlVV(AArch64ASIMDAssembler.ElementSize.fromStride(stride), vecArray, vecArray);
                break;
            }
            default: {
                throw GraalError.shouldNotReachHere();
            }
        }
    }

    private static int loadBits(Stride strideA, Stride strideMax, int nBytes) {
        return nBytes << 3 >> strideMax.log2 - strideA.log2;
    }
}

