/* * Copyright © 2021, 2022, 2023 Peter Doornbosch * * This file is part of Kwik, an implementation of the QUIC protocol in Java. * * Kwik is free software: you can redistribute it and/or modify it under * the terms of the GNU Lesser General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your option) * any later version. * * Kwik is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for * more details. * * You should have received a copy of the GNU Lesser General Public License * along with this program. If not, see . */ package net.luminis.quic.server; import net.luminis.quic.*; import net.luminis.quic.KeyUtils; import net.luminis.quic.Version; import net.luminis.quic.crypto.ConnectionSecrets; import net.luminis.quic.frame.FrameProcessor; import net.luminis.quic.frame.QuicFrame; import net.luminis.quic.packet.HandshakePacket; import net.luminis.quic.packet.QuicPacket; import net.luminis.quic.packet.RetryPacket; import net.luminis.quic.stream.StreamManager; import net.luminis.quic.test.FieldReader; import net.luminis.quic.tls.QuicTransportParametersExtension; import net.luminis.quic.frame.ConnectionCloseFrame; import net.luminis.quic.frame.CryptoFrame; import net.luminis.quic.log.Logger; import net.luminis.quic.packet.InitialPacket; import net.luminis.quic.send.SenderImpl; import net.luminis.tls.*; import net.luminis.tls.alert.HandshakeFailureAlert; import net.luminis.tls.extension.ApplicationLayerProtocolNegotiationExtension; import net.luminis.tls.extension.Extension; import net.luminis.tls.handshake.*; import net.luminis.tls.util.ByteUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import net.luminis.quic.test.FieldSetter; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import java.io.IOException; import java.net.DatagramSocket; import java.net.InetAddress; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Stream; import static net.luminis.quic.QuicConstants.TransportParameterId.*; import static org.assertj.core.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; class ServerConnectionImplTest { public static final String DEFAULT_APPLICATION_PROTOCOL = "hq-29"; private ServerConnectionImpl connection; private ApplicationLayerProtocolNegotiationExtension alpn = new ApplicationLayerProtocolNegotiationExtension(DEFAULT_APPLICATION_PROTOCOL); private TlsServerEngine tlsServerEngine; private TlsServerEngineFactory tlsServerEngineFactory; @BeforeEach void setupObjectUnderTest() throws Exception { tlsServerEngineFactory = createTlsServerEngine(); connection = createServerConnection(tlsServerEngineFactory, false, new byte[8]); } @Test void whenParsingClientHelloLeadsToTlsErrorConnectionIsClosed() throws Exception { // When connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, new CryptoFrame(Version.getDefault(), new byte[123])), Instant.now()); // Then verify(connection.getSender()).send(argThat(frame -> frame instanceof ConnectionCloseFrame), eq(EncryptionLevel.Initial)); } @Test void engineNotBeingAbleToNegotiateCipherShouldCloseConnection() throws Exception { // Given ((MockTlsServerEngine) tlsServerEngine).injectErrorInReceivingClientHello(() -> new HandshakeFailureAlert("")); // When List clientExtensions = List.of(alpn, createTransportParametersExtension()); ClientHello ch = new ClientHello("localhost", KeyUtils.generatePublicKey(), false, List.of(TlsConstants.CipherSuite.TLS_CHACHA20_POLY1305_SHA256), List.of(TlsConstants.SignatureScheme.rsa_pss_pss_sha256), TlsConstants.NamedGroup.secp256r1, clientExtensions, null, ClientHello.PskKeyEstablishmentMode.both); CryptoFrame cryptoFrame = new CryptoFrame(Version.getDefault(), ch.getBytes()); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, cryptoFrame), Instant.now()); // Then verify(connection.getSender()).send(argThat(frame -> frame instanceof ConnectionCloseFrame), eq(EncryptionLevel.Initial)); } @Test void failingAlpnNegotiationLeadsToCloseConnection() throws Exception { // When List clientExtensions = List.of(new ApplicationLayerProtocolNegotiationExtension("h2"), createTransportParametersExtension()); ClientHello ch = new ClientHello("localhost", KeyUtils.generatePublicKey(), false, clientExtensions); CryptoFrame cryptoFrame = new CryptoFrame(Version.getDefault(), ch.getBytes()); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, cryptoFrame), Instant.now()); // Then verify(connection.getSender()).send(argThat(frame -> frame instanceof ConnectionCloseFrame && ((ConnectionCloseFrame) frame).getErrorCode() == 0x100 + TlsConstants.AlertDescription.no_application_protocol.value), eq(EncryptionLevel.Initial)); } @Test void clientHelloLackingTransportParametersExtensionLeadsToConnectionClose() throws Exception { // When List clientExtensions = List.of(alpn); ClientHello ch = new ClientHello("localhost", KeyUtils.generatePublicKey(), false, clientExtensions); CryptoFrame cryptoFrame = new CryptoFrame(Version.getDefault(), ch.getBytes()); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, cryptoFrame), Instant.now()); // Then verify(connection.getSender()).send(argThat(frame -> frame instanceof ConnectionCloseFrame && ((ConnectionCloseFrame) frame).getErrorCode() == 0x100 + TlsConstants.AlertDescription.missing_extension.value), eq(EncryptionLevel.Initial)); } @Test void clientHelloWithCorrectTransportParametersIsAccepted() throws Exception { // When List clientExtensions = List.of(alpn, createTransportParametersExtension()); ClientHello ch = new ClientHello("localhost", KeyUtils.generatePublicKey(), false, clientExtensions); CryptoFrame cryptoFrame = new CryptoFrame(Version.getDefault(), ch.getBytes()); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, cryptoFrame), Instant.now()); // Then List serverExtensions = tlsServerEngine.getServerExtensions(); assertThat(serverExtensions).hasAtLeastOneElementOfType(QuicTransportParametersExtension.class); } @ParameterizedTest @MethodSource("provideTransportParametersWithInvalidValue") void whenTransportParametersContainsInvalidValueServerShouldCloseConnection(TransportParameters tp) throws Exception { // When QuicTransportParametersExtension transportParametersExtension = new QuicTransportParametersExtension(Version.getDefault(), tp, Role.Client); List clientExtensions = List.of(alpn, transportParametersExtension); ClientHello ch = new ClientHello("localhost", KeyUtils.generatePublicKey(), false, clientExtensions); CryptoFrame cryptoFrame = new CryptoFrame(Version.getDefault(), ch.getBytes()); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, cryptoFrame), Instant.now()); // Then verify(connection.getSender()).send(argThat(frame -> frame instanceof ConnectionCloseFrame && ((ConnectionCloseFrame) frame).getErrorCode() == 0x08), eq(EncryptionLevel.Initial)); } @ParameterizedTest @MethodSource("provideInvalidTransportParametersForClient") void whenTransportParametersContainsInvalidParameterServerShouldCloseConnection(TransportParameters tp) throws Exception { // When QuicTransportParametersExtension transportParametersExtension = new QuicTransportParametersExtensionTest(tp); List clientExtensions = List.of(alpn, transportParametersExtension); ClientHello ch = new ClientHello("localhost", KeyUtils.generatePublicKey(), false, clientExtensions); CryptoFrame cryptoFrame = new CryptoFrame(Version.getDefault(), ch.getBytes()); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, cryptoFrame), Instant.now()); // Then verify(connection.getSender()).send(argThat(frame -> frame instanceof ConnectionCloseFrame && ((ConnectionCloseFrame) frame).getErrorCode() == 0x08), eq(EncryptionLevel.Initial)); } @Test void whenTransportParametersAreProcessedStreamManagerDefaultsShouldHaveBeenSet() throws Exception { // Given StreamManager streamManager = mock(StreamManager.class); FieldSetter.setField(connection, connection.getClass().getDeclaredField("streamManager"), streamManager); QuicTransportParametersExtension transportParametersExtension = createTransportParametersExtension(); transportParametersExtension.getTransportParameters().setInitialMaxStreamsUni(3); transportParametersExtension.getTransportParameters().setInitialMaxStreamsBidi(100); List clientExtensions = List.of(alpn, transportParametersExtension); ClientHello ch = new ClientHello("localhost", KeyUtils.generatePublicKey(), false, clientExtensions); CryptoFrame cryptoFrame = new CryptoFrame(Version.getDefault(), ch.getBytes()); // When connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, cryptoFrame), Instant.now()); // Then verify(streamManager).setInitialMaxStreamsUni(longThat(value -> value == 3)); verify(streamManager).setInitialMaxStreamsBidi(longThat(value -> value == 100)); } @Test void versionInformationWithSupportedOtherVersionLeadsToVersionChange() throws Exception { var connectionSecrets = spyOnConnectionSecrets(); // Given TransportParameters.VersionInformation versionInfo = new TransportParameters.VersionInformation(Version.QUIC_version_1, List.of(Version.QUIC_version_2, Version.QUIC_version_1)); List clientExtensions = List.of(alpn, createTransportParametersExtension(versionInfo)); ClientHello ch = new ClientHello("localhost", KeyUtils.generatePublicKey(), false, clientExtensions); CryptoFrame cryptoFrame = new CryptoFrame(Version.QUIC_version_1, ch.getBytes()); // When connection.process(new InitialPacket(Version.QUIC_version_1, new byte[8], new byte[8], null, cryptoFrame), Instant.now()); // Then assertThat(connection.getQuicVersion().equals(Version.QUIC_version_2)); verify(connectionSecrets).recomputeInitialKeys(); } @Test void versionInformationWithoutSupportedOtherVersionLeadsToNoVersionChange() throws Exception { var connectionSecrets = spyOnConnectionSecrets(); // Given TransportParameters.VersionInformation versionInfo = new TransportParameters.VersionInformation(Version.QUIC_version_1, List.of(Version.parse(0x1a2a3a4a), Version.QUIC_version_1)); List clientExtensions = List.of(alpn, createTransportParametersExtension(versionInfo)); ClientHello ch = new ClientHello("localhost", KeyUtils.generatePublicKey(), false, clientExtensions); CryptoFrame cryptoFrame = new CryptoFrame(Version.QUIC_version_1, ch.getBytes()); // When connection.process(new InitialPacket(Version.QUIC_version_1, new byte[8], new byte[8], null, cryptoFrame), Instant.now()); // Then assertThat(connection.getQuicVersion().equals(Version.QUIC_version_1)); verify(connectionSecrets, never()).recomputeInitialKeys(); } @Test void serverShouldSendAlpnAndQuicTransportParameterExtensions() throws Exception { // When List clientExtensions = List.of(alpn, createTransportParametersExtension()); ClientHello ch = new ClientHello("localhost", KeyUtils.generatePublicKey(), false, clientExtensions); CryptoFrame cryptoFrame = new CryptoFrame(Version.getDefault(), ch.getBytes()); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, cryptoFrame), Instant.now()); // Then TlsServerEngine tlsEngine = (TlsServerEngine) new FieldReader(connection, connection.getClass().getDeclaredField("tlsEngine")).read(); assertThat(tlsEngine.getServerExtensions()).hasAtLeastOneElementOfType(ApplicationLayerProtocolNegotiationExtension.class); assertThat(tlsEngine.getServerExtensions()).hasAtLeastOneElementOfType(QuicTransportParametersExtension.class); } @Test void serverShouldSendTransportParameterDisableActiveMigration() throws Exception { // When List clientExtensions = List.of(alpn, createTransportParametersExtension()); ClientHello ch = new ClientHello("localhost", KeyUtils.generatePublicKey(), false, clientExtensions); CryptoFrame cryptoFrame = new CryptoFrame(Version.getDefault(), ch.getBytes()); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, cryptoFrame), Instant.now()); // Then TlsServerEngine tlsEngine = (TlsServerEngine) new FieldReader(connection, connection.getClass().getDeclaredField("tlsEngine")).read(); assertThat(tlsEngine.getServerExtensions()).hasAtLeastOneElementOfType(QuicTransportParametersExtension.class); QuicTransportParametersExtension tpExtension = (QuicTransportParametersExtension) tlsEngine.getServerExtensions().stream().filter(ext -> ext instanceof QuicTransportParametersExtension).findFirst().get(); assertThat(tpExtension.getTransportParameters().getDisableMigration()).isTrue(); } @Test void retransmittedOriginalInitialMessageIsProcessedToo() throws Exception { byte[] odcid = new byte[] { 0x0f, 0x0e, 0x0d, 0x0c, 0x0b, 0x0a, 0x09, 0x08 }; connection = createServerConnection(tlsServerEngineFactory, false, odcid); CryptoFrame firstFrame = mock(CryptoFrame.class); CryptoFrame secondFrame = mock(CryptoFrame.class); InitialPacket packet1 = new InitialPacket(Version.getDefault(), new byte[8], odcid, null, firstFrame); InitialPacket packet2 = new InitialPacket(Version.getDefault(), new byte[8], odcid, null, secondFrame); connection.process(packet1, Instant.now()); connection.process(packet2, Instant.now()); verify(firstFrame).accept(any(FrameProcessor.class), any(QuicPacket.class), any(Instant.class)); verify(secondFrame).accept(any(FrameProcessor.class), any(QuicPacket.class), any(Instant.class)); } @Test void newServerConnectionUsesOriginalScidAsDcid() throws Exception { byte[] clientSourceCid = new byte[] { 0x03, 0x07, 0x05, 0x01 }; byte[] odcid = new byte[] { 0x0f, 0x0e, 0x0d, 0x0c, 0x0b, 0x0a, 0x09, 0x08 }; // When connection = createServerConnection(tlsServerEngineFactory, false, clientSourceCid, odcid, cid -> {}); // Then assertThat(connection.getDestinationConnectionId()).isEqualTo(clientSourceCid); } @Test void whenRetryIsRequiredFirstInitialLeadsToRetryPacket() throws Exception { // Given connection = createServerConnection(createTlsServerEngine(), true, new byte[8]); // When connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, new CryptoFrame()), Instant.now()); // Then verify(connection.getSender()).send(any(RetryPacket.class)); } @Test void whenRetryIsRequiredAllRetryPacketsContainsSameToken() throws Exception { // Given connection = createServerConnection(createTlsServerEngine(), true, new byte[8]); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, new CryptoFrame()), Instant.now()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RetryPacket.class); verify(connection.getSender()).send(argumentCaptor.capture()); byte[] retryToken = argumentCaptor.getValue().getRetryToken(); clearInvocations(connection.getSender()); // When connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, new CryptoFrame()), Instant.now()); // Then verify(connection.getSender()).send(argThat(retryPacket -> Arrays.equals(retryPacket.getRetryToken(), retryToken))); } @Test void whenRetryIsRequiredDifferentDestinationConnectionIdsGetDifferentToken() throws Exception { // Given byte[] dcid1 = new byte[] { 0, 1, 2, 3, 4, 5, 6, 7, 8 }; ServerConnectionImpl connection1 = createServerConnection(createTlsServerEngine(), true, dcid1); connection1.process(new InitialPacket(Version.getDefault(), new byte[8], dcid1, null, new CryptoFrame()), Instant.now()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RetryPacket.class); verify(connection1.getSender()).send(argumentCaptor.capture()); byte[] retryToken = argumentCaptor.getValue().getRetryToken(); // When byte[] dcid2 = new byte[] { 8, 7, 6, 5, 4, 3, 2, 1, 0 }; ServerConnectionImpl connection2 = createServerConnection(createTlsServerEngine(), true, dcid2); connection2.process(new InitialPacket(Version.getDefault(), new byte[8], dcid2, null, new CryptoFrame()), Instant.now()); // Then verify(connection2.getSender()).send(argThat(retryPacket -> !Arrays.equals(retryPacket.getRetryToken(), retryToken))); } @Test void whenRetryIsRequiredInitialWithTokenIsProcessed() throws Exception { // Given connection = createServerConnection(createTlsServerEngine(), true, null); connection = createServerConnection(createTlsServerEngine(), true, new byte[8]); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, new CryptoFrame()), Instant.now()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RetryPacket.class); verify(connection.getSender()).send(argumentCaptor.capture()); byte[] retryToken = argumentCaptor.getValue().getRetryToken(); clearInvocations(connection.getSender()); // When ClientHello ch = new ClientHello("testserver", KeyUtils.generatePublicKey(), false, Collections.emptyList()); CryptoFrame initialCrypto = new CryptoFrame(Version.getDefault(), ch.getBytes()); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], retryToken, initialCrypto), Instant.now()); // Then verify(connection.getSender()).send(argThat(frame -> frame instanceof ConnectionCloseFrame && ((ConnectionCloseFrame) frame).getErrorCode() == 256 + TlsConstants.AlertDescription.missing_extension.value), any(EncryptionLevel.class)); } @Test void whenRetryIsRequiredInitialWithInvalidTokenConnectionIsClosed() throws Exception { // Given connection = createServerConnection(createTlsServerEngine(), true, new byte[8]); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, new CryptoFrame()), Instant.now()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RetryPacket.class); verify(connection.getSender()).send(argumentCaptor.capture()); byte[] retryToken = argumentCaptor.getValue().getRetryToken(); byte[] incorrectToken = Arrays.copyOfRange(retryToken, 0, retryToken.length - 1); // When connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], incorrectToken, new CryptoFrame()), Instant.now()); // Then verify(connection.getSender()).send(argThat(frame -> frame instanceof ConnectionCloseFrame && ((ConnectionCloseFrame) frame).getErrorCode() == 0x0b), any(EncryptionLevel.class)); } @Test void whenRetryIsRequiredSecondInitialShouldReturnSameRetryPacket() throws Exception { // Given byte[] odcid = { 0x0f, 0x0e, 0x0d, 0x0c, 0x0b, 0x0a, 0x09, 0x08 }; connection = createServerConnection(createTlsServerEngine(), true, odcid); InitialPacket initialPacket = new InitialPacket(Version.getDefault(), new byte[8], odcid, null, new CryptoFrame(Version.getDefault(), new byte[38])); initialPacket.setPacketNumber(0); ConnectionSecrets clientConnectionSecrets = new ConnectionSecrets(VersionHolder.withDefault(), Role.Client, null, mock(Logger.class)); clientConnectionSecrets.computeInitialKeys(odcid); byte[] initialPacketBytes = initialPacket.generatePacketBytes(clientConnectionSecrets.getClientAead(EncryptionLevel.Initial)); ByteBuffer paddedInitial = ByteBuffer.allocate(1200); paddedInitial.put(initialPacketBytes); connection.parseAndProcessPackets(0, Instant.now(), paddedInitial, null); ArgumentCaptor argumentCaptor1 = ArgumentCaptor.forClass(RetryPacket.class); verify(connection.getSender()).send(argumentCaptor1.capture()); byte[] retryPacket1 = argumentCaptor1.getValue().generatePacketBytes(null); clearInvocations(connection.getSender()); // When connection.parseAndProcessPackets(0, Instant.now(), paddedInitial, null); ArgumentCaptor argumentCaptor2 = ArgumentCaptor.forClass(RetryPacket.class); verify(connection.getSender()).send(argumentCaptor2.capture()); RetryPacket retryPacket2 = argumentCaptor1.getValue(); // Then assertThat(retryPacket1).isEqualTo(retryPacket2.generatePacketBytes(null)); } @Test void whenServerConnectionIsAbortedCloseCallbackShouldBeCalled() throws Exception { // Given byte[] odcid = { 0x0f, 0x0e, 0x0d, 0x0c, 0x0b, 0x0a, 0x09, 0x08 }; AtomicBoolean closeCallbackIsCalled = new AtomicBoolean(false); connection = createServerConnection(createTlsServerEngine(), true, new byte[8], odcid, cid -> closeCallbackIsCalled.set(true)); // When connection.abortConnection(new RuntimeException("injected error")); // Then assertThat(closeCallbackIsCalled.get()).isTrue(); } @Test void receivingInitialPacketShouldSetAntiAmplification() throws Exception { // Given byte[] odcid = ByteUtils.hexToBytes("67268378ae7dc13b"); connection = createServerConnection(tlsServerEngineFactory, false, odcid); // When byte[] validInitial = TestUtils.createValidInitial(Version.getDefault()); connection.parseAndProcessPackets(0, Instant.now(), ByteBuffer.wrap(validInitial), null); ArgumentCaptor antiAmplificationLimitCaptor = ArgumentCaptor.forClass(Integer.class); // Then verify(connection.getSender()).setAntiAmplificationLimit(antiAmplificationLimitCaptor.capture()); assertThat(antiAmplificationLimitCaptor.getValue()).isEqualTo(3 * validInitial.length); } @Test void receivingInvalidInitialPacketShouldAddToAntiAmplificationLimit() throws Exception { // When byte[] invalidInitial = TestUtils.createInvalidInitial(Version.getDefault()); connection.parseAndProcessPackets(0, Instant.now(), ByteBuffer.wrap(invalidInitial), null); // Then ArgumentCaptor antiAmplificationLimitCaptor = ArgumentCaptor.forClass(Integer.class); verify(connection.getSender()).setAntiAmplificationLimit(antiAmplificationLimitCaptor.capture()); assertThat(antiAmplificationLimitCaptor.getValue()).isEqualTo(3 * invalidInitial.length); } @Test void whenPeerAddressValidatedAntiAmplificationIsDisabled() { // When connection.process(new HandshakePacket(Version.getDefault(), new byte[0], new byte[0], new CryptoFrame(Version.getDefault(), new byte[300])), Instant.now()); // Then verify(connection.getSender()).unsetAntiAmplificationLimit(); } @Test void whenRetryIsRequiredInitialWithValidTokenDisablesAntiAmplificationLimit() throws Exception { // Given connection = createServerConnection(createTlsServerEngine(), true, new byte[8]); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], null, new CryptoFrame()), Instant.now()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RetryPacket.class); verify(connection.getSender()).send(argumentCaptor.capture()); byte[] retryToken = argumentCaptor.getValue().getRetryToken(); clearInvocations(connection.getSender()); // When ClientHello ch = new ClientHello("testserver", KeyUtils.generatePublicKey(), false, Collections.emptyList()); CryptoFrame initialCrypto = new CryptoFrame(Version.getDefault(), ch.getBytes()); connection.process(new InitialPacket(Version.getDefault(), new byte[8], new byte[8], retryToken, initialCrypto), Instant.now()); // Then verify(connection.getSender()).unsetAntiAmplificationLimit(); } @Test void initialPacketCarriedInDatagramSmallerThan1200BytesShouldBeDropped() throws Exception { // Given byte[] initialPacketBytes = TestUtils.createValidInitialNoPadding(Version.getDefault()); byte[] odcid = Arrays.copyOfRange(initialPacketBytes, 6, 6 + 8); connection = createServerConnection(tlsServerEngineFactory, false, odcid); // When connection.parseAndProcessPackets(0, Instant.now(), ByteBuffer.wrap(initialPacketBytes), null); // Then verify(connection.getSender(), never()).send(any(QuicFrame.class), any(EncryptionLevel.class)); } @Test void initialPacketWithPaddingInDatagramShouldBeAccepted() throws Exception { // Given byte[] initialPacketBytes = TestUtils.createValidInitialNoPadding(Version.getDefault()); byte[] odcid = Arrays.copyOfRange(initialPacketBytes, 6, 6 + 8); connection = createServerConnection(tlsServerEngineFactory, false, odcid); // When ByteBuffer buffer = ByteBuffer.allocate(1200); buffer.put(initialPacketBytes); connection.parseAndProcessPackets(0, Instant.now(), buffer, null); // Then verify(connection.getSender(), atLeastOnce()).send(any(QuicFrame.class), any(EncryptionLevel.class)); } @Test void whenInitialPacketPaddedInDatagramAllBytesShouldBeCountedInAntiAmplificationLimit() throws Exception { // Given byte[] initialPacketBytes = TestUtils.createValidInitialNoPadding(Version.getDefault()); byte[] odcid = Arrays.copyOfRange(initialPacketBytes, 6, 6 + 8); connection = createServerConnection(tlsServerEngineFactory, false, odcid); // When ByteBuffer buffer = ByteBuffer.allocate(1200); buffer.put(initialPacketBytes); connection.parseAndProcessPackets(0, Instant.now(), buffer, null); // Then ArgumentCaptor antiAmplicationLimitCaptor = ArgumentCaptor.forClass(Integer.class); verify(connection.getSender()).setAntiAmplificationLimit(antiAmplicationLimitCaptor.capture()); assertThat(antiAmplicationLimitCaptor.getValue()).isEqualTo(3 * 1200); } @Test void whenParsingZeroRttPacketItShouldFailOnMissingKeys() throws Exception { // Given byte[] data = { (byte) 0b11010001, 0x00, 0x00, 0x00, 0x01, 0, 0, 17, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }; assertThatThrownBy(() -> // When connection.parsePacket(ByteBuffer.wrap(data)) ) // Then .isInstanceOf(MissingKeysException.class) .hasMessageContaining("ZeroRTT"); } static Stream provideTransportParametersWithInvalidValue() { TransportParameters invalidMaxStreamsBidi = createDefaultTransportParameters(); invalidMaxStreamsBidi.setInitialMaxStreamsBidi(0x1000000000000001l); TransportParameters invalidMaxUdpPayloadSize = createDefaultTransportParameters(); invalidMaxUdpPayloadSize.setMaxUdpPayloadSize(1199); TransportParameters invalidAckDelayExponent = createDefaultTransportParameters(); invalidAckDelayExponent.setAckDelayExponent(21); TransportParameters invalidMaxAckDelay = createDefaultTransportParameters(); invalidMaxAckDelay.setMaxAckDelay(0x4001); // 2^14 + 1 TransportParameters invalidActiveConnectionIdLimit = createDefaultTransportParameters(); invalidActiveConnectionIdLimit.setActiveConnectionIdLimit(1); TransportParameters incorrectInitialSourceConnectionId = createDefaultTransportParameters(); incorrectInitialSourceConnectionId.setInitialSourceConnectionId(new byte[] { 0, 0, 7, 0, 0, 0, 0, 0 }); return Stream.of(invalidMaxStreamsBidi, invalidMaxUdpPayloadSize, invalidAckDelayExponent, invalidMaxAckDelay, invalidActiveConnectionIdLimit, incorrectInitialSourceConnectionId); } static Stream provideInvalidTransportParametersForClient() { TransportParameters withOriginalDestinationConnectionId = createDefaultTransportParameters(); withOriginalDestinationConnectionId.setOriginalDestinationConnectionId(new byte[8]); TransportParameters withPreferredAddress = createDefaultTransportParameters(); withPreferredAddress.setPreferredAddress(new TransportParameters.PreferredAddress()); TransportParameters withRetrySourceConnectionId = createDefaultTransportParameters(); withRetrySourceConnectionId.setRetrySourceConnectionId(new byte[8]); TransportParameters withStatelessResetToken = createDefaultTransportParameters(); withStatelessResetToken.setStatelessResetToken(new byte[16]); return Stream.of(withOriginalDestinationConnectionId, withPreferredAddress, withRetrySourceConnectionId, withStatelessResetToken); } private ServerConnectionImpl createServerConnection(TlsServerEngineFactory tlsServerEngineFactory, boolean retryRequired, byte[] odcid) throws Exception { if (odcid == null) { odcid = new byte[]{ 0x0f, 0x0e, 0x0d, 0x0c, 0x0b, 0x0a, 0x09, 0x08 }; } return createServerConnection(tlsServerEngineFactory, retryRequired, new byte[8], odcid, cid -> {}); } private ServerConnectionImpl createServerConnection(TlsServerEngineFactory tlsServerEngineFactory, boolean retryRequired, byte[] clientCid, byte[] odcid, Consumer closeCallback) throws Exception { ApplicationProtocolRegistry applicationProtocolRegistry = new ApplicationProtocolRegistry(); applicationProtocolRegistry.registerApplicationProtocol("hq-29", mock(ApplicationProtocolConnectionFactory.class)); ServerConnectionImpl connection = new ServerConnectionImpl(Version.getDefault(), mock(DatagramSocket.class), new InetSocketAddress(InetAddress.getLoopbackAddress(), 6000), clientCid, odcid, 8, tlsServerEngineFactory, retryRequired, applicationProtocolRegistry, 100, null, closeCallback, mock(Logger.class)); SenderImpl sender = mock(SenderImpl.class); FieldSetter.setField(connection, connection.getClass().getDeclaredField("sender"), sender); return connection; } private TlsServerEngineFactory createTlsServerEngine() { TlsServerEngineFactory tlsServerEngineFactory = mock(TlsServerEngineFactory.class); when(tlsServerEngineFactory.createServerEngine(any(ServerMessageSender.class), any(TlsStatusEventHandler.class))).then(new Answer() { @Override public TlsServerEngine answer(InvocationOnMock invocation) throws Throwable { tlsServerEngine = new MockTlsServerEngine(mock(X509Certificate.class), null, invocation.getArgument(0), invocation.getArgument(1)); return tlsServerEngine; } }); return tlsServerEngineFactory; } private static TransportParameters createDefaultTransportParameters() { TransportParameters tp = new TransportParameters(); tp.setInitialSourceConnectionId(new byte[8]); return tp; } private QuicTransportParametersExtension createTransportParametersExtension() { return new QuicTransportParametersExtension(Version.getDefault(), createDefaultTransportParameters(), Role.Client); } private QuicTransportParametersExtension createTransportParametersExtension(TransportParameters.VersionInformation versionInfo) { TransportParameters transportParameters = createDefaultTransportParameters(); transportParameters.setVersionInformation(versionInfo); return new QuicTransportParametersExtension(Version.getDefault(), transportParameters, Role.Client); } private ConnectionSecrets spyOnConnectionSecrets() throws Exception { ConnectionSecrets connectionSecrets = spy((ConnectionSecrets) new FieldReader(connection, QuicConnectionImpl.class.getDeclaredField("connectionSecrets")).read()); FieldSetter.setField(connection, QuicConnectionImpl.class. getDeclaredField("connectionSecrets"), connectionSecrets); return connectionSecrets; } static class MockTlsServerEngine extends TlsServerEngine { private Supplier exceptionSupplier; public MockTlsServerEngine(X509Certificate serverCertificate, PrivateKey certificateKey, ServerMessageSender serverMessageSender, TlsStatusEventHandler tlsStatusHandler) { super(serverCertificate, certificateKey, serverMessageSender, tlsStatusHandler, null); } @Override public void received(ClientHello clientHello, ProtectionKeysType keyType) throws TlsProtocolException, IOException { if (exceptionSupplier != null) { throw exceptionSupplier.get(); } statusHandler.extensionsReceived(clientHello.getExtensions()); } public void injectErrorInReceivingClientHello(Supplier exceptionSupplier) { this.exceptionSupplier = exceptionSupplier; } } /** * For testing behaviour when invalid parameters are sent (for client or server), the serialize method must be * overridden, because the original will check for each parameter whether it is valid to sent for the given role. */ static class QuicTransportParametersExtensionTest extends QuicTransportParametersExtension { private TransportParameters transportParameters; QuicTransportParametersExtensionTest(TransportParameters transportParameters) { super(Version.getDefault(), transportParameters, Role.Client); this.transportParameters = transportParameters; } @Override protected void serialize() { super.serialize(); ByteBuffer extendedBuffer = ByteBuffer.allocate(1024); extendedBuffer.put(getBytes()); if (transportParameters.getOriginalDestinationConnectionId() != null) { addTransportParameter(extendedBuffer, original_destination_connection_id, transportParameters.getOriginalDestinationConnectionId()); } if (transportParameters.getPreferredAddress() != null) { byte[] addressData = new byte[41]; addressData[0] = 123; // IP address must not be all 0 addTransportParameter(extendedBuffer, preferred_address, addressData); } if (transportParameters.getRetrySourceConnectionId() != null) { addTransportParameter(extendedBuffer, retry_source_connection_id, transportParameters.getRetrySourceConnectionId()); } if (transportParameters.getStatelessResetToken() != null) { addTransportParameter(extendedBuffer, stateless_reset_token, transportParameters.getStatelessResetToken()); } int length = extendedBuffer.position(); extendedBuffer.limit(length); int extensionsSize = length - 2 - 2; // 2 bytes for the length itself and 2 for the type extendedBuffer.putShort(2, (short) extensionsSize); byte[] data = new byte[length]; extendedBuffer.flip(); extendedBuffer.get(data); FieldSetter.setField(this, QuicTransportParametersExtension.class, "data", data); } } }