/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.quic.cid;

import java.security.SecureRandom;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import net.luminis.quic.EncryptionLevel;
import net.luminis.quic.QuicConstants;
import net.luminis.quic.Version;
import net.luminis.quic.cid.ConnectionIdInfo;
import net.luminis.quic.cid.DestinationConnectionIdRegistry;
import net.luminis.quic.cid.SourceConnectionIdRegistry;
import net.luminis.quic.frame.NewConnectionIdFrame;
import net.luminis.quic.frame.QuicFrame;
import net.luminis.quic.frame.RetireConnectionIdFrame;
import net.luminis.quic.log.Logger;
import net.luminis.quic.send.Sender;
import net.luminis.quic.server.ServerConnectionProxy;
import net.luminis.quic.server.ServerConnectionRegistry;

public class ConnectionIdManager {
    public static final int MAX_CIDS_PER_CONNECTION = 6;
    private final int connectionIdLength;
    private final ServerConnectionRegistry connectionRegistry;
    private final Sender sender;
    private final BiConsumer<Integer, String> closeConnectionCallback;
    private final SourceConnectionIdRegistry cidRegistry;
    private final DestinationConnectionIdRegistry peerCidRegistry;
    private final byte[] initialConnectionId;
    private final byte[] initialPeerConnectionId;
    private final byte[] originalDestinationConnectionId;
    private volatile int maxCids = 2;
    private volatile int maxPeerCids;
    private volatile byte[] retrySourceCid;
    private final Version quicVersion = Version.QUIC_version_1;

    public ConnectionIdManager(byte[] initialClientCid, byte[] originalDestinationConnectionId, int connectionIdLength, int maxPeerCids, ServerConnectionRegistry connectionRegistry, Sender sender, BiConsumer<Integer, String> closeConnectionCallback, Logger log) {
        this.originalDestinationConnectionId = originalDestinationConnectionId;
        this.connectionIdLength = connectionIdLength;
        this.maxPeerCids = maxPeerCids;
        this.connectionRegistry = connectionRegistry;
        this.sender = sender;
        this.closeConnectionCallback = closeConnectionCallback;
        this.cidRegistry = new SourceConnectionIdRegistry(connectionIdLength, log);
        this.initialConnectionId = this.cidRegistry.currentConnectionId;
        if (initialClientCid != null && initialClientCid.length != 0) {
            this.peerCidRegistry = new DestinationConnectionIdRegistry(initialClientCid, log);
            this.initialPeerConnectionId = initialClientCid;
        } else {
            this.peerCidRegistry = null;
            this.initialPeerConnectionId = new byte[0];
        }
    }

    public ConnectionIdManager(Integer connectionIdLength, int maxPeerCids, Sender sender, BiConsumer<Integer, String> closeConnectionCallback, Logger log) {
        this.maxPeerCids = maxPeerCids;
        this.sender = sender;
        this.cidRegistry = new SourceConnectionIdRegistry(connectionIdLength, log);
        this.connectionIdLength = this.cidRegistry.getConnectionIdlength();
        this.initialConnectionId = this.cidRegistry.getCurrent();
        this.closeConnectionCallback = closeConnectionCallback;
        this.originalDestinationConnectionId = new byte[8];
        new SecureRandom().nextBytes(this.originalDestinationConnectionId);
        this.peerCidRegistry = new DestinationConnectionIdRegistry(this.originalDestinationConnectionId, log);
        this.initialPeerConnectionId = this.originalDestinationConnectionId;
        this.connectionRegistry = new ServerConnectionRegistry(){

            @Override
            public void registerConnection(ServerConnectionProxy connection, byte[] connectionId) {
            }

            @Override
            public void deregisterConnection(ServerConnectionProxy connection, byte[] connectionId) {
            }

            @Override
            public void registerAdditionalConnectionId(byte[] currentConnectionId, byte[] newConnectionId) {
            }

            @Override
            public void deregisterConnectionId(byte[] connectionId) {
            }
        };
    }

    public void handshakeFinished() {
        for (int i = 1; i < this.maxCids; ++i) {
            this.sendNewCid(0);
        }
    }

    public void process(NewConnectionIdFrame frame) {
        if (this.peerCidRegistry == null) {
            this.closeConnectionCallback.accept(Integer.valueOf(QuicConstants.TransportErrorCode.PROTOCOL_VIOLATION.value), "new connection id frame not allowed when using zero-length connection ID");
            return;
        }
        if (frame.getRetirePriorTo() > frame.getSequenceNr()) {
            this.closeConnectionCallback.accept(Integer.valueOf(QuicConstants.TransportErrorCode.FRAME_ENCODING_ERROR.value), "exceeding active connection id limit");
            return;
        }
        if (!this.peerCidRegistry.connectionIds.containsKey(frame.getSequenceNr())) {
            boolean added = this.peerCidRegistry.registerNewConnectionId(frame.getSequenceNr(), frame.getConnectionId(), frame.getStatelessResetToken());
            if (!added) {
                this.sendRetireCid(frame.getSequenceNr());
            }
        } else if (!Arrays.equals(((ConnectionIdInfo)this.peerCidRegistry.connectionIds.get(frame.getSequenceNr())).getConnectionId(), frame.getConnectionId())) {
            this.closeConnectionCallback.accept(Integer.valueOf(QuicConstants.TransportErrorCode.PROTOCOL_VIOLATION.value), "different cids or same sequence number");
            return;
        }
        if (frame.getRetirePriorTo() > 0) {
            List<Integer> retired = this.peerCidRegistry.retireAllBefore(frame.getRetirePriorTo());
            retired.forEach(seqNr -> this.sendRetireCid((Integer)seqNr));
        }
        if (this.peerCidRegistry.getActiveConnectionIds().size() > this.maxPeerCids) {
            this.closeConnectionCallback.accept(Integer.valueOf(QuicConstants.TransportErrorCode.CONNECTION_ID_LIMIT_ERROR.value), "exceeding active connection id limit");
            return;
        }
    }

    public void process(RetireConnectionIdFrame frame, byte[] destinationConnectionId) {
        if (frame.getSequenceNr() > this.cidRegistry.getMaxSequenceNr()) {
            this.closeConnectionCallback.accept(Integer.valueOf(QuicConstants.TransportErrorCode.PROTOCOL_VIOLATION.value), "invalid connection ID sequence number");
            return;
        }
        int sequenceNr = frame.getSequenceNr();
        if (Arrays.equals(this.cidRegistry.get(sequenceNr), destinationConnectionId)) {
            this.closeConnectionCallback.accept(Integer.valueOf(QuicConstants.TransportErrorCode.PROTOCOL_VIOLATION.value), "cannot retire current connection ID");
            return;
        }
        byte[] retiredCid = this.cidRegistry.retireConnectionId(sequenceNr);
        if (retiredCid != null) {
            this.connectionRegistry.deregisterConnectionId(retiredCid);
            if (this.cidRegistry.getActiveConnectionIds().size() < this.maxCids) {
                this.sendNewCid(0);
            }
        }
    }

    public void setMaxPeerConnectionIds(int maxPeerCids) {
        this.maxPeerCids = maxPeerCids;
    }

    public void registerPeerCidLimit(int peerCidLimit) {
        this.maxCids = Integer.min(peerCidLimit, 6);
    }

    private ConnectionIdInfo sendNewCid(int retirePriorTo) {
        ConnectionIdInfo cidInfo = this.cidRegistry.generateNew();
        this.connectionRegistry.registerAdditionalConnectionId(this.cidRegistry.getActive(), cidInfo.getConnectionId());
        this.sender.send(new NewConnectionIdFrame(this.quicVersion, cidInfo.getSequenceNumber(), retirePriorTo, cidInfo.getConnectionId()), EncryptionLevel.App, this::retransmitFrame);
        return cidInfo;
    }

    private void retransmitFrame(QuicFrame frame) {
        this.sender.send(frame, EncryptionLevel.App, this::retransmitFrame);
    }

    private void sendRetireCid(Integer seqNr) {
        this.sender.send(new RetireConnectionIdFrame(this.quicVersion, seqNr), EncryptionLevel.App, this::retransmitFrame);
    }

    public List<byte[]> getActiveConnectionIds() {
        return this.cidRegistry.getActiveConnectionIds();
    }

    public List<byte[]> getActivePeerConnectionIds() {
        if (this.peerCidRegistry != null) {
            return this.peerCidRegistry.getActiveConnectionIds();
        }
        return List.of(new byte[0]);
    }

    public byte[] getCurrentPeerConnectionId() {
        if (this.peerCidRegistry != null) {
            return this.peerCidRegistry.getCurrent();
        }
        return new byte[0];
    }

    public byte[] getInitialConnectionId() {
        return this.initialConnectionId;
    }

    public byte[] getOriginalDestinationConnectionId() {
        return this.originalDestinationConnectionId;
    }

    public boolean validateInitialPeerConnectionId(byte[] connectionId) {
        return Arrays.equals(connectionId, this.initialPeerConnectionId);
    }

    public void registerConnectionIdInUse(byte[] connectionId) {
        if (this.cidRegistry.registerUsedConnectionId(connectionId) && this.cidRegistry.getActiveConnectionIds().size() < this.maxCids) {
            this.sendNewCid(0);
        }
    }

    public ConnectionIdInfo sendNewConnectionId(int retirePriorTo) {
        return this.sendNewCid(retirePriorTo);
    }

    public void registerRetrySourceConnectionId(byte[] connectionId) {
        this.retrySourceCid = connectionId;
    }

    public boolean validateRetrySourceConnectionId(byte[] connectionId) {
        return Arrays.equals(this.retrySourceCid, connectionId);
    }

    public void registerInitialPeerCid(byte[] connectionId) {
        this.peerCidRegistry.replaceInitialConnectionId(connectionId);
    }

    public void setInitialStatelessResetToken(byte[] statelessResetToken) {
        this.peerCidRegistry.setInitialStatelessResetToken(statelessResetToken);
    }

    public boolean isStatelessResetToken(byte[] data) {
        return this.peerCidRegistry.isStatelessResetToken(data);
    }

    public int getConnectionIdLength() {
        return this.connectionIdLength;
    }

    public Map<Integer, ConnectionIdInfo> getAllConnectionIds() {
        return this.cidRegistry.getAll();
    }

    public Map<Integer, ConnectionIdInfo> getAllPeerConnectionIds() {
        return this.peerCidRegistry.getAll();
    }

    public byte[] nextPeerId() {
        return this.peerCidRegistry.useNext();
    }

    public void retireConnectionId(Integer sequenceNumber) {
        this.peerCidRegistry.retireConnectionId(sequenceNumber);
        this.sender.send(new RetireConnectionIdFrame(this.quicVersion, sequenceNumber), EncryptionLevel.App, lostFrame -> this.retireConnectionId(sequenceNumber));
    }

    public byte[] getCurrentConnectionId() {
        return this.cidRegistry.getActive();
    }
}

