/*
 * Decompiled with CFR 0.152.
 */
package com.oracle.svm.core.graal.llvm.util;

import com.oracle.svm.core.FrameAccess;
import com.oracle.svm.core.graal.llvm.util.LLVMTargetSpecific;
import com.oracle.svm.core.util.VMError;
import java.lang.invoke.LambdaMetafactory;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.graalvm.compiler.core.common.NumUtil;

public class LLVMStackMapInfo {
    public static final long DEFAULT_PATCHPOINT_ID = 2882400000L;
    private Map<Long, Function> patchpointToFunction = new HashMap<Long, Function>();
    private Map<Long, Set<Record>> patchpointsByID = new HashMap<Long, Set<Record>>();
    private static final int STATEPOINT_HEADER_LOCATION_COUNT = 3;
    private static final int STATEPOINT_DEOPT_COUNT_LOCATION_INDEX = 2;

    LLVMStackMapInfo(ByteBuffer buffer) {
        int i;
        StackMap stackMap = new StackMap();
        int offset = 0;
        stackMap.version = buffer.get(offset);
        ++offset;
        ++offset;
        stackMap.functions = new Function[buffer.getInt(offset += 2)];
        stackMap.constants = new long[buffer.getInt(offset += 4)];
        int numRecords = buffer.getInt(offset += 4);
        offset += 4;
        long totalNumRecords = 0L;
        for (i = 0; i < stackMap.functions.length; ++i) {
            Function function = new Function();
            function.address = buffer.getLong(offset);
            function.stackSize = buffer.getLong(offset += 8);
            function.records = new Record[NumUtil.safeToInt((long)buffer.getLong(offset += 8))];
            offset += 8;
            stackMap.functions[i] = function;
            totalNumRecords += (long)function.records.length;
        }
        for (i = 0; i < stackMap.constants.length; ++i) {
            stackMap.constants[i] = buffer.getLong(offset);
            offset += 8;
        }
        int fun = 0;
        int rec = 0;
        assert ((long)numRecords == totalNumRecords);
        int i2 = 0;
        while (i2 < numRecords) {
            int j;
            while (rec == stackMap.functions[fun].records.length) {
                ++fun;
                rec = 0;
            }
            Function function = stackMap.functions[fun];
            Record record = new Record();
            record.patchpointID = buffer.getLong(offset);
            record.instructionOffset = buffer.getInt(offset += 8);
            record.flags = buffer.getShort(offset += 4);
            record.locations = new Location[buffer.getShort(offset += 2)];
            offset += 2;
            for (j = 0; j < record.locations.length; ++j) {
                Location location = new Location();
                location.type = Location.Type.decode(buffer.get(offset));
                ++offset;
                location.size = buffer.getShort(++offset);
                location.regNum = buffer.getShort(offset += 2);
                offset += 2;
                location.offset = buffer.getInt(offset += 2);
                offset += 4;
                record.locations[j] = location;
            }
            if (offset % 8 != 0) {
                offset += 4;
            }
            record.liveOuts = new LiveOut[buffer.getShort(offset += 2)];
            offset += 2;
            for (j = 0; j < record.liveOuts.length; ++j) {
                LiveOut liveOut = new LiveOut();
                liveOut.regNum = buffer.getShort(offset);
                offset += 2;
                liveOut.size = buffer.get(++offset);
                ++offset;
                record.liveOuts[j] = liveOut;
            }
            if (offset % 8 != 0) {
                offset += 4;
            }
            function.records[rec] = record;
            if (this.patchpointToFunction.containsKey(record.patchpointID)) assert (record.patchpointID == 2882400000L || this.patchpointToFunction.get(record.patchpointID) == function);
            this.patchpointToFunction.put(record.patchpointID, function);
            this.patchpointsByID.computeIfAbsent(record.patchpointID, v -> new HashSet()).add(record);
            ++i2;
            ++rec;
        }
    }

    long getFunctionStackSize(long startPatchpointID) {
        assert (this.patchpointToFunction.containsKey(startPatchpointID));
        return this.patchpointToFunction.get((Object)Long.valueOf((long)startPatchpointID)).stackSize;
    }

    private long getFunctionOffset(long startPatchpointID) {
        assert (this.patchpointToFunction.containsKey(startPatchpointID));
        return this.patchpointToFunction.get((Object)Long.valueOf((long)startPatchpointID)).address;
    }

    int[] getPatchpointOffsets(long patchpointID) {
        if (this.patchpointsByID.containsKey(patchpointID)) {
            return this.patchpointsByID.get(patchpointID).stream().mapToInt(r -> r.instructionOffset).toArray();
        }
        return new int[0];
    }

    void forEachStatepointOffset(long patchpointID, int instructionOffset, StatepointOffsetCallback callback) {
        Location[] locations = this.patchpointsByID.get((Object)Long.valueOf((long)patchpointID)).stream().filter((Predicate<Record>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;)Z, lambda$forEachStatepointOffset$2(int com.oracle.svm.core.graal.llvm.util.LLVMStackMapInfo$Record ), (Lcom/oracle/svm/core/graal/llvm/util/LLVMStackMapInfo$Record;)Z)((int)instructionOffset)).findFirst().orElseThrow((Supplier<RuntimeException>)LambdaMetafactory.metafactory(null, null, null, ()Ljava/lang/Object;, shouldNotReachHere(), ()Ljava/lang/RuntimeException;)()).locations;
        assert (locations.length >= 3);
        Location deoptCountLocation = locations[2];
        assert (deoptCountLocation.type == Location.Type.Constant);
        int deoptCount = deoptCountLocation.offset;
        assert (3 + deoptCount <= locations.length);
        HashSet<Integer> compressedOffsets = new HashSet<Integer>();
        for (int i = 3; i < 3 + deoptCount; ++i) {
            Location loc = locations[i];
            assert (loc.type == Location.Type.Indirect);
            int[] offsets = this.getStackOffsets(patchpointID, loc);
            assert (offsets.length == 1);
            compressedOffsets.add(offsets[0]);
        }
        HashSet<Integer> seenOffsets = new HashSet<Integer>();
        HashSet<Integer> seenBases = new HashSet<Integer>();
        for (int i = 3 + deoptCount; i < locations.length; i += 2) {
            assert (i + 1 < locations.length);
            Location base = locations[i];
            Location ref = locations[i + 1];
            if (base.type == Location.Type.Constant || ref.type == Location.Type.Constant) {
                assert (base.type == ref.type && base.offset == ref.offset);
                if (base.offset == 0) continue;
                seenBases.add((int)((long)base.offset - this.getFunctionOffset(patchpointID)));
                seenOffsets.add((int)((long)ref.offset - this.getFunctionOffset(patchpointID)));
                continue;
            }
            assert (base.type == Location.Type.Indirect);
            int[] baseOffsets = this.getStackOffsets(patchpointID, base);
            assert (ref.type == Location.Type.Indirect);
            int[] derivedOffsets = this.getStackOffsets(patchpointID, ref);
            assert (baseOffsets.length == derivedOffsets.length);
            for (int j = 0; j < baseOffsets.length; ++j) {
                int baseOffset = baseOffsets[j];
                int derivedOffset = derivedOffsets[j];
                seenBases.add(baseOffset);
                if (seenOffsets.contains(derivedOffset)) continue;
                seenOffsets.add(derivedOffset);
                assert (compressedOffsets.contains(derivedOffset) == compressedOffsets.contains(baseOffset));
                callback.accept(derivedOffset, baseOffset, compressedOffsets.contains(derivedOffset));
            }
        }
        assert (seenOffsets.containsAll(seenBases));
    }

    public int getAllocaOffset(long startPatchPointId) {
        Set<Record> startRecords = this.patchpointsByID.get(startPatchPointId);
        assert (startRecords.size() == 1);
        Record startRecord = (Record)startRecords.stream().findAny().orElseThrow(VMError::shouldNotReachHere);
        assert (startRecord.locations.length == 1);
        Location alloca = startRecord.locations[0];
        assert (alloca.type == Location.Type.Direct);
        int[] offsets = this.getStackOffsets(startPatchPointId, alloca);
        assert (offsets.length == 1);
        return offsets[0];
    }

    private int[] getStackOffsets(long patchpointID, Location location) {
        int baseOffset;
        assert (location.size % FrameAccess.wordSize() == 0);
        int numLocations = location.size / FrameAccess.wordSize();
        assert (numLocations > 0);
        if (location.regNum == LLVMTargetSpecific.get().getStackPointerDwarfRegNum()) {
            baseOffset = location.offset;
        } else if (location.regNum == LLVMTargetSpecific.get().getFramePointerDwarfRegNum()) {
            baseOffset = location.offset + NumUtil.safeToInt((long)this.getFunctionStackSize(patchpointID)) + LLVMTargetSpecific.get().getFramePointerOffset();
        } else {
            throw VMError.shouldNotReachHere((String)("found other register " + patchpointID + " " + location.regNum));
        }
        assert (baseOffset >= 0 && (long)(baseOffset + location.size) < this.getFunctionStackSize(patchpointID));
        int[] offsets = new int[numLocations];
        for (int i = 0; i < numLocations; ++i) {
            offsets[i] = baseOffset + i * FrameAccess.wordSize();
        }
        return offsets;
    }

    private static /* synthetic */ boolean lambda$forEachStatepointOffset$2(int instructionOffset, Record r) {
        return r.instructionOffset == instructionOffset;
    }

    @FunctionalInterface
    public static interface StatepointOffsetCallback {
        public void accept(int var1, int var2, boolean var3);
    }

    static class LiveOut {
        short regNum;
        byte size;

        LiveOut() {
        }
    }

    static class Location {
        Type type;
        short size;
        short regNum;
        int offset;

        Location() {
        }

        static enum Type {
            Register(1),
            Direct(2),
            Indirect(3),
            Constant(4),
            ConstantIndex(5);

            private final byte encoding;

            private Type(int encoding) {
                this.encoding = (byte)encoding;
            }

            static Type decode(byte encoding) {
                for (Type type : Type.values()) {
                    if (type.encoding != encoding) continue;
                    return type;
                }
                return null;
            }
        }
    }

    static class Record {
        long patchpointID;
        int instructionOffset;
        short flags;
        Location[] locations;
        LiveOut[] liveOuts;

        Record() {
        }
    }

    static class Function {
        long address;
        long stackSize;
        Record[] records;

        Function() {
        }
    }

    static class StackMap {
        byte version;
        Function[] functions;
        long[] constants;

        StackMap() {
        }
    }
}

