/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.quic.recovery;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.LocalTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Delayed;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.luminis.quic.EncryptionLevel;
import net.luminis.quic.FrameReceivedListener;
import net.luminis.quic.HandshakeState;
import net.luminis.quic.HandshakeStateListener;
import net.luminis.quic.PnSpace;
import net.luminis.quic.Role;
import net.luminis.quic.cc.CongestionController;
import net.luminis.quic.concurrent.DaemonThreadFactory;
import net.luminis.quic.frame.AckFrame;
import net.luminis.quic.frame.Padding;
import net.luminis.quic.frame.PingFrame;
import net.luminis.quic.frame.QuicFrame;
import net.luminis.quic.log.Logger;
import net.luminis.quic.packet.QuicPacket;
import net.luminis.quic.recovery.LossDetector;
import net.luminis.quic.recovery.RttEstimator;
import net.luminis.quic.send.Sender;

public class RecoveryManager
implements FrameReceivedListener<AckFrame>,
HandshakeStateListener {
    private final Clock clock;
    private final Role role;
    private final RttEstimator rttEstimater;
    private final LossDetector[] lossDetectors = new LossDetector[PnSpace.values().length];
    private final Sender sender;
    private final Logger log;
    private final ScheduledExecutorService scheduler;
    private int receiverMaxAckDelay;
    private ScheduledFuture<?> lossDetectionFuture;
    private final Object scheduleLock = new Object();
    private volatile int ptoCount;
    private volatile Instant timerExpiration;
    private volatile HandshakeState handshakeState = HandshakeState.Initial;
    private volatile boolean hasBeenReset = false;

    public RecoveryManager(Role role, RttEstimator rttEstimater, CongestionController congestionController, Sender sender, Logger logger) {
        this(Clock.systemUTC(), role, rttEstimater, congestionController, sender, logger);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public RecoveryManager(Clock clock, Role role, RttEstimator rttEstimater, CongestionController congestionController, Sender sender, Logger logger) {
        this.clock = clock;
        this.role = role;
        this.rttEstimater = rttEstimater;
        for (PnSpace pnSpace : PnSpace.values()) {
            this.lossDetectors[pnSpace.ordinal()] = new LossDetector(clock, this, rttEstimater, congestionController, () -> sender.flush(), logger.getQLog());
        }
        this.sender = sender;
        this.log = logger;
        this.scheduler = Executors.newScheduledThreadPool(1, new DaemonThreadFactory("loss-detection"));
        Object object = this.scheduleLock;
        synchronized (object) {
            this.lossDetectionFuture = new NullScheduledFuture();
        }
    }

    void setLossDetectionTimer() {
        Instant lossTime;
        PnSpaceTime earliestLossTime = this.getEarliestLossTime(LossDetector::getLossTime);
        Instant instant = lossTime = earliestLossTime != null ? earliestLossTime.lossTime : null;
        if (lossTime != null) {
            this.rescheduleLossDetectionTimeout(lossTime);
        } else {
            boolean ackElicitingInFlight = this.ackElicitingInFlight();
            boolean peerAwaitingAddressValidation = this.peerAwaitingAddressValidation();
            if (ackElicitingInFlight || peerAwaitingAddressValidation) {
                PnSpaceTime ptoTimeAndSpace = this.getPtoTimeAndSpace();
                if (ptoTimeAndSpace == null) {
                    this.log.recovery("cancelling loss detection timer (no loss time set, no ack eliciting in flight, peer not awaiting address validation (1))");
                    this.unschedule();
                } else {
                    this.rescheduleLossDetectionTimeout(ptoTimeAndSpace.lossTime);
                    if (this.log.logRecovery()) {
                        int timeout = (int)Duration.between(this.clock.instant(), ptoTimeAndSpace.lossTime).toMillis();
                        this.log.recovery("reschedule loss detection timer for PTO over " + timeout + " millis, based on %s/" + ptoTimeAndSpace.pnSpace + ", because " + (peerAwaitingAddressValidation ? "peerAwaitingAddressValidation " : "") + (ackElicitingInFlight ? "ackElicitingInFlight " : "") + "| RTT:" + this.rttEstimater.getSmoothedRtt() + "/" + this.rttEstimater.getRttVar(), ptoTimeAndSpace.lossTime);
                    }
                }
            } else {
                this.log.recovery("cancelling loss detection timer (no loss time set, no ack eliciting in flight, peer not awaiting address validation (2))");
                this.unschedule();
            }
        }
    }

    private PnSpaceTime getPtoTimeAndSpace() {
        int ptoDuration = this.rttEstimater.getSmoothedRtt() + Integer.max(1, 4 * this.rttEstimater.getRttVar());
        ptoDuration *= (int)Math.pow(2.0, this.ptoCount);
        if (this.peerAwaitingAddressValidation()) {
            if (this.handshakeState.hasNoHandshakeKeys()) {
                this.log.recovery("getPtoTimeAndSpace: no ack eliciting in flight and no handshake keys -> probe Initial");
                return new PnSpaceTime(PnSpace.Initial, this.clock.instant().plusMillis(ptoDuration));
            }
            this.log.recovery("getPtoTimeAndSpace: no ack eliciting in flight but handshake keys -> probe Handshake");
            return new PnSpaceTime(PnSpace.Handshake, this.clock.instant().plusMillis(ptoDuration));
        }
        Instant ptoTime = Instant.MAX;
        PnSpace ptoSpace = null;
        for (PnSpace pnSpace : PnSpace.values()) {
            Instant lastAckElicitingSent;
            if (!this.lossDetectors[pnSpace.ordinal()].ackElicitingInFlight()) continue;
            if (pnSpace == PnSpace.App && this.handshakeState.isNotConfirmed()) {
                this.log.recovery("getPtoTimeAndSpace is skipping level App, because handshake not yet confirmed!");
                continue;
            }
            if (pnSpace == PnSpace.App) {
                ptoDuration += this.receiverMaxAckDelay * (int)Math.pow(2.0, this.ptoCount);
            }
            if ((lastAckElicitingSent = this.lossDetectors[pnSpace.ordinal()].getLastAckElicitingSent()) == null || !lastAckElicitingSent.plusMillis(ptoDuration).isBefore(ptoTime)) continue;
            ptoTime = lastAckElicitingSent.plusMillis(ptoDuration);
            ptoSpace = pnSpace;
        }
        if (ptoSpace != null) {
            return new PnSpaceTime(ptoSpace, ptoTime);
        }
        return null;
    }

    private boolean peerAwaitingAddressValidation() {
        return this.role == Role.Client && this.handshakeState.isNotConfirmed() && this.lossDetectors[PnSpace.Handshake.ordinal()].noAckedReceived();
    }

    private void lossDetectionTimeout() {
        Instant lossTime;
        Instant expiration = this.timerExpiration;
        if (expiration == null) {
            this.log.warn("Loss detection timeout: Timer was cancelled.");
            return;
        }
        if (this.clock.instant().isBefore(expiration) && Duration.between(this.clock.instant(), expiration).toMillis() > 0L) {
            this.log.warn(String.format("Loss detection timeout running (at %s) is %s ms too early; rescheduling to %s", this.clock.instant(), Duration.between(this.clock.instant(), expiration).toMillis(), this.timerExpiration));
            this.rescheduleLossDetectionTimeout(this.timerExpiration);
        } else {
            this.log.recovery("%s loss detection timeout handler running", this.clock.instant());
        }
        PnSpaceTime earliestLossTime = this.getEarliestLossTime(LossDetector::getLossTime);
        Instant instant = lossTime = earliestLossTime != null ? earliestLossTime.lossTime : null;
        if (lossTime != null) {
            this.lossDetectors[earliestLossTime.pnSpace.ordinal()].detectLostPackets();
            this.sender.flush();
            this.setLossDetectionTimer();
        } else {
            this.sendProbe();
        }
    }

    private void sendProbe() {
        int nrOfProbes;
        if (this.log.logRecovery()) {
            PnSpaceTime earliestLastAckElicitingSentTime = this.getEarliestLossTime(LossDetector::getLastAckElicitingSent);
            if (earliestLastAckElicitingSentTime != null) {
                this.log.recovery(String.format("Sending probe %d, because no ack since %%s. Current RTT: %d/%d.", this.ptoCount, this.rttEstimater.getSmoothedRtt(), this.rttEstimater.getRttVar()), earliestLastAckElicitingSentTime.lossTime);
            } else {
                this.log.recovery(String.format("Sending probe %d. Current RTT: %d/%d.", this.ptoCount, this.rttEstimater.getSmoothedRtt(), this.rttEstimater.getRttVar()));
            }
        }
        ++this.ptoCount;
        int n = nrOfProbes = this.ptoCount > 1 ? 2 : 1;
        if (this.ackElicitingInFlight()) {
            PnSpaceTime ptoTimeAndSpace = this.getPtoTimeAndSpace();
            if (ptoTimeAndSpace == null) {
                this.log.recovery("Refraining from sending probe because received ack meanwhile");
                return;
            }
            this.sendOneOrTwoAckElicitingPackets(ptoTimeAndSpace.pnSpace, nrOfProbes);
        } else if (this.peerAwaitingAddressValidation()) {
            this.log.recovery("Sending probe because peer awaiting address validation");
            if (this.handshakeState.hasNoHandshakeKeys()) {
                this.sendOneOrTwoAckElicitingPackets(PnSpace.Initial, 1);
            } else {
                this.sendOneOrTwoAckElicitingPackets(PnSpace.Handshake, 1);
            }
        } else {
            this.log.recovery("Refraining from sending probe as no ack eliciting in flight and no peer awaiting address validation");
        }
    }

    private void sendOneOrTwoAckElicitingPackets(PnSpace pnSpace, int numberOfPackets) {
        if (pnSpace == PnSpace.Initial) {
            List<QuicFrame> framesToRetransmit = this.getFramesToRetransmit(PnSpace.Initial);
            if (!framesToRetransmit.isEmpty()) {
                this.log.recovery("(Probe is an initial retransmit)");
                this.repeatSend(numberOfPackets, () -> this.sender.sendProbe(framesToRetransmit, EncryptionLevel.Initial));
            } else {
                this.log.recovery("(Probe is Initial ping, because there is no Initial data to retransmit)");
                this.repeatSend(numberOfPackets, () -> this.sender.sendProbe(List.of(new PingFrame(), new Padding(2)), EncryptionLevel.Initial));
            }
        } else if (pnSpace == PnSpace.Handshake) {
            List<QuicFrame> framesToRetransmit = this.getFramesToRetransmit(PnSpace.Handshake);
            if (!framesToRetransmit.isEmpty()) {
                this.log.recovery("(Probe is a handshake retransmit)");
                this.repeatSend(numberOfPackets, () -> this.sender.sendProbe(framesToRetransmit, EncryptionLevel.Handshake));
            } else {
                this.log.recovery("(Probe is a handshake ping)");
                this.repeatSend(numberOfPackets, () -> this.sender.sendProbe(List.of(new PingFrame(), new Padding(2)), EncryptionLevel.Handshake));
            }
        } else {
            EncryptionLevel probeLevel = pnSpace.relatedEncryptionLevel();
            List<QuicFrame> framesToRetransmit = this.getFramesToRetransmit(pnSpace);
            if (!framesToRetransmit.isEmpty()) {
                this.log.recovery("(Probe is retransmit on level " + probeLevel + ")");
                this.repeatSend(numberOfPackets, () -> this.sender.sendProbe(framesToRetransmit, probeLevel));
            } else {
                this.log.recovery("(Probe is ping on level " + probeLevel + ")");
                this.repeatSend(numberOfPackets, () -> this.sender.sendProbe(List.of(new PingFrame(), new Padding(2)), probeLevel));
            }
        }
    }

    List<QuicFrame> getFramesToRetransmit(PnSpace pnSpace) {
        List<QuicPacket> unAckedPackets = this.lossDetectors[pnSpace.ordinal()].unAcked();
        Optional<QuicPacket> ackEliciting = unAckedPackets.stream().filter(p -> p.isAckEliciting()).filter(p -> !p.getFrames().stream().allMatch(frame -> frame instanceof PingFrame || frame instanceof Padding || frame instanceof AckFrame)).findFirst();
        if (ackEliciting.isPresent()) {
            List<QuicFrame> framesToRetransmit = ackEliciting.get().getFrames().stream().filter(frame -> !(frame instanceof AckFrame)).collect(Collectors.toList());
            return framesToRetransmit;
        }
        return Collections.emptyList();
    }

    PnSpaceTime getEarliestLossTime(Function<LossDetector, Instant> pnSpaceTimeFunction) {
        PnSpaceTime earliestLossTime = null;
        for (PnSpace pnSpace : PnSpace.values()) {
            Instant pnSpaceLossTime = pnSpaceTimeFunction.apply(this.lossDetectors[pnSpace.ordinal()]);
            if (pnSpaceLossTime == null) continue;
            if (earliestLossTime == null) {
                earliestLossTime = new PnSpaceTime(pnSpace, pnSpaceLossTime);
                continue;
            }
            if (earliestLossTime.lossTime.isBefore(pnSpaceLossTime)) continue;
            earliestLossTime = new PnSpaceTime(pnSpace, pnSpaceLossTime);
        }
        return earliestLossTime;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    void rescheduleLossDetectionTimeout(Instant scheduledTime) {
        block5: {
            try {
                Object object = this.scheduleLock;
                synchronized (object) {
                    this.lossDetectionFuture.cancel(false);
                    this.timerExpiration = scheduledTime;
                    long delay = Duration.between(this.clock.instant(), scheduledTime).toMillis();
                    this.lossDetectionFuture = this.scheduler.schedule(this::runLossDetectionTimeout, delay, TimeUnit.MILLISECONDS);
                }
            }
            catch (RejectedExecutionException taskRejected) {
                if (this.hasBeenReset) break block5;
                throw taskRejected;
            }
        }
    }

    private void runLossDetectionTimeout() {
        try {
            this.lossDetectionTimeout();
        }
        catch (Exception error) {
            this.log.error("Runtime exception occurred while running loss detection timeout handler", error);
        }
    }

    private Runnable createLossDetectionTimeoutRunnerWithTooEarlyDetection(Instant scheduledTime) {
        return () -> {
            Instant now = this.clock.instant();
            if (now.plusMillis(1L).isBefore(scheduledTime)) {
                this.log.error(String.format("Task scheduled for %s is running already at %s (%s ms too early)", scheduledTime, now, Duration.between(now, scheduledTime).toMillis()));
            }
            this.runLossDetectionTimeout();
        };
    }

    void unschedule() {
        this.lossDetectionFuture.cancel(true);
        this.timerExpiration = null;
    }

    public void onAckReceived(AckFrame ackFrame, PnSpace pnSpace, Instant timeReceived) {
        if (!this.hasBeenReset) {
            if (this.ptoCount > 0) {
                if (!this.peerAwaitingAddressValidation()) {
                    this.ptoCount = 0;
                } else {
                    this.log.recovery("probe count not reset on ack because handshake not yet confirmed");
                }
            }
            this.lossDetectors[pnSpace.ordinal()].onAckReceived(ackFrame, timeReceived);
        }
    }

    public void packetSent(QuicPacket packet, Instant sent, Consumer<QuicPacket> packetLostCallback) {
        if (!this.hasBeenReset && packet.isInflightPacket()) {
            this.lossDetectors[packet.getPnSpace().ordinal()].packetSent(packet, sent, packetLostCallback);
            this.setLossDetectionTimer();
        }
    }

    private boolean ackElicitingInFlight() {
        return Stream.of(this.lossDetectors).anyMatch(detector -> detector.ackElicitingInFlight());
    }

    public synchronized void setReceiverMaxAckDelay(int receiverMaxAckDelay) {
        this.receiverMaxAckDelay = receiverMaxAckDelay;
    }

    public void stopRecovery() {
        if (!this.hasBeenReset) {
            this.hasBeenReset = true;
            this.unschedule();
            this.scheduler.shutdown();
            for (PnSpace pnSpace : PnSpace.values()) {
                this.lossDetectors[pnSpace.ordinal()].reset();
            }
        }
    }

    public void stopRecovery(PnSpace pnSpace) {
        if (!this.hasBeenReset) {
            this.lossDetectors[pnSpace.ordinal()].reset();
            this.ptoCount = 0;
            this.setLossDetectionTimer();
        }
    }

    public long getLost() {
        return Stream.of(this.lossDetectors).mapToLong(ld -> ld.getLost()).sum();
    }

    @Override
    public void handshakeStateChangedEvent(HandshakeState newState) {
        if (!this.hasBeenReset) {
            HandshakeState oldState = this.handshakeState;
            this.handshakeState = newState;
            if (newState == HandshakeState.Confirmed && oldState != HandshakeState.Confirmed) {
                this.log.recovery("State is set to " + newState);
                this.setLossDetectionTimer();
            }
        }
    }

    @Override
    public void received(AckFrame frame, PnSpace pnSpace, Instant timeReceived) {
        this.onAckReceived(frame, pnSpace, timeReceived);
    }

    private void repeatSend(int count, Runnable task) {
        for (int i = 0; i < count; ++i) {
            task.run();
            try {
                Thread.sleep(1L);
                continue;
            }
            catch (InterruptedException interruptedException) {
                // empty catch block
            }
        }
    }

    String timeNow() {
        LocalTime localTimeNow = LocalTime.from(this.clock.instant().atZone(ZoneId.systemDefault()));
        DateTimeFormatter timeFormatter = DateTimeFormatter.ofPattern("mm:ss.SSS");
        return timeFormatter.format(localTimeNow);
    }

    private static class NullScheduledFuture
    implements ScheduledFuture<Void> {
        private NullScheduledFuture() {
        }

        @Override
        public int compareTo(Delayed o) {
            return 0;
        }

        @Override
        public long getDelay(TimeUnit unit) {
            return 0L;
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return false;
        }

        @Override
        public boolean isCancelled() {
            return false;
        }

        @Override
        public boolean isDone() {
            return false;
        }

        @Override
        public Void get() throws InterruptedException, ExecutionException {
            return null;
        }

        @Override
        public Void get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
            return null;
        }
    }

    static class PnSpaceTime {
        public PnSpace pnSpace;
        public Instant lossTime;

        public PnSpaceTime(PnSpace pnSpace, Instant pnSpaceLossTime) {
            this.pnSpace = pnSpace;
            this.lossTime = pnSpaceLossTime;
        }

        public String toString() {
            return this.lossTime.toString() + " (in " + this.pnSpace + ")";
        }
    }
}

