/* * Copyright © 2020, 2021, 2022, 2023 Peter Doornbosch * * This file is part of Kwik, an implementation of the QUIC protocol in Java. * * Kwik is free software: you can redistribute it and/or modify it under * the terms of the GNU Lesser General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your option) * any later version. * * Kwik is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for * more details. * * You should have received a copy of the GNU Lesser General Public License * along with this program. If not, see . */ package net.luminis.quic.stream; import net.luminis.quic.EncryptionLevel; import net.luminis.quic.QuicClientConnectionImpl; import net.luminis.quic.Role; import net.luminis.quic.Version; import net.luminis.quic.frame.QuicFrame; import net.luminis.quic.frame.StreamFrame; import net.luminis.quic.log.Logger; import net.luminis.quic.log.NullLogger; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; import java.util.function.Consumer; import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.*; class EarlyDataStreamTest { private EarlyDataStream stream; private QuicClientConnectionImpl connection; private Logger logger; @BeforeEach void initObjectUnderTest() { connection = mock(QuicClientConnectionImpl.class); when(connection.getMaxShortHeaderPacketOverhead()).thenReturn(29); int maxData = 5000; FlowControl flowController = new FlowControl(Role.Client, maxData, maxData, maxData, maxData); logger = new NullLogger(); stream = new EarlyDataStream(Version.getDefault(), 0, connection, flowController, logger); } @Test void sendingEarlyDataResultsInZeroRttPacket() throws IOException { // When stream.writeEarlyData(new byte[10], false, 10_000); // Then StreamFrame frame = captureFrameSentAndVerifyEncryptionLevel(connection, 1500, EncryptionLevel.ZeroRTT); assertThat(frame.getStreamData().length).isEqualTo(10); } @Test void sendingFinalEarlyDataResultsInClosingStream() throws IOException { // When stream.writeEarlyData(new byte[10], true, 10_000); // Then StreamFrame frame = captureFrameSentAndVerifyEncryptionLevel(connection, 1500, EncryptionLevel.ZeroRTT); assertThat(frame.isFinal()).isTrue(); assertThat(frame.getStreamData().length).isEqualTo(10); } @Test void sendingLargeEarlyDataResultsInMultiplePackets() throws IOException { // When stream.writeEarlyData(new byte[1500], false, 10_000); // Then // Simulate first packet is sent (which will cause second send request to be queued) QuicFrame firstFrame = captureFrameSentAndVerifyEncryptionLevel(connection, 1300, EncryptionLevel.ZeroRTT); QuicFrame secondFrame = captureFrameSentAndVerifyEncryptionLevel(connection, 1300, EncryptionLevel.ZeroRTT); assertThat(((StreamFrame) firstFrame).getStreamData().length).isGreaterThan(1000); assertThat(((StreamFrame) secondFrame).getStreamData().length).isGreaterThan(200); assertThat(((StreamFrame) firstFrame).getStreamData().length + ((StreamFrame) secondFrame).getStreamData().length).isEqualTo(1500); } @Test void earlyDataShouldBeLimitedToFlowControlLimit() throws Exception { // Given int maxData = 1000; FlowControl flowController = new FlowControl(Role.Client, maxData, maxData, maxData, maxData); stream = new EarlyDataStream(Version.getDefault(), 0, connection, flowController, logger); // When stream.writeEarlyData(new byte[1500], false, 10_000); // Then StreamFrame frame1 = captureFrameSentAndVerifyEncryptionLevel(connection, 1500, EncryptionLevel.ZeroRTT); assertThat(frame1.getLength()).isEqualTo(1000); } @Test void earlyDataShouldBeLimitedToInitalMaxData() throws Exception { // When stream.writeEarlyData(new byte[1500], true, 500); // earlyDataSizeLeft should be set to initial max data from session ticket // Then StreamFrame frame1 = captureFrameSentAndVerifyEncryptionLevel(connection, 1500, EncryptionLevel.ZeroRTT); assertThat(frame1.getLength()).isEqualTo(500); } @Test void whenEarlyDataIsLimitedStreamIsNotClosed() throws Exception { // When stream.writeEarlyData(new byte[1500], true, 500); StreamFrame streamFrame = captureFrameSentAndVerifyEncryptionLevel(connection, 1300, EncryptionLevel.ZeroRTT); // Then assertThat(streamFrame.isFinal()).isFalse(); } @Test void whenWritingRemainingAllDataShouldHaveBeenSent() throws Exception { // Given byte[] data = new byte[1500]; for (int i = 0; i < data.length; i++) { data[i] = (byte) i; } // When stream.writeEarlyData(data, true, 500); stream.writeRemaining(true); // Then StreamFrame zeroRttData = captureFrameSentAndVerifyEncryptionLevel(connection, 1500, EncryptionLevel.ZeroRTT); StreamFrame oneRttData = captureFrameSentAndVerifyEncryptionLevel(connection, 1500, EncryptionLevel.App); byte[] transmittedData = transmittedByteStream(List.of(zeroRttData, oneRttData)); assertThat(transmittedData).isEqualTo(data); assertThat(zeroRttData.isFinal()).isFalse(); assertThat(oneRttData.isFinal()).isTrue(); } @Test void whenEarlyDataWasNotAcceptedWritingRemainingShouldSendAll() throws Exception { // Given byte[] data = new byte[1500]; for (int i = 0; i < data.length; i++) { data[i] = (byte) i; } // When stream.writeEarlyData(data, true, 500); // Then StreamFrame zeroRttData = captureFrameSentAndVerifyEncryptionLevel(connection, 1200, EncryptionLevel.ZeroRTT); stream.writeRemaining(false); StreamFrame oneRttData = captureFrameSentAndVerifyEncryptionLevel(connection, 1200, EncryptionLevel.App); StreamFrame oneRttData2 = captureFrameSentAndVerifyEncryptionLevel(connection, 1200, EncryptionLevel.App); byte[] transmittedData = transmittedByteStream(List.of(oneRttData, oneRttData2)); assertThat(transmittedData).isEqualTo(data); assertThat(oneRttData.getOffset()).isEqualTo(0); assertThat(oneRttData2.isFinal()).isTrue(); } @Test void whenAllEarlyDataWasSentNoRemainingShouldBeSend() throws Exception { // Given stream.writeEarlyData(new byte[100], true, 10_000); // When clearInvocations(connection); stream.writeRemaining(true); // Then verify(connection, never()).send(any(Function.class), anyInt(), any(EncryptionLevel.class), any(Consumer.class)); } StreamFrame captureFrameSentAndVerifyEncryptionLevel(QuicClientConnectionImpl connection, int maxFrameSize, EncryptionLevel expectedLevel) { ArgumentCaptor> frameSupplierCaptor = ArgumentCaptor.forClass(Function.class); verify(connection, times(1)).send(frameSupplierCaptor.capture(), anyInt(), argThat(l -> l == expectedLevel), any(Consumer.class), anyBoolean()); clearInvocations(connection); return (StreamFrame) frameSupplierCaptor.getValue().apply(maxFrameSize); } byte[] transmittedByteStream(List streamFrames) { int totalSize = streamFrames.stream().mapToInt(f -> f.getLength()).sum(); ByteBuffer buffer = ByteBuffer.allocate(totalSize); streamFrames.stream() .map(frame -> frame.getStreamData()) .forEach(byteArray -> buffer.put(byteArray)); return buffer.array(); } }