/* * Copyright © 2019, 2020, 2021, 2022, 2023 Peter Doornbosch * * This file is part of Agent15, an implementation of TLS 1.3 in Java. * * Agent15 is free software: you can redistribute it and/or modify it under * the terms of the GNU Lesser General Public License as published by the * Free Software Foundation, either version 3 of the License, or (at your option) * any later version. * * Agent15 is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for * more details. * * You should have received a copy of the GNU Lesser General Public License * along with this program. If not, see . */ package net.luminis.tls.extension; import net.luminis.tls.*; import net.luminis.tls.alert.DecodeErrorException; import net.luminis.tls.util.ByteUtils; import java.math.BigInteger; import java.nio.ByteBuffer; import java.security.AlgorithmParameters; import java.security.KeyFactory; import java.security.NoSuchAlgorithmException; import java.security.PublicKey; import java.security.interfaces.ECPublicKey; import java.security.interfaces.XECPublicKey; import java.security.spec.*; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.stream.Stream; import static net.luminis.tls.TlsConstants.NamedGroup.*; /** * The TLS "key_share" extension contains the endpoint's cryptographic parameters. * See https://tools.ietf.org/html/rfc8446#section-4.2.8 */ public class KeyShareExtension extends Extension { public static final Map CURVE_KEY_LENGTHS = Map.of( secp256r1, 65, x25519, 32, x448, 56 ); public static final List supportedCurves = List.of(secp256r1, x25519, x448); private TlsConstants.HandshakeType handshakeType; private List keyShareEntries = new ArrayList<>(); /** * Assuming KeyShareClientHello: * "In the ClientHello message, the "extension_data" field of this extension contains a "KeyShareClientHello" value..." * @param publicKey * @param ecCurve */ public KeyShareExtension(ECPublicKey publicKey, TlsConstants.NamedGroup ecCurve, TlsConstants.HandshakeType handshakeType) { this.handshakeType = handshakeType; if (! supportedCurves.contains(ecCurve)) { throw new RuntimeException("Only curves supported: " + supportedCurves); } keyShareEntries.add(new ECKeyShareEntry(ecCurve, publicKey)); } public KeyShareExtension(PublicKey publicKey, TlsConstants.NamedGroup ecCurve, TlsConstants.HandshakeType handshakeType) { this.handshakeType = handshakeType; if (! supportedCurves.contains(ecCurve)) { throw new RuntimeException("Only curves supported: " + supportedCurves); } keyShareEntries.add(new KeyShareEntry(ecCurve, publicKey)); } /** * Assuming KeyShareServerHello: * "In a ServerHello message, the "extension_data" field of this extension contains a KeyShareServerHello value..." * @param buffer * @return * @throws TlsProtocolException */ public KeyShareExtension(ByteBuffer buffer, TlsConstants.HandshakeType handshakeType) throws TlsProtocolException { this(buffer, handshakeType, false); } public KeyShareExtension(ByteBuffer buffer, TlsConstants.HandshakeType handshakeType, boolean helloRetryRequestType) throws TlsProtocolException { int extensionDataLength = parseExtensionHeader(buffer, TlsConstants.ExtensionType.key_share, 1); if (extensionDataLength < 2) { throw new DecodeErrorException("extension underflow"); } if (handshakeType == TlsConstants.HandshakeType.client_hello) { int keyShareEntriesSize = buffer.getShort(); if (extensionDataLength != 2 + keyShareEntriesSize) { throw new DecodeErrorException("inconsistent length"); } int remaining = keyShareEntriesSize; while (remaining > 0) { remaining -= parseKeyShareEntry(buffer, helloRetryRequestType); } if (remaining != 0) { throw new DecodeErrorException("inconsistent length"); } } else if (handshakeType == TlsConstants.HandshakeType.server_hello) { int remaining = extensionDataLength; remaining -= parseKeyShareEntry(buffer, helloRetryRequestType); if (remaining != 0) { throw new DecodeErrorException("inconsistent length"); } } else { throw new IllegalArgumentException(); } } protected int parseKeyShareEntry(ByteBuffer buffer, boolean namedGroupOnly) throws TlsProtocolException { int startPosition = buffer.position(); if (namedGroupOnly && buffer.remaining() < 2 || !namedGroupOnly && buffer.remaining() < 4 ) { throw new DecodeErrorException("extension underflow"); } int namedGroupValue = buffer.getShort(); TlsConstants.NamedGroup namedGroup = Stream.of(TlsConstants.NamedGroup.values()).filter(it -> it.value == namedGroupValue).findAny() .orElseThrow(() -> new DecodeErrorException("Invalid named group")); if (! supportedCurves.contains(namedGroup)) { throw new RuntimeException("Curve '" + namedGroup + "' not supported"); } if (namedGroupOnly) { keyShareEntries.add(new ECKeyShareEntry(namedGroup, null)); } else { int keyLength = buffer.getShort(); if (buffer.remaining() < keyLength) { throw new DecodeErrorException("extension underflow"); } if (keyLength != CURVE_KEY_LENGTHS.get(namedGroup)) { throw new DecodeErrorException("Invalid " + namedGroup.name() + " key length: " + keyLength); } if (namedGroup == secp256r1) { int headerByte = buffer.get(); if (headerByte == 4) { byte[] keyData = new byte[keyLength - 1]; buffer.get(keyData); ECPublicKey ecPublicKey = rawToEncodedECPublicKey(namedGroup, keyData); keyShareEntries.add(new ECKeyShareEntry(namedGroup, ecPublicKey)); } else { throw new DecodeErrorException("EC keys must be in legacy form"); } } else if (namedGroup == x25519 || namedGroup == x448) { byte[] keyData = new byte[keyLength]; buffer.get(keyData); PublicKey publicKey = rawToEncodedXDHPublicKey(namedGroup, keyData); keyShareEntries.add(new KeyShareEntry(namedGroup, publicKey)); } } return buffer.position() - startPosition; } @Override public byte[] getBytes() { short keyShareEntryLength = (short) keyShareEntries.stream() .map(ks -> ks.getNamedGroup()) .mapToInt(g -> CURVE_KEY_LENGTHS.get(g)) .map(s -> 2 + 2 + s) // Named Group: 2 bytes, key length: 2 bytes .sum(); short extensionLength = keyShareEntryLength; if (handshakeType == TlsConstants.HandshakeType.client_hello) { extensionLength += 2; } ByteBuffer buffer = ByteBuffer.allocate(4 + extensionLength); buffer.putShort(TlsConstants.ExtensionType.key_share.value); buffer.putShort(extensionLength); // Extension data length (in bytes) if (handshakeType == TlsConstants.HandshakeType.client_hello) { buffer.putShort(keyShareEntryLength); } for (KeyShareEntry keyShare: keyShareEntries) { buffer.putShort(keyShare.getNamedGroup().value); buffer.putShort(CURVE_KEY_LENGTHS.get(keyShare.getNamedGroup()).shortValue()); if (keyShare.getNamedGroup() == secp256r1) { // See https://tools.ietf.org/html/rfc8446#section-4.2.8.2, "For secp256r1, secp384r1, and secp521r1, ..." buffer.put((byte) 4); byte[] affineX = ((ECPublicKey) keyShare.getKey()).getW().getAffineX().toByteArray(); writeAffine(buffer, affineX); byte[] affineY = ((ECPublicKey) keyShare.getKey()).getW().getAffineY().toByteArray(); writeAffine(buffer, affineY); } else if (keyShare.getNamedGroup() == x25519 || keyShare.getNamedGroup() == x448) { byte[] raw = ((XECPublicKey) keyShare.getKey()).getU().toByteArray(); if (raw.length > CURVE_KEY_LENGTHS.get(keyShare.getNamedGroup())) { throw new RuntimeException("Invalid " + keyShare.getNamedGroup() + " key length: " + raw.length); } if (raw.length < CURVE_KEY_LENGTHS.get(keyShare.getNamedGroup())) { // Must pad with leading zeros, but as the encoding is little endian, it is easier to first reverse... reverse(raw); // ... and than pad with zeroes up to the required ledngth byte[] padded = Arrays.copyOf(raw, CURVE_KEY_LENGTHS.get(keyShare.getNamedGroup())); raw = padded; } else { // Encoding is little endian, so reverse the bytes. reverse(raw); } buffer.put(raw); } else { throw new RuntimeException(); } } return buffer.array(); } public List getKeyShareEntries() { return keyShareEntries; } private void writeAffine(ByteBuffer buffer, byte[] affine) { if (affine.length == 32) { buffer.put(affine); } else if (affine.length < 32) { for (int i = 0; i < 32 - affine.length; i++) { buffer.put((byte) 0); } buffer.put(affine, 0, affine.length); } else if (affine.length > 32) { for (int i = 0; i < affine.length - 32; i++) { if (affine[i] != 0) { throw new RuntimeException("W Affine more then 32 bytes, leading bytes not 0 " + ByteUtils.bytesToHex(affine)); } } buffer.put(affine, affine.length - 32, 32); } } public static class KeyShareEntry { protected TlsConstants.NamedGroup namedGroup; protected final PublicKey key; public KeyShareEntry(TlsConstants.NamedGroup namedGroup, PublicKey key) { this.namedGroup = namedGroup; this.key = key; } public TlsConstants.NamedGroup getNamedGroup() { return namedGroup; } public PublicKey getKey() { return key; } } public static class ECKeyShareEntry extends KeyShareEntry { private final ECPublicKey key; public ECKeyShareEntry(TlsConstants.NamedGroup namedGroup, ECPublicKey key) { super(namedGroup, key); this.namedGroup = namedGroup; this.key = key; } public ECPublicKey getKey() { return key; } } static ECPublicKey rawToEncodedECPublicKey(TlsConstants.NamedGroup curveName, byte[] rawBytes) { try { KeyFactory kf = KeyFactory.getInstance("EC"); byte[] x = Arrays.copyOfRange(rawBytes, 0, rawBytes.length/2); byte[] y = Arrays.copyOfRange(rawBytes, rawBytes.length/2, rawBytes.length); ECPoint w = new ECPoint(new BigInteger(1,x), new BigInteger(1,y)); return (ECPublicKey) kf.generatePublic(new ECPublicKeySpec(w, ecParameterSpecForCurve(curveName.name()))); } catch (NoSuchAlgorithmException e) { throw new RuntimeException("Missing support for EC algorithm"); } catch (InvalidKeySpecException e) { throw new RuntimeException("Inappropriate parameter specification"); } } static ECParameterSpec ecParameterSpecForCurve(String curveName) { try { AlgorithmParameters params = AlgorithmParameters.getInstance("EC"); params.init(new ECGenParameterSpec(curveName)); return params.getParameterSpec(ECParameterSpec.class); } catch (NoSuchAlgorithmException e) { throw new RuntimeException("Missing support for EC algorithm"); } catch (InvalidParameterSpecException e) { throw new RuntimeException("Inappropriate parameter specification"); } } static PublicKey rawToEncodedXDHPublicKey(TlsConstants.NamedGroup curve, byte[] keyData) { try { // Encoding is little endian, so reverse the bytes. reverse(keyData); BigInteger u = new BigInteger(keyData); KeyFactory kf = KeyFactory.getInstance("XDH"); NamedParameterSpec paramSpec = new NamedParameterSpec(curve.name().toUpperCase()); XECPublicKeySpec pubSpec = new XECPublicKeySpec(paramSpec, u); return kf.generatePublic(pubSpec); } catch (NoSuchAlgorithmException e) { throw new RuntimeException("Missing support for EC algorithm"); } catch (InvalidKeySpecException e) { throw new RuntimeException("Inappropriate parameter specification"); } } public static void reverse(byte[] array) { if (array == null) { return; } int i = 0; int j = array.length - 1; byte tmp; while (j > i) { tmp = array[j]; array[j] = array[i]; array[i] = tmp; j--; i++; } } }