/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.tls.handshake;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import net.luminis.tls.ProtectionKeysType;
import net.luminis.tls.TlsConstants;
import net.luminis.tls.TlsProtocolException;
import net.luminis.tls.TlsState;
import net.luminis.tls.TranscriptHash;
import net.luminis.tls.alert.DecryptErrorAlert;
import net.luminis.tls.alert.HandshakeFailureAlert;
import net.luminis.tls.alert.IllegalParameterAlert;
import net.luminis.tls.alert.MissingExtensionAlert;
import net.luminis.tls.alert.UnexpectedMessageAlert;
import net.luminis.tls.extension.ClientHelloPreSharedKeyExtension;
import net.luminis.tls.extension.EarlyDataExtension;
import net.luminis.tls.extension.Extension;
import net.luminis.tls.extension.KeyShareExtension;
import net.luminis.tls.extension.PskKeyExchangeModesExtension;
import net.luminis.tls.extension.ServerPreSharedKeyExtension;
import net.luminis.tls.extension.SignatureAlgorithmsExtension;
import net.luminis.tls.extension.SupportedGroupsExtension;
import net.luminis.tls.extension.SupportedVersionsExtension;
import net.luminis.tls.handshake.CertificateMessage;
import net.luminis.tls.handshake.CertificateVerifyMessage;
import net.luminis.tls.handshake.ClientHello;
import net.luminis.tls.handshake.EncryptedExtensions;
import net.luminis.tls.handshake.FinishedMessage;
import net.luminis.tls.handshake.NewSessionTicketMessage;
import net.luminis.tls.handshake.ServerHello;
import net.luminis.tls.handshake.ServerMessageProcessor;
import net.luminis.tls.handshake.ServerMessageSender;
import net.luminis.tls.handshake.TlsEngine;
import net.luminis.tls.handshake.TlsSession;
import net.luminis.tls.handshake.TlsSessionRegistry;
import net.luminis.tls.handshake.TlsStatusEventHandler;

public class TlsServerEngine
extends TlsEngine
implements ServerMessageProcessor {
    private final Set<TlsConstants.CipherSuite> supportedCiphers;
    private final ArrayList<Extension> extensions;
    private ServerMessageSender serverMessageSender;
    protected TlsStatusEventHandler statusHandler;
    private List<X509Certificate> serverCertificateChain;
    private PrivateKey certificatePrivateKey;
    private TranscriptHash transcriptHash;
    private TlsConstants.CipherSuite selectedCipher;
    private List<Extension> serverExtensions;
    private List<TlsConstants.PskKeyExchangeMode> clientSupportedKeyExchangeModes;
    private TlsSessionRegistry sessionRegistry;
    private byte currentTicketNumber = 0;
    private String selectedApplicationLayerProtocol;
    private Long maxEarlyDataSize = 0xFFFFFFFFL;
    private byte[] additionalSessionData;
    private Function<ByteBuffer, Boolean> sessionDataVerificationCallback;

    public TlsServerEngine(List<X509Certificate> certificates, PrivateKey certificateKey, ServerMessageSender serverMessageSender, TlsStatusEventHandler tlsStatusHandler, TlsSessionRegistry tlsSessionRegistry) {
        this.serverCertificateChain = certificates;
        this.certificatePrivateKey = certificateKey;
        this.serverMessageSender = serverMessageSender;
        this.statusHandler = tlsStatusHandler;
        this.supportedCiphers = new HashSet<TlsConstants.CipherSuite>();
        this.supportedCiphers.add(TlsConstants.CipherSuite.TLS_AES_128_GCM_SHA256);
        this.extensions = new ArrayList();
        this.serverExtensions = new ArrayList<Extension>();
        this.clientSupportedKeyExchangeModes = new ArrayList<TlsConstants.PskKeyExchangeMode>();
        this.sessionRegistry = tlsSessionRegistry;
    }

    public TlsServerEngine(X509Certificate serverCertificate, PrivateKey certificateKey, ServerMessageSender serverMessageSender, TlsStatusEventHandler tlsStatusHandler, TlsSessionRegistry tlsSessionRegistry) {
        this(List.of(serverCertificate), certificateKey, serverMessageSender, tlsStatusHandler, tlsSessionRegistry);
    }

    @Override
    public void received(ClientHello clientHello, ProtectionKeysType protectedBy) throws TlsProtocolException, IOException {
        this.selectedCipher = clientHello.getCipherSuites().stream().filter(it -> this.supportedCiphers.contains(it)).findFirst().orElseThrow(() -> new HandshakeFailureAlert("Failed to negotiate a cipher (server only supports " + this.supportedCiphers.stream().map(c -> c.toString()).collect(Collectors.joining(", ")) + ")"));
        SupportedGroupsExtension supportedGroupsExt = (SupportedGroupsExtension)clientHello.getExtensions().stream().filter(ext -> ext instanceof SupportedGroupsExtension).findFirst().orElseThrow(() -> new MissingExtensionAlert("supported groups extension is required in Client Hello"));
        List<TlsConstants.NamedGroup> serverSupportedGroups = List.of(TlsConstants.NamedGroup.secp256r1, TlsConstants.NamedGroup.x25519);
        if (supportedGroupsExt.getNamedGroups().stream().filter(serverSupportedGroups::contains).findFirst().isEmpty()) {
            throw new HandshakeFailureAlert(String.format("Failed to negotiate supported group (server only supports %s)", serverSupportedGroups));
        }
        KeyShareExtension keyShareExtension = (KeyShareExtension)clientHello.getExtensions().stream().filter(ext -> ext instanceof KeyShareExtension).findFirst().orElseThrow(() -> new MissingExtensionAlert("key share extension is required in Client Hello"));
        KeyShareExtension.KeyShareEntry keyShareEntry = keyShareExtension.getKeyShareEntries().stream().filter(entry -> serverSupportedGroups.contains((Object)entry.getNamedGroup())).findFirst().orElseThrow(() -> new IllegalParameterAlert("key share named group not supported (and no HelloRetryRequest support)"));
        SignatureAlgorithmsExtension signatureAlgorithmsExtension = (SignatureAlgorithmsExtension)clientHello.getExtensions().stream().filter(ext -> ext instanceof SignatureAlgorithmsExtension).findFirst().orElseThrow(() -> new MissingExtensionAlert("signature algorithms extension is required in Client Hello"));
        clientHello.getExtensions().stream().filter(ext -> ext instanceof PskKeyExchangeModesExtension).findFirst().ifPresent(extension -> this.clientSupportedKeyExchangeModes.addAll(((PskKeyExchangeModesExtension)extension).getKeyExchangeModes()));
        if (!signatureAlgorithmsExtension.getSignatureAlgorithms().contains((Object)TlsConstants.SignatureScheme.rsa_pss_rsae_sha256)) {
            throw new HandshakeFailureAlert("Failed to negotiate signature algorithm (server only supports rsa_pss_rsae_sha256");
        }
        Optional<Extension> pskExtension = clientHello.getExtensions().stream().filter(ext -> ext instanceof ClientHelloPreSharedKeyExtension).findFirst();
        this.statusHandler.extensionsReceived(clientHello.getExtensions());
        boolean earlyDataAccepted = false;
        Integer selectedIdentity = null;
        if (pskExtension.isPresent()) {
            TlsSession resumedSession;
            ClientHelloPreSharedKeyExtension preSharedKeyExtension;
            if (this.clientSupportedKeyExchangeModes.isEmpty()) {
                throw new MissingExtensionAlert("psk_key_exchange_modes extension required with pre_shared_key");
            }
            if (this.clientSupportedKeyExchangeModes.contains((Object)TlsConstants.PskKeyExchangeMode.psk_dhe_ke) && (selectedIdentity = this.sessionRegistry.selectIdentity((preSharedKeyExtension = (ClientHelloPreSharedKeyExtension)pskExtension.get()).getIdentities(), this.selectedCipher)) != null && this.isAcceptable(this.sessionRegistry.peekSessionData(preSharedKeyExtension.getIdentities().get(selectedIdentity))) && (resumedSession = this.sessionRegistry.useSession(preSharedKeyExtension.getIdentities().get(selectedIdentity))) != null) {
                this.transcriptHash = new TranscriptHash(TlsServerEngine.hashLength(this.selectedCipher));
                this.state = new TlsState(this.transcriptHash, resumedSession.getPsk(), TlsServerEngine.keyLength(this.selectedCipher), TlsServerEngine.hashLength(this.selectedCipher));
                if (!this.validateBinder(preSharedKeyExtension.getBinders().get(selectedIdentity), preSharedKeyExtension.getBinderPosition(), clientHello)) {
                    this.state = null;
                    throw new DecryptErrorAlert("Invalid PSK binder");
                }
                if (clientHello.getExtensions().stream().filter(ext -> ext instanceof EarlyDataExtension).findAny().isPresent() && selectedIdentity == 0 && this.selectedApplicationLayerProtocol != null && this.selectedApplicationLayerProtocol.equals(resumedSession.getApplicationLayerProtocol())) {
                    earlyDataAccepted = this.statusHandler.isEarlyDataAccepted();
                }
            }
        }
        if (this.state == null) {
            this.transcriptHash = new TranscriptHash(TlsServerEngine.hashLength(this.selectedCipher));
            this.state = new TlsState(this.transcriptHash, TlsServerEngine.keyLength(this.selectedCipher), TlsServerEngine.hashLength(this.selectedCipher));
            selectedIdentity = null;
        }
        this.transcriptHash.record(clientHello);
        this.generateKeys(keyShareEntry.getNamedGroup());
        this.state.setOwnKey(this.privateKey);
        this.state.computeEarlyTrafficSecret();
        this.statusHandler.earlySecretsKnown();
        List<Extension> extensions = List.of(new SupportedVersionsExtension(TlsConstants.HandshakeType.server_hello), new KeyShareExtension(this.publicKey, keyShareEntry.getNamedGroup(), TlsConstants.HandshakeType.server_hello));
        if (selectedIdentity != null) {
            extensions = new ArrayList<KeyShareExtension>(extensions);
            extensions.add(new ServerPreSharedKeyExtension(selectedIdentity.shortValue()));
        }
        ServerHello serverHello = new ServerHello(this.selectedCipher, extensions);
        this.serverMessageSender.send(serverHello);
        this.transcriptHash.record(serverHello);
        this.state.setPeerKey(keyShareEntry.getKey());
        this.state.computeSharedSecret();
        this.state.computeHandshakeSecrets();
        this.statusHandler.handshakeSecretsKnown();
        if (earlyDataAccepted) {
            this.serverExtensions.add(new EarlyDataExtension());
        }
        EncryptedExtensions encryptedExtensions = new EncryptedExtensions(this.serverExtensions);
        this.serverMessageSender.send(encryptedExtensions);
        this.transcriptHash.record(encryptedExtensions);
        if (selectedIdentity == null) {
            CertificateMessage certificate = new CertificateMessage(this.serverCertificateChain);
            this.serverMessageSender.send(certificate);
            this.transcriptHash.recordServer(certificate);
            byte[] hash = this.transcriptHash.getServerHash(TlsConstants.HandshakeType.certificate);
            byte[] signature = this.computeSignature(hash, this.certificatePrivateKey, TlsConstants.SignatureScheme.rsa_pss_rsae_sha256, false);
            CertificateVerifyMessage certificateVerify = new CertificateVerifyMessage(TlsConstants.SignatureScheme.rsa_pss_rsae_sha256, signature);
            this.serverMessageSender.send(certificateVerify);
            this.transcriptHash.recordServer(certificateVerify);
        }
        byte[] hmac = this.computeFinishedVerifyData(this.transcriptHash.getServerHash(TlsConstants.HandshakeType.certificate_verify), this.state.getServerHandshakeTrafficSecret());
        FinishedMessage finished = new FinishedMessage(hmac);
        this.serverMessageSender.send(finished);
        this.transcriptHash.recordServer(finished);
        this.state.computeApplicationSecrets();
    }

    private boolean isAcceptable(byte[] sessionData) {
        if (this.sessionDataVerificationCallback == null || sessionData == null) {
            return true;
        }
        return this.sessionDataVerificationCallback.apply(ByteBuffer.wrap(sessionData));
    }

    @Override
    public void received(FinishedMessage clientFinished, ProtectionKeysType protectedBy) throws TlsProtocolException, IOException {
        if (protectedBy != ProtectionKeysType.Handshake) {
            throw new UnexpectedMessageAlert("incorrect protection level");
        }
        this.transcriptHash.recordClient(clientFinished);
        byte[] serverHmac = this.computeFinishedVerifyData(this.transcriptHash.getServerHash(TlsConstants.HandshakeType.finished), this.state.getClientHandshakeTrafficSecret());
        if (!Arrays.equals(clientFinished.getVerifyData(), serverHmac)) {
            throw new DecryptErrorAlert("incorrect finished message");
        }
        this.state.computeResumptionMasterSecret();
        this.statusHandler.handshakeFinished();
        if (this.sessionRegistry != null && this.clientSupportedKeyExchangeModes.contains((Object)TlsConstants.PskKeyExchangeMode.psk_dhe_ke)) {
            byte by = this.currentTicketNumber;
            this.currentTicketNumber = (byte)(by + 1);
            NewSessionTicketMessage newSessionTicketMessage = this.sessionRegistry.createNewSessionTicketMessage(by, this.selectedCipher, this.state, this.selectedApplicationLayerProtocol, this.maxEarlyDataSize, this.additionalSessionData);
            this.serverMessageSender.send(newSessionTicketMessage);
        }
    }

    protected boolean validateBinder(ClientHelloPreSharedKeyExtension.PskBinderEntry pskBinderEntry, int binderPosition, ClientHello clientHello) {
        byte[] partialCH = Arrays.copyOfRange(clientHello.getBytes(), 0, clientHello.getPskExtensionStartPosition() + binderPosition);
        byte[] binder = this.state.computePskBinder(partialCH);
        boolean valid = Arrays.equals(pskBinderEntry.getHmac(), binder);
        return valid;
    }

    public void addSupportedCiphers(List<TlsConstants.CipherSuite> cipherSuites) {
        this.supportedCiphers.addAll(cipherSuites);
    }

    public void setServerMessageSender(ServerMessageSender serverMessageSender) {
        this.serverMessageSender = serverMessageSender;
    }

    public void setStatusHandler(TlsStatusEventHandler statusHandler) {
        this.statusHandler = statusHandler;
    }

    @Override
    public TlsConstants.CipherSuite getSelectedCipher() {
        return this.selectedCipher;
    }

    public List<Extension> getServerExtensions() {
        return this.serverExtensions;
    }

    public void addServerExtensions(Extension extension) {
        this.serverExtensions.add(extension);
    }

    public void setSelectedApplicationLayerProtocol(String applicationProtocol) {
        if (applicationProtocol == null) {
            throw new IllegalArgumentException();
        }
        this.selectedApplicationLayerProtocol = applicationProtocol;
    }

    public void setSessionData(byte[] additionalSessionData) {
        this.additionalSessionData = additionalSessionData;
    }

    public void setSessionDataVerificationCallback(Function<ByteBuffer, Boolean> callback) {
        this.sessionDataVerificationCallback = callback;
    }
}

