/*
* Copyright © 2022, 2023 Peter Doornbosch
*
* This file is part of Kwik, an implementation of the QUIC protocol in Java.
*
* Kwik is free software: you can redistribute it and/or modify it under
* the terms of the GNU Lesser General Public License as published by the
* Free Software Foundation, either version 3 of the License, or (at your option)
* any later version.
*
* Kwik is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
* more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*/
package net.luminis.quic.cid;
import net.luminis.quic.Version;
import net.luminis.quic.frame.NewConnectionIdFrame;
import net.luminis.quic.frame.QuicFrame;
import net.luminis.quic.frame.RetireConnectionIdFrame;
import net.luminis.quic.log.Logger;
import net.luminis.quic.send.Sender;
import net.luminis.quic.server.ServerConnectionRegistry;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import static net.luminis.quic.cid.ConnectionIdManager.MAX_CIDS_PER_CONNECTION;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.*;
class ConnectionIdManagerTest {
private ServerConnectionRegistry connectionRegistry;
private Sender sender;
private ConnectionIdManager connectionIdManager;
private BiConsumer closeCallback;
@BeforeEach
void initObjectUnderTest() {
connectionRegistry = mock(ServerConnectionRegistry.class);
sender = mock(Sender.class);
closeCallback = mock(BiConsumer.class);
connectionIdManager = new ConnectionIdManager(new byte[4], new byte[8], 6, 2, connectionRegistry, sender, closeCallback, mock(Logger.class));
}
@Test
void whenConnectionCreatedNewConnectionIdsShouldBeSent() {
// Given
connectionIdManager.registerPeerCidLimit(2);
// When
connectionIdManager.handshakeFinished();
// Then
verify(sender, atLeastOnce()).send(argThat(frame -> frame instanceof NewConnectionIdFrame), any(), any(Consumer.class));
}
@Test
void firstNewConnectionIdSentShouldHaveSequenceNumberOne() {
// Given
connectionIdManager.registerPeerCidLimit(4);
// When
connectionIdManager.handshakeFinished();
// Then
ArgumentCaptor captor = ArgumentCaptor.forClass(QuicFrame.class);
verify(sender, times(3)).send(captor.capture(), any(), any(Consumer.class));
QuicFrame firstFrame = captor.getAllValues().get(0);
assertThat(((NewConnectionIdFrame) firstFrame).getSequenceNr()).isEqualTo(1);
}
@Test
void initialCidsShouldMatchPeerLimitMinusOne() {
// Given
connectionIdManager.registerPeerCidLimit(4);
// When
connectionIdManager.handshakeFinished();
// Then
verify(sender, times(3)).send(argThat(frame -> frame instanceof NewConnectionIdFrame), any(), any(Consumer.class));
}
@Test
void whenPeerLimitIsLargeinitialCidsShouldMatchServerLimit() {
// Given
connectionIdManager.registerPeerCidLimit(64);
// When
connectionIdManager.handshakeFinished();
// Then
verify(sender, times(MAX_CIDS_PER_CONNECTION - 1)).send(argThat(frame -> frame instanceof NewConnectionIdFrame), any(), any(Consumer.class));
}
@Test
void retireConnectionIdShouldLeadToDeregistering() {
// Given
byte[] originalCid = connectionIdManager.getActiveConnectionIds().get(0);
connectionIdManager.registerPeerCidLimit(4);
connectionIdManager.handshakeFinished();
// When
connectionIdManager.process(new RetireConnectionIdFrame(Version.getDefault(), 0), null);
// Then
ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class);
verify(connectionRegistry).deregisterConnectionId(captor.capture());
assertThat(captor.getValue()).isEqualTo(originalCid);
}
@Test
void retireConnectionIdShouldLeadToSendingNew() {
// Given
connectionIdManager.registerPeerCidLimit(2);
connectionIdManager.handshakeFinished();
clearInvocations(sender);
// When
connectionIdManager.process(new RetireConnectionIdFrame(Version.getDefault(), 0), null);
// Then
verify(sender).send(argThat(f -> f instanceof NewConnectionIdFrame), any(), any(Consumer.class));
}
@Test
void retiringConnectionIdAlreadyRetiredDoesNothing() {
// Given
connectionIdManager.registerPeerCidLimit(2);
connectionIdManager.handshakeFinished();
connectionIdManager.process(new RetireConnectionIdFrame(Version.getDefault(), 0), null);
clearInvocations(sender);
// When
connectionIdManager.process(new RetireConnectionIdFrame(Version.getDefault(), 0), null);
// Then
verify(sender, never()).send(any(QuicFrame.class), any(), any(Consumer.class));
}
@Test
void retiringNonExistentSequenceNumberLeadsToConnectionClose() {
// Given
connectionIdManager.registerPeerCidLimit(2);
connectionIdManager.handshakeFinished();
// When
connectionIdManager.process(new RetireConnectionIdFrame(Version.getDefault(), 2), null);
// Then
ArgumentCaptor captor = ArgumentCaptor.forClass(Integer.class);
verify(closeCallback).accept(captor.capture(), anyString());
assertThat(captor.getValue()).isEqualTo(0x0a);
}
@Test
void retiringConnectionIdUsedAsDestinationConnectionIdLeadsToConnectionClose() {
// Given
connectionIdManager.registerPeerCidLimit(2);
connectionIdManager.handshakeFinished();
// When
connectionIdManager.process(new RetireConnectionIdFrame(Version.getDefault(), 0), connectionIdManager.getActiveConnectionIds().get(0));
// Then
ArgumentCaptor captor = ArgumentCaptor.forClass(Integer.class);
verify(closeCallback).accept(captor.capture(), anyString());
assertThat(captor.getValue()).isEqualTo(0x0a);
}
@Test
void initiallyThereShouldBeExactlyOneActiveCid() {
assertThat(connectionIdManager.getActiveConnectionIds()).hasSize(1);
}
@Test
void initiallyAtLeastOneNewCidShouldBeAccepted() {
// Given
// When
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 2, 0, new byte[4]));
// Then
assertThat(connectionIdManager.getActivePeerConnectionIds()).hasSize(2);
}
@Test
void whenNumberOfActiveCidsExceedsLimitConnectionIdLimitErrorIsThrown() {
// Given
connectionIdManager = new ConnectionIdManager(new byte[4], new byte[8], 6, 3, connectionRegistry, sender, closeCallback, mock(Logger.class));
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 1, 0, new byte[4]));
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 2, 0, new byte[4]));
// When
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 3, 0, new byte[4]));
// Then
ArgumentCaptor captor = ArgumentCaptor.forClass(Integer.class);
verify(closeCallback).accept(captor.capture(), anyString());
assertThat(captor.getValue()).isEqualTo(0x09);
}
@Test
void repeatingNewCidWithSequenceNumberShouldNotLeadToError() {
// Given
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 1, 0, new byte[4]));
// When
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 1, 0, new byte[4]));
// Then
verify(closeCallback, never()).accept(anyInt(), anyString());
}
@Test
void invalidRetirePriorToFieldShouldLeadToFrameEncodingError() {
// Given
// When
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 1, 2, new byte[4]));
// Then
ArgumentCaptor captor = ArgumentCaptor.forClass(Integer.class);
verify(closeCallback).accept(captor.capture(), anyString());
assertThat(captor.getValue()).isEqualTo(0x07);
}
@Test
void newConnectionIdFrameWithIncreasedRetirePriorToFieldLeadsToRetireConnectionIdFrame() {
// Given
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 0, 0, new byte[4]));
// When
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 1, 1, new byte[4]));
// Then
verify(sender, atLeastOnce()).send(argThat(f -> f instanceof RetireConnectionIdFrame), any(), any(Consumer.class));
}
@Test
void newConnectionIdFrameWithIncreasedRetirePriorToFieldLeadsToDecrementOfActiveCids() {
// Given
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 1, 0, new byte[4]));
// When
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 2, 1, new byte[4]));
// Then
assertThat(connectionIdManager.getActivePeerConnectionIds()).hasSize(2);
verify(closeCallback, never()).accept(anyInt(), anyString());
}
@Test
void retiredCidShouldNotBeUsedAnymoreAsDestination() {
// Given
byte[] originalDcid = connectionIdManager.getCurrentPeerConnectionId();
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 1, 0, new byte[] { 0x34, 0x1f, 0x5a, 0x55 }));
// When
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 2, 1, new byte[] { 0x5b, 0x2e, 0x1a, 0x44 }));
// Then
assertThat(connectionIdManager.getCurrentPeerConnectionId()).isNotEqualTo(originalDcid);
}
@Test
void newConnectionIdWithSequenceNumberZeroShouldFail() {
// Given
byte[] originalDcid = connectionIdManager.getCurrentPeerConnectionId();
byte[] newDcid = Arrays.copyOf(originalDcid, originalDcid.length);
newDcid[0] += 1; // So now the two or definitely different
// When
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 0, 0, newDcid));
// Then
ArgumentCaptor captor = ArgumentCaptor.forClass(Integer.class);
verify(closeCallback).accept(captor.capture(), anyString());
assertThat(captor.getValue()).isEqualTo(0x0a);
}
@Test
void whenUsingZeroLengthConnectionIdNewConnectionIdFrameShouldLeadToProtocolViolationError() {
// Given
connectionIdManager = new ConnectionIdManager(new byte[0], new byte[8], 6, 2, connectionRegistry, sender, closeCallback, mock(Logger.class));
// When
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 1, 0, new byte[4]));
// Then
ArgumentCaptor captor = ArgumentCaptor.forClass(Integer.class);
verify(closeCallback).accept(captor.capture(), anyString());
assertThat(captor.getValue()).isEqualTo(0x0a);
}
@Test
void initialConnectionIdShouldNotChange() {
// Given
byte[] initialConnectionId = connectionIdManager.getInitialConnectionId();
// When
connectionIdManager.process(new RetireConnectionIdFrame(Version.getDefault(), 0), new byte[3]);
// Then
assertThat(connectionIdManager.getInitialConnectionId()).isEqualTo(initialConnectionId);
}
@Test
void testValidateInitialPeerConnectionId() {
// Given
byte[] peerCid = new byte[] { 0x06, 0x0f, 0x08, 0x0b };
connectionIdManager = new ConnectionIdManager(peerCid, new byte[8], 6, 2, connectionRegistry, sender, closeCallback, mock(Logger.class));
// Then
assertThat(connectionIdManager.validateInitialPeerConnectionId(peerCid)).isTrue();
}
@Test
void whenReorderedNewConnectionIdIsAlreadyRetiredRetireConnectionIdFrameShouldBeSent() {
// Given
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 2, 2, new byte[4]));
// When
connectionIdManager.process(new NewConnectionIdFrame(Version.getDefault(), 1, 0, new byte[4]));
// Then
ArgumentCaptor captor = ArgumentCaptor.forClass(QuicFrame.class);
verify(sender, atLeastOnce()).send(captor.capture(), any(), any(Consumer.class));
List retiredSeqNr = captor.getAllValues().stream()
.filter(f -> f instanceof RetireConnectionIdFrame)
.map(f -> ((RetireConnectionIdFrame) f).getSequenceNr())
.collect(Collectors.toList());
assertThat(retiredSeqNr).contains(1);
}
@Test
void whenSendingNewConnectionIdRetirePriorToIsSet() {
connectionIdManager.sendNewConnectionId(1);
ArgumentCaptor captor = ArgumentCaptor.forClass(QuicFrame.class);
verify(sender, atLeastOnce()).send(captor.capture(), any(), any(Consumer.class));
assertThat(captor.getValue() instanceof NewConnectionIdFrame);
assertThat(((NewConnectionIdFrame) captor.getValue()).getRetirePriorTo()).isEqualTo(1);
}
@Test
void whenPreviouslyUnusedConnectionIdIsUsedNewConnectionIdIsSent() {
// Given
int maxCids = 3;
connectionIdManager.registerPeerCidLimit(maxCids);
connectionIdManager.sendNewConnectionId(0);
clearInvocations(sender);
assertThat(connectionIdManager.getActiveConnectionIds()).hasSize(2);
// When
connectionIdManager.getActiveConnectionIds().forEach(cid -> {
connectionIdManager.registerConnectionIdInUse(cid);
});
// Then
verify(sender, atLeastOnce()).send(argThat(f -> f instanceof NewConnectionIdFrame), any(), any(Consumer.class));
}
@Test
void whenMaxCidsIsReachedRegisterUnusedDoesNotLeadToNew() {
// Given
connectionIdManager = new ConnectionIdManager(new byte[4], new byte[8], 4, 2, connectionRegistry, sender, closeCallback, mock(Logger.class));
int maxCids = 6;
connectionIdManager.registerPeerCidLimit(maxCids);
connectionIdManager.handshakeFinished();
clearInvocations(sender);
assertThat(connectionIdManager.getActiveConnectionIds()).hasSize(maxCids);
// When
connectionIdManager.getActiveConnectionIds().forEach(cid -> {
connectionIdManager.registerConnectionIdInUse(cid);
});
// Then
verify(sender, never()).send(argThat(f -> f instanceof NewConnectionIdFrame), any(), any(Consumer.class));
}
void testValidateRetrySourceConnectionId() {
// Given
connectionIdManager = new ConnectionIdManager(new byte[8], new byte[8], 6, 2, connectionRegistry, sender, closeCallback, mock(Logger.class));
byte[] retryCid = new byte[] { 0x06, 0x0f, 0x08, 0x0b };
// When
connectionIdManager.registerRetrySourceConnectionId(retryCid);
// Then
assertThat(connectionIdManager.validateRetrySourceConnectionId(retryCid)).isTrue();
}
@Test
void whenActiveConnectionIdLimitReachedReceivingRetireShouldNotLeadToNew() {
// Given
connectionIdManager.sendNewConnectionId(0);
// When
connectionIdManager.sendNewConnectionId(1);
clearInvocations(sender);
connectionIdManager.process(new RetireConnectionIdFrame(Version.getDefault(), 0), new byte[3]);
// Then
verify(sender, never()).send(any(QuicFrame.class), any(), any(Consumer.class));
}
@Test
void whenConnectionIdAlreadyRetiredReceivingRetireShouldNotLeadToNew() {
// Given
connectionIdManager.sendNewConnectionId(0);
connectionIdManager.process(new RetireConnectionIdFrame(Version.getDefault(), 0), new byte[3]);
clearInvocations(sender);
assertThat(connectionIdManager.getActiveConnectionIds()).hasSize(2); // Because retire triggers new.
// When
connectionIdManager.process(new RetireConnectionIdFrame(Version.getDefault(), 0), new byte[3]);
// Then
verify(sender, never()).send(any(QuicFrame.class), any(), any(Consumer.class));
}
@Test
void testRegisterInitialPeerCid() {
// Given
assertThat(connectionIdManager.getAllPeerConnectionIds().get(0).getConnectionId()).isNotEqualTo(new byte[] { 0x01, 0x02, 0x03, 0x04 });
// When
connectionIdManager.registerInitialPeerCid(new byte[] { 0x01, 0x02, 0x03, 0x04 });
// Then
assertThat(connectionIdManager.getAllPeerConnectionIds().get(0).getConnectionId()).isEqualTo(new byte[] { 0x01, 0x02, 0x03, 0x04 });
}
}