/*
* Copyright © 2021, 2022, 2023 Peter Doornbosch
*
* This file is part of Kwik, an implementation of the QUIC protocol in Java.
*
* Kwik is free software: you can redistribute it and/or modify it under
* the terms of the GNU Lesser General Public License as published by the
* Free Software Foundation, either version 3 of the License, or (at your option)
* any later version.
*
* Kwik is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
* more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*/
package net.luminis.quic.server;
import net.luminis.quic.TestUtils;
import net.luminis.quic.Version;
import net.luminis.quic.log.Logger;
import net.luminis.quic.packet.InitialPacket;
import net.luminis.quic.send.SenderImpl;
import net.luminis.quic.test.TestClock;
import net.luminis.quic.test.TestScheduledExecutor;
import net.luminis.tls.handshake.TlsServerEngineFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import net.luminis.quic.test.FieldReader;
import java.io.InputStream;
import java.net.DatagramSocket;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.time.Instant;
import java.util.Arrays;
import java.util.function.Consumer;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*;
class ServerConnectionCandidateTest {
private Logger logger;
private TestClock clock;
private ServerConnectionImpl createdServerConnection;
private ServerConnectionFactory serverConnectionFactory;
private Context context;
private TestScheduledExecutor testExecutor;
@BeforeEach
void initObjectUnderTest() throws Exception {
logger = mock(Logger.class);
clock = new TestClock();
InputStream certificate = getClass().getResourceAsStream("localhost.pem");
InputStream privateKey = getClass().getResourceAsStream("localhost.key");
TlsServerEngineFactory tlsServerEngineFactory = new TlsServerEngineFactory(certificate, privateKey);
serverConnectionFactory = new TestServerConnectionFactory(16, mock(DatagramSocket.class), tlsServerEngineFactory,
false, mock(ApplicationProtocolRegistry.class), 100, cid -> {}, logger);
context = mock(Context.class);
testExecutor = new TestScheduledExecutor(clock);
when(context.getSharedServerExecutor()).thenReturn(testExecutor);
when(context.getSharedScheduledExecutor()).thenReturn(testExecutor);
}
@Test
void firstInitialPacketShouldSetAntiAmplificationLimit() throws Exception {
// Given
byte[] initialPacketBytes = TestUtils.createValidInitial(Version.getDefault());
byte[] scid = new byte[0];
byte[] odcid = Arrays.copyOfRange(initialPacketBytes, 6, 6 + 8);
ServerConnectionRegistry connectionRegistry = mock(ServerConnectionRegistry.class);
InetSocketAddress address = new InetSocketAddress("localhost", 55333);
ServerConnectionCandidate connectionCandidate = new ServerConnectionCandidate(context, Version.getDefault(), address, scid, odcid, serverConnectionFactory, connectionRegistry, logger);
// When
connectionCandidate.parsePackets(0, Instant.now(), ByteBuffer.wrap(initialPacketBytes));
testExecutor.check();
// Then
assertThat(createdServerConnection).isNotNull();
Integer antiAmplificationLimit = (Integer) new FieldReader(createdServerConnection.getSender(), SenderImpl.class.getDeclaredField("antiAmplificationLimit")).read();
assertThat(antiAmplificationLimit).isEqualTo(3 * 1200);
}
@Test
void firstInitialCarriedInSmallDatagramShouldBeDiscarded() throws Exception {
byte[] initialPacketBytes = TestUtils.createValidInitialNoPadding(Version.getDefault());
byte[] scid = new byte[0];
byte[] odcid = Arrays.copyOfRange(initialPacketBytes, 6, 6 + 8);
ServerConnectionRegistry connectionRegistry = mock(ServerConnectionRegistry.class);
InetSocketAddress address = new InetSocketAddress("localhost", 55333);
ServerConnectionCandidate connectionCandidate = new ServerConnectionCandidate(context, Version.getDefault(), address, scid, odcid, serverConnectionFactory, connectionRegistry, logger);
// When
connectionCandidate.parsePackets(0, Instant.now(), ByteBuffer.wrap(initialPacketBytes));
testExecutor.check();
// Then
assertThat(createdServerConnection).isNull();
verify(connectionRegistry, never()).registerConnection(any(ServerConnectionProxy.class), any(byte[].class));
}
@Test
void firstInitialWithPaddingInDatagramShouldCreateConnection() throws Exception {
byte[] initialPacketBytes = TestUtils.createValidInitialNoPadding(Version.getDefault());
byte[] scid = new byte[0];
byte[] odcid = Arrays.copyOfRange(initialPacketBytes, 6, 6 + 8);
ServerConnectionRegistry connectionRegistry = mock(ServerConnectionRegistry.class);
InetSocketAddress address = new InetSocketAddress("localhost", 55333);
ServerConnectionCandidate connectionCandidate = new ServerConnectionCandidate(context, Version.getDefault(), address, scid, odcid, serverConnectionFactory, connectionRegistry, logger);
// When
ByteBuffer datagramBytes = ByteBuffer.allocate(1200);
datagramBytes.put(initialPacketBytes);
datagramBytes.rewind();
connectionCandidate.parsePackets(0, Instant.now(), datagramBytes);
testExecutor.check();
// Then
assertThat(createdServerConnection).isNotNull();
}
class TestServerConnectionFactory extends ServerConnectionFactory {
public TestServerConnectionFactory(int connectionIdLength, DatagramSocket serverSocket, TlsServerEngineFactory tlsServerEngineFactory, boolean requireRetry, ApplicationProtocolRegistry applicationProtocolRegistry, int initalRtt, Consumer closeCallback, Logger log) {
super(connectionIdLength, serverSocket, tlsServerEngineFactory, requireRetry, applicationProtocolRegistry, initalRtt, null, closeCallback, log);
}
@Override
public ServerConnectionImpl createNewConnection(Version version, InetSocketAddress clientAddress, byte[] originalScid, byte[] originalDcid) {
ServerConnectionImpl newConnection = super.createNewConnection(version, clientAddress, originalScid, originalDcid);
createdServerConnection = newConnection;
return newConnection;
}
@Override
public ServerConnectionProxy createServerConnectionProxy(ServerConnectionImpl connection, InitialPacket initialPacket, Instant packetReceived, ByteBuffer datagram) {
return new ServerConnectionWrapper(connection, initialPacket, packetReceived, datagram);
}
}
}