/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.quic.server;

import java.io.IOException;
import java.io.InputStream;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.time.Instant;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
import net.luminis.quic.EncryptionLevel;
import net.luminis.quic.RawPacket;
import net.luminis.quic.Receiver;
import net.luminis.quic.Version;
import net.luminis.quic.log.Logger;
import net.luminis.quic.packet.InitialPacket;
import net.luminis.quic.packet.VersionNegotiationPacket;
import net.luminis.quic.server.ApplicationProtocolConnectionFactory;
import net.luminis.quic.server.ApplicationProtocolRegistry;
import net.luminis.quic.server.ConnectionSource;
import net.luminis.quic.server.Context;
import net.luminis.quic.server.ServerConnectionCandidate;
import net.luminis.quic.server.ServerConnectionFactory;
import net.luminis.quic.server.ServerConnectionImpl;
import net.luminis.quic.server.ServerConnectionProxy;
import net.luminis.quic.server.ServerConnectionRegistry;
import net.luminis.tls.handshake.TlsServerEngineFactory;
import net.luminis.tls.util.ByteUtils;

public class ServerConnector
implements ServerConnectionRegistry {
    private static final int MINIMUM_LONG_HEADER_LENGTH = 7;
    private static final int CONNECTION_ID_LENGTH = 4;
    private final Receiver receiver;
    private final Logger log;
    private final List<Version> supportedVersions;
    private final List<Integer> supportedVersionIds;
    private final DatagramSocket serverSocket;
    private final boolean requireRetry;
    private Integer initalRtt = 100;
    private Map<ConnectionSource, ServerConnectionProxy> currentConnections;
    private TlsServerEngineFactory tlsEngineFactory;
    private final ServerConnectionFactory serverConnectionFactory;
    private ApplicationProtocolRegistry applicationProtocolRegistry;
    private final ExecutorService sharedExecutor = Executors.newSingleThreadExecutor();
    private final ScheduledExecutorService sharedScheduledExecutor = Executors.newSingleThreadScheduledExecutor();
    private Context context;

    public ServerConnector(int port, InputStream certificateFile, InputStream certificateKeyFile, List<Version> supportedVersions, boolean requireRetry, Logger log) throws Exception {
        this(new DatagramSocket(port), certificateFile, certificateKeyFile, supportedVersions, requireRetry, log);
    }

    public ServerConnector(DatagramSocket socket, InputStream certificateFile, InputStream certificateKeyFile, List<Version> supportedVersions, boolean requireRetry, Logger log) throws Exception {
        this.serverSocket = socket;
        this.supportedVersions = supportedVersions;
        this.requireRetry = requireRetry;
        this.log = Objects.requireNonNull(log);
        this.tlsEngineFactory = new TlsServerEngineFactory(certificateFile, certificateKeyFile);
        this.applicationProtocolRegistry = new ApplicationProtocolRegistry();
        this.serverConnectionFactory = new ServerConnectionFactory(4, this.serverSocket, this.tlsEngineFactory, this.requireRetry, this.applicationProtocolRegistry, this.initalRtt, this, this::removeConnection, log);
        this.supportedVersionIds = supportedVersions.stream().map(version -> version.getId()).collect(Collectors.toList());
        this.currentConnections = new ConcurrentHashMap<ConnectionSource, ServerConnectionProxy>();
        this.receiver = new Receiver(this.serverSocket, log, exception -> System.exit(9));
        this.context = new ServerConnectorContext();
    }

    public void registerApplicationProtocol(String protocol, ApplicationProtocolConnectionFactory protocolConnectionFactory) {
        this.applicationProtocolRegistry.registerApplicationProtocol(protocol, protocolConnectionFactory);
    }

    public Set<String> getRegisteredApplicationProtocols() {
        return this.applicationProtocolRegistry.getRegisteredApplicationProtocols();
    }

    public void start() {
        this.receiver.start();
        new Thread(this::receiveLoop, "server receive loop").start();
        this.log.info("Kwik server connector started on port " + this.serverSocket.getLocalPort() + "; supported application protcols: " + this.applicationProtocolRegistry.getRegisteredApplicationProtocols());
    }

    protected void receiveLoop() {
        while (true) {
            try {
                while (true) {
                    RawPacket rawPacket = this.receiver.get((int)Duration.ofDays(3650L).toSeconds());
                    this.process(rawPacket);
                }
            }
            catch (InterruptedException e) {
                this.log.error("receiver interrupted (ignoring)");
            }
            catch (Exception runtimeError) {
                this.log.error("Uncaught exception in server receive loop", runtimeError);
                continue;
            }
            break;
        }
    }

    protected void process(RawPacket rawPacket) {
        ByteBuffer data = rawPacket.getData();
        byte flags = data.get();
        data.rewind();
        if ((flags & 0xC0) == 192) {
            this.processLongHeaderPacket(new InetSocketAddress(rawPacket.getAddress(), rawPacket.getPort()), data);
        } else if ((flags & 0xC0) == 64) {
            this.processShortHeaderPacket(new InetSocketAddress(rawPacket.getAddress(), rawPacket.getPort()), data);
        } else {
            this.log.warn(String.format("Invalid Quic packet (flags: %02x) is discarded", flags));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void processLongHeaderPacket(InetSocketAddress clientAddress, ByteBuffer data) {
        if (data.remaining() >= 7) {
            data.position(1);
            int version = data.getInt();
            data.position(5);
            int dcidLength = data.get() & 0xFF;
            if (dcidLength > 20) {
                if (this.initialWithUnspportedVersion(data, version)) {
                    this.sendVersionNegotiationPacket(clientAddress, data, dcidLength);
                }
                return;
            }
            if (data.remaining() >= dcidLength + 1) {
                byte[] dcid = new byte[dcidLength];
                data.get(dcid);
                int scidLength = data.get() & 0xFF;
                if (data.remaining() >= scidLength) {
                    byte[] scid = new byte[scidLength];
                    data.get(scid);
                    data.rewind();
                    Optional<ServerConnectionProxy> connection = this.isExistingConnection(clientAddress, dcid);
                    if (connection.isEmpty()) {
                        ServerConnector serverConnector = this;
                        synchronized (serverConnector) {
                            if (this.mightStartNewConnection(data, version, dcid) && this.isExistingConnection(clientAddress, dcid).isEmpty()) {
                                connection = Optional.of(this.createNewConnection(version, clientAddress, scid, dcid));
                            } else if (this.initialWithUnspportedVersion(data, version)) {
                                this.log.received(Instant.now(), 0, EncryptionLevel.Initial, dcid, scid);
                                this.sendVersionNegotiationPacket(clientAddress, data, dcidLength);
                            }
                        }
                    }
                    connection.ifPresent(c -> c.parsePackets(0, Instant.now(), data));
                }
            }
        }
    }

    private void processShortHeaderPacket(InetSocketAddress clientAddress, ByteBuffer data) {
        byte[] dcid = new byte[4];
        data.position(1);
        data.get(dcid);
        data.rewind();
        Optional<ServerConnectionProxy> connection = this.isExistingConnection(clientAddress, dcid);
        connection.ifPresentOrElse(c -> c.parsePackets(0, Instant.now(), data), () -> this.log.warn("Discarding short header packet addressing non existent connection " + ByteUtils.bytesToHex((byte[])dcid)));
    }

    private boolean mightStartNewConnection(ByteBuffer packetBytes, int version, byte[] dcid) {
        if (dcid.length >= 8) {
            return this.supportedVersionIds.contains(version);
        }
        return false;
    }

    private boolean initialWithUnspportedVersion(ByteBuffer packetBytes, int version) {
        packetBytes.rewind();
        int type = (packetBytes.get() & 0x30) >> 4;
        if (InitialPacket.isInitial(type, Version.parse(version)) && packetBytes.limit() >= 1200) {
            return !this.supportedVersionIds.contains(version);
        }
        return false;
    }

    private ServerConnectionProxy createNewConnection(int versionValue, InetSocketAddress clientAddress, byte[] scid, byte[] dcid) {
        Version version = Version.parse(versionValue);
        ServerConnectionCandidate connectionCandidate = new ServerConnectionCandidate(this.context, version, clientAddress, scid, dcid, this.serverConnectionFactory, this, this.log);
        this.currentConnections.put(new ConnectionSource(dcid), connectionCandidate);
        return connectionCandidate;
    }

    private void removeConnection(ServerConnectionImpl connection) {
        ServerConnectionProxy removed = null;
        for (byte[] connectionId : connection.getActiveConnectionIds()) {
            if (removed == null) {
                removed = this.currentConnections.remove(new ConnectionSource(connectionId));
                if (removed != null) continue;
                this.log.error("Cannot remove connection with cid " + ByteUtils.bytesToHex((byte[])connectionId));
                continue;
            }
            if (removed == this.currentConnections.remove(new ConnectionSource(connectionId))) continue;
            this.log.error("Removed connections for set of active cids are not identical");
        }
        this.currentConnections.remove(new ConnectionSource(connection.getOriginalDestinationConnectionId()));
        if (!removed.isClosed()) {
            this.log.error("Removed connection with dcid " + ByteUtils.bytesToHex((byte[])connection.getOriginalDestinationConnectionId()) + " that is not closed...");
        }
        removed.terminate();
    }

    private Optional<ServerConnectionProxy> isExistingConnection(InetSocketAddress clientAddress, byte[] dcid) {
        return Optional.ofNullable(this.currentConnections.get(new ConnectionSource(dcid)));
    }

    private void sendVersionNegotiationPacket(InetSocketAddress clientAddress, ByteBuffer data, int dcidLength) {
        data.rewind();
        if (data.remaining() >= 6 + dcidLength + 1) {
            byte[] dcid = new byte[dcidLength];
            data.position(6);
            data.get(dcid);
            int scidLength = data.get() & 0xFF;
            byte[] scid = new byte[scidLength];
            if (scidLength > 0) {
                data.get(scid);
            }
            VersionNegotiationPacket versionNegotiationPacket = new VersionNegotiationPacket(this.supportedVersions, dcid, scid);
            byte[] packetBytes = versionNegotiationPacket.generatePacketBytes(null);
            DatagramPacket datagram = new DatagramPacket(packetBytes, packetBytes.length, clientAddress.getAddress(), clientAddress.getPort());
            try {
                this.serverSocket.send(datagram);
                this.log.sent(Instant.now(), versionNegotiationPacket);
            }
            catch (IOException e) {
                this.log.error("Sending version negotiation packet failed", e);
            }
        }
    }

    @Override
    public void registerConnection(ServerConnectionProxy connection, byte[] connectionId) {
        this.currentConnections.put(new ConnectionSource(connectionId), connection);
    }

    @Override
    public void deregisterConnection(ServerConnectionProxy connection, byte[] connectionId) {
        boolean removed = this.currentConnections.remove(new ConnectionSource(connectionId), connection);
        if (!removed && this.currentConnections.containsKey(new ConnectionSource(connectionId))) {
            this.log.error("Connection " + connection + " not removed, because " + this.currentConnections.get(new ConnectionSource(connectionId)) + " is registered for " + ByteUtils.bytesToHex((byte[])connectionId));
        }
    }

    @Override
    public void registerAdditionalConnectionId(byte[] currentConnectionId, byte[] newConnectionId) {
        ServerConnectionProxy connection = this.currentConnections.get(new ConnectionSource(currentConnectionId));
        if (connection != null) {
            this.currentConnections.put(new ConnectionSource(newConnectionId), connection);
        } else {
            this.log.error("Cannot add additional cid to non-existing connection " + ByteUtils.bytesToHex((byte[])currentConnectionId));
        }
    }

    @Override
    public void deregisterConnectionId(byte[] connectionId) {
        this.currentConnections.remove(new ConnectionSource(connectionId));
    }

    private void logConnectionTable() {
        this.log.info("Connection table: \n" + this.currentConnections.entrySet().stream().sorted(new Comparator<Map.Entry<ConnectionSource, ServerConnectionProxy>>(){

            @Override
            public int compare(Map.Entry<ConnectionSource, ServerConnectionProxy> o1, Map.Entry<ConnectionSource, ServerConnectionProxy> o2) {
                return o1.getValue().toString().compareTo(o2.getValue().toString());
            }
        }).map(e -> e.getKey() + "->" + e.getValue()).collect(Collectors.joining("\n")));
    }

    private class ServerConnectorContext
    implements Context {
        private ServerConnectorContext() {
        }

        @Override
        public ExecutorService getSharedServerExecutor() {
            return ServerConnector.this.sharedExecutor;
        }

        @Override
        public ScheduledExecutorService getSharedScheduledExecutor() {
            return ServerConnector.this.sharedScheduledExecutor;
        }
    }
}

