/*
* Copyright © 2019, 2020, 2021, 2022, 2023 Peter Doornbosch
*
* This file is part of Kwik, an implementation of the QUIC protocol in Java.
*
* Kwik is free software: you can redistribute it and/or modify it under
* the terms of the GNU Lesser General Public License as published by the
* Free Software Foundation, either version 3 of the License, or (at your option)
* any later version.
*
* Kwik is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
* more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*/
package net.luminis.quic;
import net.luminis.quic.frame.CryptoFrame;
import net.luminis.quic.frame.QuicFrame;
import net.luminis.quic.log.Logger;
import net.luminis.quic.send.Sender;
import net.luminis.tls.*;
import net.luminis.tls.handshake.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import net.luminis.quic.test.FieldSetter;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.function.Consumer;
import java.util.function.Function;
import static net.luminis.tls.TlsConstants.HandshakeType.certificate_request;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
import static net.luminis.quic.test.FieldSetter.setField;
class CryptoStreamTest {
public static final Version QUIC_VERSION = Version.getDefault();
private CryptoStream cryptoStream;
private TlsMessageParser messageParser;
private Sender sender;
@BeforeEach
void prepareObjectUnderTest() throws Exception {
sender = mock(Sender.class);
cryptoStream = new CryptoStream(new VersionHolder(QUIC_VERSION), EncryptionLevel.Handshake, null,
Role.Client, new TlsClientEngine(mock(ClientMessageSender.class), mock(TlsStatusEventHandler.class)), mock(Logger.class), sender);
messageParser = mock(TlsMessageParser.class);
setField(cryptoStream, cryptoStream.getClass().getDeclaredField("tlsMessageParser"), messageParser);
setParseFunction(buffer -> {
buffer.mark();
int type = buffer.get();
buffer.reset();
int length = buffer.getInt() & 0x00ffffff;
byte[] stringBytes = new byte[length];
buffer.get(stringBytes);
return new MockTlsMessage(type, new String(stringBytes));
});
}
@Test
void parseSingleMessageInSingleFrame() throws Exception {
cryptoStream.add(new CryptoFrame(QUIC_VERSION, convertToMsgBytes(13, "first crypto frame")));
assertThat(cryptoStream.getTlsMessages())
.isNotEmpty()
.contains(new MockTlsMessage("first crypto frame"));
assertThat(((MockTlsMessage) cryptoStream.getTlsMessages().get(0)).getType()).isEqualTo(certificate_request);
}
@Test
void parserWaitsForAllFramesNeededToParseWholeMessage() throws Exception {
byte[] rawMessageBytes = convertToMsgBytes("first frame second frame last crypto frame");
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 0, Arrays.copyOf(rawMessageBytes,4 + 12)));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 16, "second frame ".getBytes()));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 29, "last crypto frame".getBytes()));
assertThat(cryptoStream.getTlsMessages()).contains(new MockTlsMessage("first frame second frame last crypto frame"));
}
@Test
void parserWaitsForAllOutOfOrderFramesNeededToParseWholeMessage() throws Exception {
byte[] rawMessageBytes = convertToMsgBytes("first frame second frame last crypto frame");
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 29, "last crypto frame".getBytes()));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 0, Arrays.copyOf(rawMessageBytes,4 + 12)));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 16, "second frame ".getBytes()));
assertThat(cryptoStream.getTlsMessages()).contains(new MockTlsMessage("first frame second frame last crypto frame"));
}
@Test
void handleRetransmittedFramesWithDifferentSegmentation() throws Exception {
byte[] rawMessageBytes = convertToMsgBytes("first frame second frame last crypto frame");
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 29, "last crypto frame".getBytes()));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 0, Arrays.copyOf(rawMessageBytes,4 + 12)));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
// Simulate second frame is never received, but all crypto content is retransmitted in different frames.
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 0, Arrays.copyOf(rawMessageBytes,4 + 19)));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 23, "frame last crypto frame".getBytes()));
assertThat(cryptoStream.getTlsMessages()).contains(new MockTlsMessage("first frame second frame last crypto frame"));
}
@Test
void handleOverlappingFrames() throws Exception {
byte[] rawMessageBytes = convertToMsgBytes("abcdefghijklmnopqrstuvwxyz");
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 4 + 2, "cdefghijk".getBytes()));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 4 + 4, "efghi".getBytes()));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 4 + 12, "mn".getBytes()));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 4 + 10, "klmnop".getBytes()));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 0, Arrays.copyOfRange(rawMessageBytes, 0, 8)));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 4 + 8, "ijklmnopqrstuvwxyz".getBytes()));
assertThat(cryptoStream.getTlsMessages()).contains(new MockTlsMessage("abcdefghijklmnopqrstuvwxyz"));
}
@Test
void parseMultipleMessages() throws Exception {
byte[] rawMessageBytes1 = convertToMsgBytes("abcdefghijklmnopqrstuvwxyz");
byte[] rawMessageBytes2 = convertToMsgBytes("0123456789");
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 4 + 26, rawMessageBytes2));
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 4 + 8, "ijklmnopqrstuvwxyz".getBytes()));
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 4 + 10, "klmnopqrstuvwxyz".getBytes()));
assertThat(cryptoStream.getTlsMessages()).isEmpty();
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 0, Arrays.copyOfRange(rawMessageBytes1, 0, 18)));
assertThat(cryptoStream.getTlsMessages()).contains(new MockTlsMessage("abcdefghijklmnopqrstuvwxyz"));
assertThat(cryptoStream.getTlsMessages()).contains(new MockTlsMessage("0123456789"));
}
@Test
void parseMessageSplitAccrossMultipleFrames() throws Exception {
byte[] rawMessageBytes = new byte[4 + 5 + 4 + 5];
System.arraycopy(convertToMsgBytes("abcde"), 0, rawMessageBytes, 0, 4 + 5);
System.arraycopy(convertToMsgBytes("12345"), 0, rawMessageBytes, 4 + 5, 4 + 5);
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 0, Arrays.copyOfRange(rawMessageBytes, 0, 11)));
assertThat(cryptoStream.getTlsMessages().size()).isEqualTo(1);
assertThat(cryptoStream.getTlsMessages()).contains(new MockTlsMessage("abcde"));
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 11, Arrays.copyOfRange(rawMessageBytes, 11, 12)));
assertThat(cryptoStream.getTlsMessages().size()).isEqualTo(1);
assertThat(cryptoStream.getTlsMessages()).contains(new MockTlsMessage("abcde"));
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 12, Arrays.copyOfRange(rawMessageBytes, 12, 14)));
assertThat(cryptoStream.getTlsMessages().size()).isEqualTo(1);
assertThat(cryptoStream.getTlsMessages()).contains(new MockTlsMessage("abcde"));
cryptoStream.add(new CryptoFrame(QUIC_VERSION, 14, Arrays.copyOfRange(rawMessageBytes, 14, 18)));
assertThat(cryptoStream.getTlsMessages().size()).isEqualTo(2);
assertThat(cryptoStream.getTlsMessages()).contains(new MockTlsMessage("abcde"));
assertThat(cryptoStream.getTlsMessages()).contains(new MockTlsMessage("12345"));
}
@Test
void writingDataToStreamLeadsToCallingSenderWithSendFrameCallback() {
// Given
byte[] dataToSend = new byte[120];
// When
cryptoStream.write(dataToSend);
// Then
ArgumentCaptor> captor = ArgumentCaptor.forClass(Function.class);
verify(sender).send(captor.capture(), anyInt(), any(EncryptionLevel.class), any(Consumer.class));
Function frameGeneratorFunction = captor.getValue();
QuicFrame frameToSend = frameGeneratorFunction.apply(1000);
assertThat(frameToSend).isInstanceOf(CryptoFrame.class);
assertThat(((CryptoFrame) frameToSend).getStreamData()).hasSize(120);
assertThat(((CryptoFrame) frameToSend).getOffset()).isEqualTo(0);
}
@Test
void writingDataThatDoesNotFitInFrameLeadsToMultipleCallbacks() {
// Given
byte[] dataToSend = new byte[1800];
new Random().nextBytes(dataToSend);
// When
cryptoStream.write(dataToSend);
// Then
ByteBuffer dataReceived = ByteBuffer.allocate(1800);
while (true) {
ArgumentCaptor> captor = ArgumentCaptor.forClass(Function.class);
verify(sender, atMost(99)).send(captor.capture(), anyInt(), any(EncryptionLevel.class), any(Consumer.class));
List> frameGeneratorFunctions = captor.getAllValues();
clearInvocations(sender);
if (frameGeneratorFunctions.size() == 0) {
break;
}
frameGeneratorFunctions.stream().forEach(f -> {
QuicFrame frameToSend = f.apply(1000);
assertThat(frameToSend).isInstanceOf(CryptoFrame.class);
assertThat(((CryptoFrame) frameToSend).getFrameLength()).isLessThanOrEqualTo(1000);
dataReceived.put(((CryptoFrame) frameToSend).getStreamData());
});
}
assertThat(dataReceived.array()).isEqualTo(dataToSend);
}
@Test
void dataInMultipleWritesIsConcatenatedIntoStream() {
// Given
byte[] dataToSend = new byte[1800];
new Random().nextBytes(dataToSend);
// When
cryptoStream.write(Arrays.copyOfRange(dataToSend, 0, 200));
cryptoStream.write(Arrays.copyOfRange(dataToSend, 200, 1413));
cryptoStream.write(Arrays.copyOfRange(dataToSend, 1413, 1509));
cryptoStream.write(Arrays.copyOfRange(dataToSend, 1509, 1628));
cryptoStream.write(Arrays.copyOfRange(dataToSend, 1628, 1800));
// Then
ByteBuffer dataReceived = ByteBuffer.allocate(1800);
while (true) {
ArgumentCaptor> captor = ArgumentCaptor.forClass(Function.class);
verify(sender, atMost(99)).send(captor.capture(), anyInt(), any(EncryptionLevel.class), any(Consumer.class));
List> frameGeneratorFunctions = captor.getAllValues();
clearInvocations(sender);
if (frameGeneratorFunctions.size() == 0) {
break;
}
frameGeneratorFunctions.stream().forEach(f -> {
QuicFrame frameToSend = f.apply(1000);
if (frameToSend != null) {
assertThat(frameToSend).isInstanceOf(CryptoFrame.class);
dataReceived.put(((CryptoFrame) frameToSend).getStreamData());
}
});
}
assertThat(dataReceived.array()).isEqualTo(dataToSend);
}
private void setParseFunction(Function parseFunction) throws Exception {
when(messageParser.parseAndProcessHandshakeMessage(any(ByteBuffer.class), any(TlsClientEngine.class), any(ProtectionKeysType.class))).thenAnswer(new Answer() {
@Override
public Message answer(InvocationOnMock invocation) throws Throwable {
ByteBuffer buffer = invocation.getArgument(0);
return parseFunction.apply(buffer);
}
});
}
private byte[] convertToMsgBytes(String content) {
return convertToMsgBytes(0, content);
}
private byte[] convertToMsgBytes(int type, String content) {
byte[] bytes = new byte[content.getBytes().length + 4];
ByteBuffer buffer = ByteBuffer.wrap(bytes);
buffer.putInt(content.getBytes().length);
buffer.put(content.getBytes());
buffer.rewind();
buffer.put((byte) type);
return bytes;
}
static class MockTlsMessage extends HandshakeMessage {
private final int type;
private final String contents;
public MockTlsMessage(int type, String contents) {
this.type = type;
this.contents = contents;
}
public MockTlsMessage(String contents) {
this.type = 0;
this.contents = contents;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MockTlsMessage that = (MockTlsMessage) o;
return Objects.equals(contents, that.contents);
}
@Override
public int hashCode() {
return Objects.hash(contents);
}
@Override
public String toString() {
return "Message: " + contents;
}
@Override
public TlsConstants.HandshakeType getType() {
return Arrays.stream(TlsConstants.HandshakeType.values()).filter(v -> v.value == this.type).findFirst().get();
}
@Override
public byte[] getBytes() {
return new byte[0];
}
}
}