/*
 * Decompiled with CFR 0.152.
 */
package org.graalvm.compiler.lir.amd64;

import java.util.Arrays;
import java.util.EnumSet;
import jdk.vm.ci.amd64.AMD64;
import jdk.vm.ci.code.Register;
import jdk.vm.ci.code.RegisterValue;
import jdk.vm.ci.code.TargetDescription;
import jdk.vm.ci.code.ValueUtil;
import jdk.vm.ci.meta.AllocatableValue;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.Value;
import org.graalvm.compiler.asm.Label;
import org.graalvm.compiler.asm.amd64.AMD64Address;
import org.graalvm.compiler.asm.amd64.AMD64Assembler;
import org.graalvm.compiler.asm.amd64.AMD64MacroAssembler;
import org.graalvm.compiler.asm.amd64.AVXKind;
import org.graalvm.compiler.core.common.Stride;
import org.graalvm.compiler.core.common.StrideUtil;
import org.graalvm.compiler.debug.GraalError;
import org.graalvm.compiler.lir.LIRInstruction;
import org.graalvm.compiler.lir.LIRInstructionClass;
import org.graalvm.compiler.lir.Opcode;
import org.graalvm.compiler.lir.amd64.AMD64ComplexVectorOp;
import org.graalvm.compiler.lir.amd64.AMD64ControlFlow;
import org.graalvm.compiler.lir.asm.CompilationResultBuilder;
import org.graalvm.compiler.lir.gen.LIRGeneratorTool;

@Opcode(value="ARRAY_REGION_COMPARE")
public final class AMD64ArrayRegionCompareToOp
extends AMD64ComplexVectorOp {
    public static final LIRInstructionClass<AMD64ArrayRegionCompareToOp> TYPE = LIRInstructionClass.create(AMD64ArrayRegionCompareToOp.class);
    private static final Register REG_ARRAY_A = AMD64.rsi;
    private static final Register REG_OFFSET_A = AMD64.rax;
    private static final Register REG_ARRAY_B = AMD64.rdi;
    private static final Register REG_OFFSET_B = AMD64.rcx;
    private static final Register REG_LENGTH = AMD64.rdx;
    private static final Register REG_STRIDE = AMD64.r8;
    private static final int ONES_16 = 65535;
    private static final int ONES_32 = -1;
    private final Stride argStrideA;
    private final Stride argStrideB;
    private final AMD64MacroAssembler.ExtendMode extendMode;
    @LIRInstruction.Def(value={LIRInstruction.OperandFlag.REG})
    private Value resultValue;
    @LIRInstruction.Use(value={LIRInstruction.OperandFlag.REG})
    private Value arrayAValue;
    @LIRInstruction.Use(value={LIRInstruction.OperandFlag.REG})
    private Value offsetAValue;
    @LIRInstruction.Use(value={LIRInstruction.OperandFlag.REG})
    private Value arrayBValue;
    @LIRInstruction.Use(value={LIRInstruction.OperandFlag.REG, LIRInstruction.OperandFlag.ILLEGAL})
    private Value offsetBValue;
    @LIRInstruction.Use(value={LIRInstruction.OperandFlag.REG})
    private Value lengthValue;
    @LIRInstruction.Use(value={LIRInstruction.OperandFlag.REG, LIRInstruction.OperandFlag.ILLEGAL})
    private Value dynamicStridesValue;
    @LIRInstruction.Temp(value={LIRInstruction.OperandFlag.REG})
    private Value arrayAValueTemp;
    @LIRInstruction.Temp(value={LIRInstruction.OperandFlag.REG})
    private Value offsetAValueTemp;
    @LIRInstruction.Temp(value={LIRInstruction.OperandFlag.REG})
    private Value arrayBValueTemp;
    @LIRInstruction.Temp(value={LIRInstruction.OperandFlag.REG, LIRInstruction.OperandFlag.ILLEGAL})
    private Value offsetBValueTemp;
    @LIRInstruction.Temp(value={LIRInstruction.OperandFlag.REG})
    private Value lengthValueTemp;
    @LIRInstruction.Temp(value={LIRInstruction.OperandFlag.REG, LIRInstruction.OperandFlag.ILLEGAL})
    private Value dynamicStridesValueTemp;
    @LIRInstruction.Temp(value={LIRInstruction.OperandFlag.REG})
    Value[] vectorTemp;

    private AMD64ArrayRegionCompareToOp(LIRGeneratorTool tool, Stride strideA, Stride strideB, EnumSet<AMD64.CPUFeature> runtimeCheckedCPUFeatures, Value result, Value arrayA, Value offsetA, Value arrayB, Value offsetB, Value length, Value dynamicStrides, AMD64MacroAssembler.ExtendMode extendMode) {
        super(TYPE, tool, runtimeCheckedCPUFeatures, AVXKind.AVXSize.YMM);
        this.extendMode = extendMode;
        if (strideA == null) {
            this.argStrideA = null;
            this.argStrideB = null;
        } else {
            GraalError.guarantee(strideA.value <= 4, "unsupported strideA");
            GraalError.guarantee(strideB.value <= 4, "unsupported strideB");
            this.argStrideA = strideA;
            this.argStrideB = strideB;
        }
        this.resultValue = result;
        this.arrayAValue = this.arrayAValueTemp = arrayA;
        this.offsetAValue = this.offsetAValueTemp = offsetA;
        this.arrayBValue = this.arrayBValueTemp = arrayB;
        this.offsetBValue = this.offsetBValueTemp = offsetB;
        this.lengthValue = this.lengthValueTemp = length;
        this.dynamicStridesValue = this.dynamicStridesValueTemp = dynamicStrides;
        this.vectorTemp = this.allocateVectorRegisters(tool, JavaKind.Byte, AMD64ArrayRegionCompareToOp.isVectorCompareSupported(tool.target(), runtimeCheckedCPUFeatures, this.argStrideA, this.argStrideB) ? 4 : 0);
    }

    public static AMD64ArrayRegionCompareToOp movParamsAndCreate(LIRGeneratorTool tool, Stride strideA, Stride strideB, EnumSet<AMD64.CPUFeature> runtimeCheckedCPUFeatures, Value result, Value arrayA, Value offsetA, Value arrayB, Value offsetB, Value length, Value dynamicStrides, AMD64MacroAssembler.ExtendMode extendMode) {
        RegisterValue regArrayA = REG_ARRAY_A.asValue(arrayA.getValueKind());
        RegisterValue regOffsetA = REG_OFFSET_A.asValue(offsetA.getValueKind());
        RegisterValue regArrayB = REG_ARRAY_B.asValue(arrayB.getValueKind());
        RegisterValue regOffsetB = REG_OFFSET_B.asValue(offsetB.getValueKind());
        RegisterValue regLength = REG_LENGTH.asValue(length.getValueKind());
        AllocatableValue regStride = dynamicStrides == null ? Value.ILLEGAL : REG_STRIDE.asValue(length.getValueKind());
        tool.emitConvertNullToZero((AllocatableValue)regArrayA, arrayA);
        tool.emitMove((AllocatableValue)regOffsetA, offsetA);
        tool.emitConvertNullToZero((AllocatableValue)regArrayB, arrayB);
        tool.emitMove((AllocatableValue)regOffsetB, offsetB);
        tool.emitMove((AllocatableValue)regLength, length);
        if (dynamicStrides != null) {
            tool.emitMove(regStride, dynamicStrides);
        }
        return new AMD64ArrayRegionCompareToOp(tool, strideA, strideB, runtimeCheckedCPUFeatures, result, (Value)regArrayA, (Value)regOffsetA, (Value)regArrayB, (Value)regOffsetB, (Value)regLength, (Value)regStride, extendMode);
    }

    @Override
    public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
        Register result = ValueUtil.asRegister((Value)this.resultValue);
        Register arrayA = ValueUtil.asRegister((Value)this.arrayAValue);
        Register arrayB = ValueUtil.asRegister((Value)this.arrayBValue);
        Register length = ValueUtil.asRegister((Value)this.lengthValue);
        Register tmp1 = ValueUtil.asRegister((Value)this.offsetAValue);
        Register tmp2 = ValueUtil.asRegister((Value)this.offsetBValue);
        masm.leaq(arrayA, new AMD64Address(arrayA, ValueUtil.asRegister((Value)this.offsetAValue), Stride.S1));
        masm.leaq(arrayB, new AMD64Address(arrayB, ValueUtil.asRegister((Value)this.offsetBValue), Stride.S1));
        if (ValueUtil.isIllegal((Value)this.dynamicStridesValue)) {
            this.emitArrayCompare(crb, masm, this.argStrideA, this.argStrideB, result, arrayA, arrayB, length, tmp1, tmp2);
        } else {
            masm.xorq(tmp2, tmp2);
            Label[] variants = new Label[9];
            Label done = new Label();
            for (int i = 0; i < variants.length; ++i) {
                variants[i] = new Label();
            }
            AMD64ControlFlow.RangeTableSwitchOp.emitJumpTable(crb, masm, tmp1, ValueUtil.asRegister((Value)this.dynamicStridesValue), 0, 8, Arrays.stream(variants));
            for (Stride strideA : new Stride[]{Stride.S1, Stride.S2, Stride.S4}) {
                for (Stride strideB : new Stride[]{Stride.S1, Stride.S2, Stride.S4}) {
                    masm.align(this.preferredBranchTargetAlignment(crb));
                    masm.bind(variants[StrideUtil.getDirectStubCallIndex(strideA, strideB)]);
                    this.emitArrayCompare(crb, masm, strideA, strideB, result, arrayA, arrayB, length, tmp1, tmp2);
                    masm.jmp(done);
                }
            }
            masm.bind(done);
        }
    }

    private static boolean isVectorCompareSupported(TargetDescription target, EnumSet<AMD64.CPUFeature> runtimeCheckedCPUFeatures, Stride strideA, Stride strideB) {
        return strideA == strideB || AMD64ArrayRegionCompareToOp.supports(target, runtimeCheckedCPUFeatures, AMD64.CPUFeature.SSE4_1, new AMD64.CPUFeature[0]);
    }

    private void emitArrayCompare(CompilationResultBuilder crb, AMD64MacroAssembler masm, Stride strideA, Stride strideB, Register result, Register arrayA, Register arrayB, Register length, Register tmp1, Register tmp2) {
        Label returnLabel = new Label();
        if (AMD64ArrayRegionCompareToOp.isVectorCompareSupported(crb.target, this.runtimeCheckedCPUFeatures, strideA, strideB)) {
            this.emitVectorLoop(crb, masm, strideA, strideB, result, arrayA, arrayB, length, tmp1, tmp2, returnLabel);
        }
        this.emitScalarLoop(crb, masm, strideA, strideB, result, arrayA, arrayB, length, tmp1, returnLabel);
        masm.bind(returnLabel);
    }

    private void emitVectorLoop(CompilationResultBuilder crb, AMD64MacroAssembler masm, Stride strideA, Stride strideB, Register result, Register arrayA, Register arrayB, Register length, Register tmp1, Register tmp2, Label returnLabel) {
        Stride maxStride = Stride.max(strideA, strideB);
        Register vector1 = ValueUtil.asRegister((Value)this.vectorTemp[0]);
        Register vector2 = ValueUtil.asRegister((Value)this.vectorTemp[1]);
        Register vector3 = ValueUtil.asRegister((Value)this.vectorTemp[2]);
        Register vector4 = ValueUtil.asRegister((Value)this.vectorTemp[3]);
        int elementsPerVector = AMD64ArrayRegionCompareToOp.getElementsPerVector(this.vectorSize, maxStride);
        Label loop = new Label();
        Label qwordTail = new Label();
        Label scalarTail = new Label();
        Label tail = new Label();
        Label diffFound = new Label();
        masm.movSZx(strideA, this.extendMode, result, new AMD64Address(arrayA));
        masm.movSZx(strideB, this.extendMode, tmp1, new AMD64Address(arrayB));
        masm.subqAndJcc(result, tmp1, AMD64Assembler.ConditionFlag.NotZero, returnLabel, false);
        masm.movl(result, length);
        masm.andl(result, elementsPerVector - 1);
        masm.andlAndJcc(length, -elementsPerVector, AMD64Assembler.ConditionFlag.Zero, tail, false);
        masm.leaq(arrayA, new AMD64Address(arrayA, length, strideA));
        masm.leaq(arrayB, new AMD64Address(arrayB, length, strideB));
        masm.negq(length);
        masm.align(this.preferredLoopAlignment(crb));
        masm.bind(loop);
        masm.pmovSZx(this.vectorSize, this.extendMode, vector1, maxStride, arrayA, strideA, length, 0);
        masm.pmovSZx(this.vectorSize, this.extendMode, vector2, maxStride, arrayB, strideB, length, 0);
        masm.pcmpeq(this.vectorSize, maxStride, vector1, vector2);
        masm.pmovmsk(this.vectorSize, tmp1, vector1);
        masm.xorlAndJcc(tmp1, this.vectorSize == AVXKind.AVXSize.XMM ? 65535 : -1, AMD64Assembler.ConditionFlag.NotZero, diffFound, true);
        masm.addqAndJcc(length, elementsPerVector, AMD64Assembler.ConditionFlag.NotZero, loop, true);
        masm.testlAndJcc(result, result, AMD64Assembler.ConditionFlag.Zero, returnLabel, false);
        masm.pmovSZx(this.vectorSize, this.extendMode, vector1, maxStride, arrayA, strideA, result, -this.vectorSize.getBytes());
        masm.pmovSZx(this.vectorSize, this.extendMode, vector2, maxStride, arrayB, strideB, result, -this.vectorSize.getBytes());
        masm.leaq(length, new AMD64Address(length, result, Stride.S1, -elementsPerVector));
        masm.pcmpeq(this.vectorSize, maxStride, vector1, vector2);
        masm.pmovmsk(this.vectorSize, tmp1, vector1);
        masm.xorlAndJcc(tmp1, this.vectorSize == AVXKind.AVXSize.XMM ? 65535 : -1, AMD64Assembler.ConditionFlag.NotZero, diffFound, true);
        masm.xorq(result, result);
        masm.jmp(returnLabel);
        masm.bind(diffFound);
        this.bsfq(masm, tmp2, tmp1);
        if (maxStride.value > 1) {
            masm.shrq(tmp2, maxStride.log2);
        }
        masm.addq(tmp2, length);
        masm.movSZx(strideA, this.extendMode, result, new AMD64Address(arrayA, tmp2, strideA));
        masm.movSZx(strideB, this.extendMode, tmp1, new AMD64Address(arrayB, tmp2, strideB));
        masm.subq(result, tmp1);
        masm.jmp(returnLabel);
        boolean canUseQWORD = maxStride != Stride.S4 || Stride.min(strideA, strideB) != Stride.S1;
        masm.bind(tail);
        masm.movl(length, result);
        if (this.supportsAVX2AndYMM()) {
            this.emitVectorizedTail(masm, strideA, strideB, result, arrayA, arrayB, length, tmp1, tmp2, returnLabel, maxStride, vector1, vector2, vector3, vector4, canUseQWORD ? qwordTail : scalarTail, AVXKind.AVXSize.XMM, AVXKind.AVXSize.YMM);
        }
        if (canUseQWORD) {
            masm.bind(qwordTail);
            this.emitVectorizedTail(masm, strideA, strideB, result, arrayA, arrayB, length, tmp1, tmp2, returnLabel, maxStride, vector1, vector2, vector3, vector4, scalarTail, AVXKind.AVXSize.QWORD, AVXKind.AVXSize.XMM);
        }
        masm.bind(scalarTail);
    }

    private void emitVectorizedTail(AMD64MacroAssembler masm, Stride strideA, Stride strideB, Register result, Register arrayA, Register arrayB, Register length, Register tmp1, Register tmp2, Label returnLabel, Stride maxStride, Register vector1, Register vector2, Register vector3, Register vector4, Label nextTail, AVXKind.AVXSize loadSize, AVXKind.AVXSize cmpSize) {
        assert (cmpSize.getBytes() == loadSize.getBytes() * 2);
        assert (cmpSize == AVXKind.AVXSize.YMM || cmpSize == AVXKind.AVXSize.XMM);
        masm.cmplAndJcc(length, AMD64ArrayRegionCompareToOp.getElementsPerVector(loadSize, maxStride), AMD64Assembler.ConditionFlag.Less, nextTail, false);
        if (loadSize == AVXKind.AVXSize.QWORD) {
            masm.pmovSZxQWORD(this.extendMode, vector1, maxStride, arrayA, strideA, Register.None, 0);
            masm.pmovSZxQWORD(this.extendMode, vector2, maxStride, arrayB, strideB, Register.None, 0);
            masm.pmovSZxQWORD(this.extendMode, vector3, maxStride, arrayA, strideA, length, -loadSize.getBytes());
            masm.pmovSZxQWORD(this.extendMode, vector4, maxStride, arrayB, strideB, length, -loadSize.getBytes());
        } else {
            masm.pmovSZx(loadSize, this.extendMode, vector1, maxStride, arrayA, strideA, Register.None, 0);
            masm.pmovSZx(loadSize, this.extendMode, vector2, maxStride, arrayB, strideB, Register.None, 0);
            masm.pmovSZx(loadSize, this.extendMode, vector3, maxStride, arrayA, strideA, length, -loadSize.getBytes());
            masm.pmovSZx(loadSize, this.extendMode, vector4, maxStride, arrayB, strideB, length, -loadSize.getBytes());
        }
        if (cmpSize == AVXKind.AVXSize.YMM) {
            AMD64Assembler.VexRVMIOp.VPERM2I128.emit((AMD64Assembler)masm, cmpSize, vector1, vector3, vector1, 2);
            AMD64Assembler.VexRVMIOp.VPERM2I128.emit((AMD64Assembler)masm, cmpSize, vector2, vector4, vector2, 2);
        } else {
            masm.movlhps(vector1, vector3);
            masm.movlhps(vector2, vector4);
        }
        masm.pcmpeq(cmpSize, maxStride, vector1, vector2);
        masm.pmovmsk(cmpSize, result, vector1);
        masm.xorlAndJcc(result, cmpSize == AVXKind.AVXSize.XMM ? 65535 : -1, AMD64Assembler.ConditionFlag.Zero, returnLabel, false);
        this.bsfq(masm, tmp2, result);
        if (maxStride.value > 1) {
            masm.shrq(tmp2, maxStride.log2);
        }
        masm.leaq(tmp1, new AMD64Address(tmp2, length, Stride.S1, -AMD64ArrayRegionCompareToOp.getElementsPerVector(cmpSize, maxStride)));
        masm.cmpq(tmp2, AMD64ArrayRegionCompareToOp.getElementsPerVector(loadSize, maxStride));
        masm.cmovq(AMD64Assembler.ConditionFlag.Greater, tmp2, tmp1);
        masm.movSZx(strideA, this.extendMode, result, new AMD64Address(arrayA, tmp2, strideA));
        masm.movSZx(strideB, this.extendMode, tmp1, new AMD64Address(arrayB, tmp2, strideB));
        masm.subq(result, tmp1);
        masm.jmp(returnLabel);
    }

    private static int getElementsPerVector(AVXKind.AVXSize vSize, Stride maxStride) {
        return vSize.getBytes() >> maxStride.log2;
    }

    private void emitScalarLoop(CompilationResultBuilder crb, AMD64MacroAssembler masm, Stride strideA, Stride strideB, Register result, Register arrayA, Register arrayB, Register length, Register tmp, Label returnLabel) {
        Label loop = new Label();
        masm.leaq(arrayA, new AMD64Address(arrayA, length, strideA));
        masm.leaq(arrayB, new AMD64Address(arrayB, length, strideB));
        masm.negq(length);
        masm.align(this.preferredLoopAlignment(crb));
        masm.bind(loop);
        masm.movSZx(strideA, this.extendMode, result, new AMD64Address(arrayA, length, strideA));
        masm.movSZx(strideB, this.extendMode, tmp, new AMD64Address(arrayB, length, strideB));
        masm.subqAndJcc(result, tmp, AMD64Assembler.ConditionFlag.NotZero, returnLabel, true);
        masm.incqAndJcc(length, AMD64Assembler.ConditionFlag.NotZero, loop, true);
    }
}

