QUIC CHANNEL: Finish moving SRT handling to SRTM

Reviewed-by: Tomas Mraz <tomas@openssl.org>
Reviewed-by: Matt Caswell <matt@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/22674)
This commit is contained in:
Hugo Landau 2023-11-09 10:27:14 +00:00
parent cbf4b68333
commit 5f86ae32c2
4 changed files with 14 additions and 132 deletions

View File

@ -113,6 +113,8 @@ typedef struct quic_channel_args_st {
QUIC_PORT *port;
/* LCIDM to register LCIDs with. */
QUIC_LCIDM *lcidm;
/* SRTM to register SRTs with. */
QUIC_SRTM *srtm;
int is_server;
SSL *tls;

View File

@ -13,6 +13,7 @@
#include "internal/quic_error.h"
#include "internal/quic_rx_depack.h"
#include "internal/quic_lcidm.h"
#include "internal/quic_srtm.h"
#include "../ssl_local.h"
#include "quic_channel_local.h"
#include "quic_port_local.h"
@ -117,95 +118,6 @@ static int gen_rand_conn_id(OSSL_LIB_CTX *libctx, size_t len, QUIC_CONN_ID *cid)
return 1;
}
static unsigned long chan_reset_token_hash(const QUIC_SRT_ELEM *a)
{
unsigned long h;
assert(sizeof(h) <= sizeof(a->token));
memcpy(&h, &a->token, sizeof(h));
return h;
}
static int chan_reset_token_cmp(const QUIC_SRT_ELEM *a, const QUIC_SRT_ELEM *b)
{
/* RFC 9000 s. 10.3.1:
* When comparing a datagram to stateless reset token values,
* endpoints MUST perform the comparison without leaking
* information about the value of the token. For example,
* performing this comparison in constant time protects the
* value of individual stateless reset tokens from information
* leakage through timing side channels.
*
* TODO(QUIC FUTURE): make this a memcmp when obfuscation is done and update
* comment above.
*/
return CRYPTO_memcmp(&a->token, &b->token, sizeof(a->token));
}
static int reset_token_obfuscate(QUIC_SRT_ELEM *out, const unsigned char *in)
{
/*
* TODO(QUIC FUTURE): update this to AES encrypt the token in ECB mode with a
* random (per channel) key.
*/
memcpy(&out->token, in, sizeof(out->token));
return 1;
}
/*
* Add a stateless reset token to the channel
*/
static int chan_add_reset_token(QUIC_CHANNEL *ch, const unsigned char *new,
uint64_t seq_num)
{
QUIC_SRT_ELEM *srte;
int err;
/* Add to list by sequence number (always the tail) */
if ((srte = OPENSSL_malloc(sizeof(*srte))) == NULL)
return 0;
ossl_list_stateless_reset_tokens_init_elem(srte);
ossl_list_stateless_reset_tokens_insert_tail(&ch->srt_list_seq, srte);
reset_token_obfuscate(srte, new);
srte->seq_num = seq_num;
lh_QUIC_SRT_ELEM_insert(ch->srt_hash_tok, srte);
err = lh_QUIC_SRT_ELEM_error(ch->srt_hash_tok);
if (err > 0) {
ossl_list_stateless_reset_tokens_remove(&ch->srt_list_seq, srte);
OPENSSL_free(srte);
return 0;
}
return 1;
}
/*
* Remove a stateless reset token from the channel
* If the token isn't known, we just ignore the remove request which is safe.
*/
static void chan_remove_reset_token(QUIC_CHANNEL *ch, uint64_t seq_num)
{
QUIC_SRT_ELEM *srte;
/*
* Because the list is ordered and we only ever remove CIDs in order,
* this loop should never iterate, but safer to provide the option.
*/
for (srte = ossl_list_stateless_reset_tokens_head(&ch->srt_list_seq);
srte != NULL;
srte = ossl_list_stateless_reset_tokens_next(srte)) {
if (srte->seq_num > seq_num)
return;
if (srte->seq_num == seq_num) {
ossl_list_stateless_reset_tokens_remove(&ch->srt_list_seq, srte);
(void)lh_QUIC_SRT_ELEM_delete(ch->srt_hash_tok, srte);
OPENSSL_free(srte);
return;
}
}
}
/*
* QUIC Channel Initialization and Teardown
* ========================================
@ -228,13 +140,7 @@ static int ch_init(QUIC_CHANNEL *ch)
size_t rx_short_dcid_len = ossl_quic_port_get_rx_short_dcid_len(ch->port);
size_t tx_init_dcid_len = ossl_quic_port_get_tx_init_dcid_len(ch->port);
if (ch->port == NULL || ch->lcidm == NULL)
goto err;
ossl_list_stateless_reset_tokens_init(&ch->srt_list_seq);
ch->srt_hash_tok = lh_QUIC_SRT_ELEM_new(&chan_reset_token_hash,
&chan_reset_token_cmp);
if (ch->srt_hash_tok == NULL)
if (ch->port == NULL || ch->lcidm == NULL || ch->srtm == NULL)
goto err;
/* For clients, generate our initial DCID. */
@ -346,13 +252,6 @@ static int ch_init(QUIC_CHANNEL *ch)
ossl_quic_tx_packetiser_set_ack_tx_cb(ch->txp, ch_on_txp_ack_tx, ch);
/*
* Setup a handler to detect stateless reset tokens.
*/
//ossl_quic_demux_set_stateless_reset_handler(ch->demux,
// &ch_stateless_reset_token_handler,
// ch);
qrx_args.libctx = ch->port->libctx;
qrx_args.demux = ch->port->demux;
qrx_args.short_conn_id_len = rx_short_dcid_len;
@ -435,7 +334,6 @@ err:
static void ch_cleanup(QUIC_CHANNEL *ch)
{
QUIC_SRT_ELEM *srte, *srte_next;
uint32_t pn_space;
if (ch->ackm != NULL)
@ -445,6 +343,7 @@ static void ch_cleanup(QUIC_CHANNEL *ch)
ossl_ackm_on_pkt_space_discarded(ch->ackm, pn_space);
ossl_quic_lcidm_cull(ch->lcidm, ch);
ossl_quic_srtm_cull(ch->srtm, ch);
ossl_quic_tx_packetiser_free(ch->txp);
ossl_quic_txpim_free(ch->txpim);
ossl_quic_cfq_free(ch->cfq);
@ -473,16 +372,6 @@ static void ch_cleanup(QUIC_CHANNEL *ch)
OSSL_ERR_STATE_free(ch->err_state);
OPENSSL_free(ch->ack_range_scratch);
/* Free the stateless reset tokens */
for (srte = ossl_list_stateless_reset_tokens_head(&ch->srt_list_seq);
srte != NULL;
srte = srte_next) {
srte_next = ossl_list_stateless_reset_tokens_next(srte);
ossl_list_stateless_reset_tokens_remove(&ch->srt_list_seq, srte);
(void)lh_QUIC_SRT_ELEM_delete(ch->srt_hash_tok, srte);
OPENSSL_free(srte);
}
lh_QUIC_SRT_ELEM_free(ch->srt_hash_tok);
if (ch->on_port_list) {
ossl_list_ch_remove(&ch->port->channel_list, ch);
ch->on_port_list = 0;
@ -500,6 +389,7 @@ QUIC_CHANNEL *ossl_quic_channel_new(const QUIC_CHANNEL_ARGS *args)
ch->is_server = args->is_server;
ch->tls = args->tls;
ch->lcidm = args->lcidm;
ch->srtm = args->srtm;
if (!ch_init(ch)) {
OPENSSL_free(ch);
@ -1550,7 +1440,8 @@ static int ch_on_transport_params(const unsigned char *params,
reason = TP_REASON_MALFORMED("STATELESS_RESET_TOKEN");
goto malformed;
}
if (!chan_add_reset_token(ch, body, ch->cur_remote_seq_num)) {
if (!ossl_quic_srtm_add(ch->srtm, ch, ch->cur_remote_seq_num,
(const QUIC_STATELESS_RESET_TOKEN *)body)) {
reason = TP_REASON_INTERNAL_ERROR("STATELESS_RESET_TOKEN");
goto malformed;
}
@ -2877,7 +2768,7 @@ static int ch_enqueue_retire_conn_id(QUIC_CHANNEL *ch, uint64_t seq_num)
WPACKET wpkt;
size_t l;
chan_remove_reset_token(ch, seq_num);
ossl_quic_srtm_remove(ch->srtm, ch, seq_num);
if ((buf_mem = BUF_MEM_new()) == NULL)
goto err;
@ -2982,8 +2873,8 @@ void ossl_quic_channel_on_new_conn_id(QUIC_CHANNEL *ch,
if (new_remote_seq_num > ch->cur_remote_seq_num) {
/* Add new stateless reset token */
if (!chan_add_reset_token(ch, f->stateless_reset.token,
new_remote_seq_num)) {
if (!ossl_quic_srtm_add(ch->srtm, ch, new_remote_seq_num,
&f->stateless_reset)) {
ossl_quic_channel_raise_protocol_error(
ch, QUIC_ERR_CONNECTION_ID_LIMIT_ERROR,
OSSL_QUIC_FRAME_TYPE_NEW_CONN_ID,

View File

@ -54,6 +54,8 @@ struct quic_channel_st {
/* Port LCIDM we use to register LCIDs. */
QUIC_LCIDM *lcidm;
/* SRTM we register SRTs with. */
QUIC_SRTM *srtm;
/*
* The transport parameter block we will send or have sent.
@ -124,14 +126,6 @@ struct quic_channel_st {
/*
* The DCID we currently use to talk to the peer and its sequence num.
*
* TODO(QUIC FUTURE) consider removing the second two, both are contained in
* srt_list_seq (defined below).
*
* cur_remote_seq_num is same as the sequence number in the last element.
* cur_retire_prior_to corresponds to the sequence number in first element.
*
* Leaving them here avoids null checking etc
*/
QUIC_CONN_ID cur_remote_dcid;
uint64_t cur_remote_seq_num;
@ -140,12 +134,6 @@ struct quic_channel_st {
/* Server only: The DCID we currently expect the peer to use to talk to us. */
QUIC_CONN_ID cur_local_cid;
/* Hash of stateless reset tokens keyed on the token */
LHASH_OF(QUIC_SRT_ELEM) *srt_hash_tok;
/* List of the stateless reset tokens ordered by sequence number */
OSSL_LIST(stateless_reset_tokens) srt_list_seq;
/* Transport parameter values we send to our peer. */
uint64_t tx_init_max_stream_data_bidi_local;
uint64_t tx_init_max_stream_data_bidi_remote;

View File

@ -281,6 +281,7 @@ static QUIC_CHANNEL *port_make_channel(QUIC_PORT *port, SSL *tls, int is_server)
args.is_server = is_server;
args.tls = (tls != NULL ? tls : port_new_handshake_layer(port));
args.lcidm = port->lcidm;
args.srtm = port->srtm;
if (args.tls == NULL)
return NULL;