/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.quic.packet;

import java.io.Serializable;
import java.nio.ByteBuffer;
import java.time.Instant;
import java.util.ArrayList;
import java.util.stream.Collectors;
import net.luminis.quic.DecryptionException;
import net.luminis.quic.EncryptionLevel;
import net.luminis.quic.InvalidPacketException;
import net.luminis.quic.PacketProcessor;
import net.luminis.quic.PnSpace;
import net.luminis.quic.Version;
import net.luminis.quic.crypto.Aead;
import net.luminis.quic.frame.QuicFrame;
import net.luminis.quic.log.Logger;
import net.luminis.quic.packet.QuicPacket;
import net.luminis.tls.util.ByteUtils;

public class ShortHeaderPacket
extends QuicPacket {
    protected short keyPhaseBit;

    public ShortHeaderPacket(Version quicVersion) {
        this.quicVersion = quicVersion;
    }

    public ShortHeaderPacket(Version quicVersion, byte[] destinationConnectionId, QuicFrame frame) {
        this.quicVersion = quicVersion;
        this.destinationConnectionId = destinationConnectionId;
        this.frames = new ArrayList();
        if (frame != null) {
            this.frames.add(frame);
        }
    }

    @Override
    public void parse(ByteBuffer buffer, Aead aead, long largestPacketNumber, Logger log, int sourceConnectionIdLength) throws DecryptionException, InvalidPacketException {
        log.debug("Parsing " + this.getClass().getSimpleName());
        if (buffer.remaining() < 1 + sourceConnectionIdLength) {
            throw new InvalidPacketException();
        }
        if (buffer.position() != 0) {
            throw new IllegalStateException();
        }
        byte flags = buffer.get();
        this.checkPacketType(flags);
        byte[] packetConnectionId = new byte[sourceConnectionIdLength];
        this.destinationConnectionId = packetConnectionId;
        buffer.get(packetConnectionId);
        log.debug("Destination connection id", packetConnectionId);
        try {
            this.parsePacketNumberAndPayload(buffer, flags, buffer.limit() - buffer.position(), aead, largestPacketNumber, log);
            aead.confirmKeyUpdateIfInProgress();
        }
        catch (DecryptionException cantDecrypt) {
            aead.cancelKeyUpdateIfInProgress();
            throw cantDecrypt;
        }
        finally {
            this.packetSize = buffer.position() - 0;
        }
    }

    @Override
    protected void setUnprotectedHeader(byte decryptedFlags) {
        this.keyPhaseBit = (short)((decryptedFlags & 4) >> 2);
    }

    @Override
    public int estimateLength(int additionalPayload) {
        int packetNumberSize = ShortHeaderPacket.computePacketNumberSize(this.packetNumber);
        int payloadSize = this.getFrames().stream().mapToInt(f -> f.getFrameLength()).sum() + additionalPayload;
        int padding = Integer.max(0, 4 - packetNumberSize - payloadSize);
        return 1 + this.destinationConnectionId.length + (this.packetNumber < 0L ? 4 : packetNumberSize) + payloadSize + padding + 16;
    }

    @Override
    public EncryptionLevel getEncryptionLevel() {
        return EncryptionLevel.App;
    }

    @Override
    public PnSpace getPnSpace() {
        return PnSpace.App;
    }

    @Override
    public byte[] generatePacketBytes(Aead aead) {
        assert (this.packetNumber >= 0L);
        ByteBuffer buffer = ByteBuffer.allocate(1500);
        byte flags = 64;
        this.keyPhaseBit = aead.getKeyPhase();
        flags = (byte)(flags | this.keyPhaseBit << 2);
        flags = ShortHeaderPacket.encodePacketNumberLength(flags, this.packetNumber);
        buffer.put(flags);
        buffer.put(this.destinationConnectionId);
        byte[] encodedPacketNumber = ShortHeaderPacket.encodePacketNumber(this.packetNumber);
        buffer.put(encodedPacketNumber);
        ByteBuffer frameBytes = this.generatePayloadBytes(encodedPacketNumber.length);
        this.protectPacketNumberAndPayload(buffer, encodedPacketNumber.length, frameBytes, 0, aead);
        buffer.limit(buffer.position());
        this.packetSize = buffer.limit();
        byte[] packetBytes = new byte[this.packetSize];
        buffer.rewind();
        buffer.get(packetBytes);
        this.packetSize = packetBytes.length;
        return packetBytes;
    }

    @Override
    public PacketProcessor.ProcessResult accept(PacketProcessor processor, Instant time) {
        return processor.process(this, time);
    }

    protected void checkPacketType(byte flags) {
        if ((flags & 0xC0) != 64) {
            throw new RuntimeException();
        }
    }

    @Override
    public byte[] getDestinationConnectionId() {
        return this.destinationConnectionId;
    }

    public String toString() {
        return "Packet " + (this.isProbe ? "P" : "") + this.getEncryptionLevel().name().charAt(0) + "|" + (Serializable)(this.packetNumber >= 0L ? Long.valueOf(this.packetNumber) : ".") + "|S" + this.keyPhaseBit + "|" + ByteUtils.bytesToHex((byte[])this.destinationConnectionId) + "|" + this.packetSize + "|" + this.frames.size() + "  " + this.frames.stream().map(f -> f.toString()).collect(Collectors.joining(" "));
    }
}

