/* * Copyright © 2019, 2020, 2021, 2022, 2023 Peter Doornbosch * * This file is part of Kwik, an implementation of the QUIC protocol in Java. * * Kwik is free software: you can redistribute it and/or modify it under * the terms of the GNU Lesser General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your option) * any later version. * * Kwik is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for * more details. * * You should have received a copy of the GNU Lesser General Public License * along with this program. If not, see . */ package net.luminis.quic.recovery; import net.luminis.quic.*; import net.luminis.quic.ack.Range; import net.luminis.quic.cc.CongestionControlEventListener; import net.luminis.quic.cc.CongestionController; import net.luminis.quic.cc.NewRenoCongestionController; import net.luminis.quic.frame.AckFrame; import net.luminis.quic.frame.ConnectionCloseFrame; import net.luminis.quic.frame.Padding; import net.luminis.quic.frame.PingFrame; import net.luminis.quic.log.NullLogger; import net.luminis.quic.packet.PacketInfo; import net.luminis.quic.packet.QuicPacket; import net.luminis.quic.qlog.NullQLog; import net.luminis.quic.test.TestClock; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import net.luminis.quic.test.FieldSetter; import java.time.Instant; import java.util.List; import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.*; class LossDetectorTest extends RecoveryTests { private LossDetector lossDetector; private LostPacketHandler lostPacketHandler; private int defaultRtt = 10; private CongestionController congestionController; private RttEstimator rttEstimator; private TestClock clock; @BeforeEach void initObjectUnderTest() throws Exception { clock = new TestClock(); rttEstimator = mock(RttEstimator.class); when(rttEstimator.getSmoothedRtt()).thenReturn(defaultRtt); when(rttEstimator.getLatestRtt()).thenReturn(defaultRtt); congestionController = mock(CongestionController.class); lossDetector = new LossDetector(mock(RecoveryManager.class), rttEstimator, congestionController, () -> {}, new NullQLog()); FieldSetter.setField(lossDetector, lossDetector.getClass().getDeclaredField("clock"), clock); } @BeforeEach void initLostPacketCallback() { lostPacketHandler = mock(LostPacketHandler.class); } @Test void congestionControllerIsOnlyCalledOncePerAck() { List packets = createPackets(1, 2, 3); lossDetector.packetSent(packets.get(0), Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(packets.get(1), Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(packets.get(2), Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.onAckReceived(new AckFrame(new Range(1L, 2L)), Instant.now()); lossDetector.onAckReceived(new AckFrame(new Range(1L, 2L)), Instant.now()); verify(congestionController, times(2)).registerAcked(any(List.class)); } @Test void congestionControllerRegisterAckedNotCalledWithAckOnlyPacket() { QuicPacket packet = createPacket(1, new AckFrame(10)); lossDetector.packetSent(packet, Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.onAckReceived(new AckFrame(1), Instant.now()); verify(congestionController, times(1)).registerAcked(argThat(MoreArgumentMatchers.emptyList())); } @Test void congestionControllerRegisterLostNotCalledWithAckOnlyPacket() { QuicPacket packet = createPacket(1, new AckFrame(10)); lossDetector.packetSent(packet, Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.onAckReceived(new AckFrame(4), Instant.now()); verify(congestionController, times(0)).registerLost(anyList()); } @Test void withoutAcksNothingIsDeclaredLost() { int count = 10; Instant now = Instant.now(); for (int i = 0; i < count; i++) { QuicPacket packet = createPacket(i); lossDetector.packetSent(packet, now.minusMillis(100 * (count - i)), lostPacket -> lostPacketHandler.process(lostPacket)); } verify(lostPacketHandler, never()).process(any(QuicPacket.class)); } @Test void packetIsNotYetLostWhenTwoLaterPacketsAreAcked() { List packets = createPackets(1, 2, 3); lossDetector.packetSent(packets.get(0), Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(packets.get(1), Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(packets.get(2), Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.onAckReceived(new AckFrame(new Range(1L, 2L)), Instant.now()); verify(lostPacketHandler, never()).process(any(QuicPacket.class)); } @Test void packetIsLostWhenThreeLaterPacketsAreAcked() { List packets = createPackets(1, 2, 3, 4); lossDetector.packetSent(packets.get(0), Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(packets.get(1), Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(packets.get(2), Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(packets.get(3), Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.onAckReceived(new AckFrame(new Range(2L, 4L)), Instant.now()); verify(lostPacketHandler, times(1)).process(argThat(new PacketMatcherByPacketNumber(1))); } @Test void ackOnlyPacketCannotBeDeclaredLost() { QuicPacket ackOnlyPacket = createPacket(1, new AckFrame()); lossDetector.packetSent(ackOnlyPacket, Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket)); List packets = createPackets(2, 3, 4); packets.forEach(p -> lossDetector.packetSent(p, Instant.now(), lostPacket -> lostPacketHandler.process(lostPacket))); lossDetector.onAckReceived(new AckFrame(new Range(2L, 4L)), Instant.now()); verify(lostPacketHandler, never()).process(any(QuicPacket.class)); } @Test void packetTooOldIsDeclaredLost() { // Given second packets is sent (a little) more than 9/8 rtt int timeDiff = (defaultRtt * 9 / 8) + 1; lossDetector.packetSent(createPacket(6), clock.instant(), lostPacket -> lostPacketHandler.process(lostPacket)); clock.fastForward(timeDiff); lossDetector.packetSent(createPacket(8), clock.instant(), lostPacket -> lostPacketHandler.process(lostPacket)); // When when (only) the second packets is acked lossDetector.onAckReceived(new AckFrame(new Range(8L)), clock.instant()); // Then the first is declared lost. verify(lostPacketHandler, times(1)).process(argThat(new PacketMatcherByPacketNumber(6))); } @Test void packetNotTooOldIsNotDeclaredLost() { Instant now = Instant.now(); int timeDiff = defaultRtt - 1; // Give some time for processing. lossDetector.packetSent(createPacket(6), now.minusMillis(timeDiff), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(createPacket(8), now, lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.onAckReceived(new AckFrame(8L), Instant.now()); verify(lostPacketHandler, never()).process(any(QuicPacket.class)); } @Test void oldPacketLaterThanLargestAcknowledgedIsNotDeclaredLost() { Instant now = Instant.now(); int timeDiff = (defaultRtt * 9 / 8) + 10; lossDetector.packetSent(createPacket(1), now.minusMillis(timeDiff), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(createPacket(3), now.minusMillis(timeDiff), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.onAckReceived(new AckFrame(1L), Instant.now()); verify(lostPacketHandler, never()).process(any(QuicPacket.class)); } @Test void packetNotYetLostIsLostAfterLossTime() throws Exception { // Given two packets are sent at the same time and only the last is acked, exactly after RTT lossDetector.packetSent(createPacket(6), clock.instant(), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(createPacket(8), clock.instant(), lostPacket -> lostPacketHandler.process(lostPacket)); clock.fastForward(defaultRtt); lossDetector.onAckReceived(new AckFrame(8L), clock.instant()); verify(lostPacketHandler, never()).process(any(QuicPacket.class)); assertThat(lossDetector.getLossTime()).isNotNull(); // When time is progressing another 1/8 rtt and another ack is received clock.fastForward(defaultRtt / 8 + 1); lossDetector.onAckReceived(new AckFrame(9L), clock.instant()); // Then the first packet is also acked (because by that time, it is old enough) verify(lostPacketHandler, times(1)).process(argThat(new PacketMatcherByPacketNumber(6))); } @Test void ifAllPacketsAreLostThenLossTimeIsNotSet() { Instant now = Instant.now(); int timeDiff = (defaultRtt * 9 / 8) + 1; lossDetector.packetSent(createPacket(1), now.minusMillis(timeDiff), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(createPacket(5), now, lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(createPacket(8), now, lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.onAckReceived(new AckFrame(8L), Instant.now()); assertThat(lossDetector.getLossTime()).isNull(); } @Test void ifAllPacketsAreAckedThenLossTimeIsNotSet() { Instant now = Instant.now(); int timeDiff = defaultRtt / 2; lossDetector.packetSent(createPacket(1), now.minusMillis(timeDiff), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(createPacket(7), now, lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(createPacket(8), now, lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.onAckReceived(new AckFrame(List.of(new Range(7L, 8L), new Range(1L))), Instant.now()); assertThat(lossDetector.getLossTime()).isNull(); } @Test void ifAllPacketsAreAckedBeforeLossTimeThenLossTimeIsNotSet() { Instant now = Instant.now(); int timeDiff = defaultRtt / 2; lossDetector.packetSent(createPacket(1), now.minusMillis(timeDiff), lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(createPacket(7), now, lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.packetSent(createPacket(8), now, lostPacket -> lostPacketHandler.process(lostPacket)); lossDetector.onAckReceived(new AckFrame(List.of(new Range(8L), new Range(1L))), Instant.now()); assertThat(lossDetector.getLossTime()).isNotNull(); lossDetector.onAckReceived(new AckFrame(List.of(new Range(7L, 8L), new Range(1L))), Instant.now()); assertThat(lossDetector.getLossTime()).isNull(); } @Test void ackOnlyPacketShouldNotSetLossTime() { lossDetector.packetSent(createPacket(1, new AckFrame(1)), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(2), Instant.now(), p -> {}); lossDetector.onAckReceived(new AckFrame(new Range(2L)), Instant.now()); assertThat(lossDetector.getLossTime()).isNull(); } @Test void detectUnacked() { lossDetector.packetSent(createPacket(2), Instant.now(), p -> {}); assertThat(lossDetector.unAcked()).isNotEmpty(); } @Test void ackedPacketIsNotDetectedAsUnacked() { lossDetector.packetSent(createPacket(2), Instant.now(), p -> {}); lossDetector.onAckReceived(new AckFrame(2), Instant.now()); assertThat(lossDetector.unAcked()).isEmpty(); } @Test void lostPacketIsNotDetectedAsUnacked() throws Exception { // Given two packets sent lossDetector.packetSent(createPacket(2), clock.instant(), p -> {}); lossDetector.packetSent(createPacket(3), clock.instant(), p -> {}); // When after 2 rtt an Ack is received that will cause first packet to be lost clock.fastForward(defaultRtt * 2); lossDetector.onAckReceived(new AckFrame(3), clock.instant()); // So 2 will be lost. // Then the lost packet is not considered un-acknowlegded. assertThat(lossDetector.unAcked()).isEmpty(); } @Test void nonAckElicitingIsNotDetectedAsUnacked() { lossDetector.packetSent(createPacket(2, new AckFrame(0)), Instant.now(), p -> {}); assertThat(lossDetector.unAcked()).isEmpty(); } @Test void whenResetNoPacketsAreUnacked() { lossDetector.packetSent(createPacket(2), Instant.now(), p -> {}); lossDetector.reset(); assertThat(lossDetector.unAcked()).isEmpty(); } @Test void whenResetLossTimeIsUnset() { lossDetector.packetSent(createPacket(2), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(3), Instant.now(), p -> {}); lossDetector.onAckReceived(new AckFrame(3), Instant.now()); lossDetector.detectLostPackets(); assertThat(lossDetector.getLossTime()).isNotNull(); lossDetector.reset(); assertThat(lossDetector.getLossTime()).isNull(); } @Test void whenResetNoAckElicitingAreInFlight() { lossDetector.packetSent(createPacket(2), Instant.now(), p -> {}); assertThat(lossDetector.ackElicitingInFlight()).isTrue(); lossDetector.reset(); assertThat(lossDetector.ackElicitingInFlight()).isFalse(); } @Test void testNoAckedReceivedWhenNoAckReceived() { lossDetector.packetSent(createPacket(2), Instant.now(), p -> {}); assertThat(lossDetector.noAckedReceived()).isTrue(); } @Test void testNoAckedReceivedWhenAckReceived() { lossDetector.packetSent(createPacket(0), Instant.now(), p -> {}); lossDetector.onAckReceived(new AckFrame(0), Instant.now()); assertThat(lossDetector.noAckedReceived()).isFalse(); } @Test void whenCongestionControllerIsResetAllNonAckedPacketsShouldBeDiscarded() { lossDetector.packetSent(createPacket(0), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(1), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(2), Instant.now(), p -> {}); lossDetector.onAckReceived(new AckFrame(0), Instant.now()); lossDetector.reset(); verify(congestionController, times(1)).discard(argThat(l -> containsPackets(l, 1, 2))); } @Test void whenCongestionControllerIsResetAllNotLostPacketsShouldBeDiscarded() { lossDetector.packetSent(createPacket(0), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(1), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(8), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(9), Instant.now(), p -> {}); lossDetector.onAckReceived(new AckFrame(9), Instant.now()); lossDetector.reset(); verify(congestionController, times(1)).discard(argThat(l -> containsPackets(l, 8))); } @Test void packetWithConnectionCloseOnlyDoesNotIncreaseBytesInFlight() { lossDetector.packetSent(createPacket(0, new ConnectionCloseFrame(Version.getDefault())), Instant.now(), p -> {}); verify(congestionController, never()).registerInFlight(any(QuicPacket.class)); } @Test void ackPacketWithConnectionCloseOnlyDoesNotDecreaseBytesInFlight() { lossDetector.packetSent(createPacket(0, new ConnectionCloseFrame(Version.getDefault())), Instant.now(), p -> {}); lossDetector.onAckReceived(new AckFrame(0), Instant.now()); verify(congestionController, never()).registerAcked(argThat(l -> ! l.isEmpty())); // It's okay when it is called with an empty list } @Test void lostPacketWithConnectionCloseOnlyDoesNotDecreaseBytesInFlight() { lossDetector.packetSent(createPacket(0, new ConnectionCloseFrame(Version.getDefault())), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(1, new ConnectionCloseFrame(Version.getDefault())), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(2, new ConnectionCloseFrame(Version.getDefault())), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(9, new ConnectionCloseFrame(Version.getDefault())), Instant.now(), p -> {}); lossDetector.onAckReceived(new AckFrame(9), Instant.now()); verify(congestionController, never()).registerLost(argThat(l -> ! l.isEmpty())); // It's okay when it is called with an empty list } @Test void packetWithPaddingOnlyDoesIncreaseBytesInFlight() { lossDetector.packetSent(createPacket(0, new Padding(99)), Instant.now(), p -> {}); verify(congestionController, times(1)).registerInFlight(any(QuicPacket.class)); } @Test void lostPacketWithPaddingOnlyDoesNotDecreaseBytesInFlight() { lossDetector.packetSent(createPacket(0, new Padding(99)), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(1, new Padding(99)), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(2, new Padding(99)), Instant.now(), p -> {}); lossDetector.packetSent(createPacket(9, new Padding(99)), Instant.now(), p -> {}); lossDetector.onAckReceived(new AckFrame(9), Instant.now()); verify(congestionController, atLeast(1)).registerLost(any(List.class)); } @Test void congestionControlStateDoesNotChangeWithUnrelatedAck() throws Exception { congestionController = new NewRenoCongestionController(new NullLogger(), mock(CongestionControlEventListener.class)); setCongestionWindowSize(congestionController, 1240); FieldSetter.setField(lossDetector, LossDetector.class.getDeclaredField("congestionController"), congestionController); lossDetector.packetSent(new MockPacket(0, 12, EncryptionLevel.App, new PingFrame(), "packet 1"), Instant.now(), p -> {}); lossDetector.packetSent(new MockPacket(1, 1200, EncryptionLevel.App, new PingFrame(), "packet 2"), Instant.now(), p -> {}); lossDetector.packetSent(new MockPacket(2, 40, EncryptionLevel.App, new PingFrame(), "packet 1"), Instant.now(), p -> {}); assertThat(congestionController.remainingCwnd()).isLessThan(1); // An ack on a non-existent packet, shouldn't change anything. lossDetector.onAckReceived(new AckFrame(0), Instant.now()); assertThat(congestionController.remainingCwnd()).isLessThan(12 + 1); // Because the 12 is acked, the cwnd is increased by 12 too. } @Test void congestionControlStateDoesNotChangeWithIncorrectAck() throws Exception { congestionController = new NewRenoCongestionController(new NullLogger(), mock(CongestionControlEventListener.class)); setCongestionWindowSize(congestionController, 1240); FieldSetter.setField(lossDetector, LossDetector.class.getDeclaredField("congestionController"), congestionController); lossDetector.packetSent(new MockPacket(10, 1200, EncryptionLevel.App, new PingFrame(), "packet 1"), Instant.now(), p -> {}); lossDetector.packetSent(new MockPacket(11, 1200, EncryptionLevel.App, new PingFrame(), "packet 2"), Instant.now(), p -> {}); assertThat(congestionController.remainingCwnd()).isLessThan(1); // An ack on a non-existent packet, shouldn't change anything. lossDetector.onAckReceived(new AckFrame(3), Instant.now()); assertThat(congestionController.remainingCwnd()).isLessThan(1); } @Test void testAckElicitingInFlightAcked() { lossDetector.packetSent(new MockPacket(10, 1200, EncryptionLevel.App, new PingFrame(), "packet 1"), Instant.now(), p -> {}); lossDetector.packetSent(new MockPacket(11, 1200, EncryptionLevel.App, new Padding(10), "packet 2"), Instant.now(), p -> {}); lossDetector.packetSent(new MockPacket(12, 1200, EncryptionLevel.App, new PingFrame(), "packet 2"), Instant.now(), p -> {}); lossDetector.onAckReceived(new AckFrame(10), Instant.now()); assertThat(lossDetector.ackElicitingInFlight()).isTrue(); lossDetector.onAckReceived(new AckFrame(12), Instant.now()); assertThat(lossDetector.ackElicitingInFlight()).isFalse(); } @Test void testAckElicitingInFlightLost() { lossDetector.packetSent(new MockPacket(10, 1200, EncryptionLevel.App, new PingFrame(), "packet 1"), Instant.now(), p -> {}); lossDetector.packetSent(new MockPacket(11, 1200, EncryptionLevel.App, new Padding(10), "packet 2"), Instant.now(), p -> {}); lossDetector.packetSent(new MockPacket(15, 1200, EncryptionLevel.App, new PingFrame(), "packet 2"), Instant.now(), p -> {}); lossDetector.onAckReceived(new AckFrame(15), Instant.now()); assertThat(lossDetector.ackElicitingInFlight()).isFalse(); } // This test was used to reproduce a race condition in the LossDetector. It is of no use to run it in each build. // To check the test is actually testing the race condition, insert system.out.print's in reset and onAckReceived methods. // @Test void maybeReproduceRaceConditionInOnAckdReceived() throws InterruptedException { int numberOfTestRuns = 500; for (int tc = 1; tc <= numberOfTestRuns; tc++) { System.out.print("\n" + tc + ": "); final int testRun = tc; for (int i = 0; i < 10000; i++) { lossDetector.packetSent(new MockPacket(i, 100, "packet " + i), Instant.now(), p -> {}); } Thread lossDetectorResetThread = new Thread(() -> { for (int i = 0; i < 1; i++) { try { Thread.sleep(100); } catch (InterruptedException e) {} lossDetector.reset(); } }); Thread onAckReceivedThread = new Thread(() -> { for (int i = 0; i < 100; i++) { try { lossDetector.onAckReceived(new AckFrame(i), Instant.now()); } catch (Exception e) { System.out.println("ERROR in test run " + testRun + ": " + e); e.printStackTrace(); System.exit(1); } } }); onAckReceivedThread.start(); lossDetectorResetThread.start(); lossDetectorResetThread.join(); onAckReceivedThread.join(); } } private void setCongestionWindowSize(CongestionController congestionController, int cwnd) throws Exception { FieldSetter.setField(congestionController, congestionController.getClass().getSuperclass().getDeclaredField("congestionWindow"), cwnd); } private boolean containsPackets(List packets, long... packetNumbers) { List listPacketNumbers = packets.stream().map(p -> p.packet().getPacketNumber()).collect(Collectors.toList()); for (long pn: packetNumbers) { if (! listPacketNumbers.contains(pn)) { return false; } } return true; } }