Add a test for using a PSK with QUIC

Check that we can set and use a PSK when establishing a QUIC connection.

Fixes openssl/project#83

Reviewed-by: Hugo Landau <hlandau@openssl.org>
Reviewed-by: Paul Dale <pauli@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/22011)
This commit is contained in:
Matt Caswell 2023-09-07 17:36:13 +01:00 committed by Tomas Mraz
parent 4ee8c1fb51
commit 1e4fc0b2e5
6 changed files with 147 additions and 49 deletions

View File

@ -211,6 +211,10 @@ int ossl_quic_tserver_new_ticket(QUIC_TSERVER *srv);
int ossl_quic_tserver_set_max_early_data(QUIC_TSERVER *srv,
uint32_t max_early_data);
/* Set the find session callback for getting a server PSK */
void ossl_quic_tserver_set_psk_find_session_cb(QUIC_TSERVER *srv,
SSL_psk_find_session_cb_func cb);
# endif
#endif

View File

@ -99,10 +99,12 @@ QUIC_TSERVER *ossl_quic_tserver_new(const QUIC_TSERVER_ARGS *args,
if (srv->ctx == NULL)
goto err;
if (SSL_CTX_use_certificate_file(srv->ctx, certfile, SSL_FILETYPE_PEM) <= 0)
if (certfile != NULL
&& SSL_CTX_use_certificate_file(srv->ctx, certfile, SSL_FILETYPE_PEM) <= 0)
goto err;
if (SSL_CTX_use_PrivateKey_file(srv->ctx, keyfile, SSL_FILETYPE_PEM) <= 0)
if (keyfile != NULL
&& SSL_CTX_use_PrivateKey_file(srv->ctx, keyfile, SSL_FILETYPE_PEM) <= 0)
goto err;
SSL_CTX_set_alpn_select_cb(srv->ctx, alpn_select_cb, srv);
@ -556,3 +558,9 @@ int ossl_quic_tserver_set_max_early_data(QUIC_TSERVER *srv,
{
return SSL_set_max_early_data(srv->tls, max_early_data);
}
void ossl_quic_tserver_set_psk_find_session_cb(QUIC_TSERVER *srv,
SSL_psk_find_session_cb_func cb)
{
SSL_set_psk_find_session_callback(srv->tls, cb);
}

View File

@ -1247,3 +1247,41 @@ void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl)
SSL_free(serverssl);
SSL_free(clientssl);
}
SSL_SESSION *create_a_psk(SSL *ssl, size_t mdsize)
{
const SSL_CIPHER *cipher = NULL;
const unsigned char key[SHA384_DIGEST_LENGTH] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a,
0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b,
0x2c, 0x2d, 0x2e, 0x2f
};
SSL_SESSION *sess = NULL;
if (mdsize == SHA384_DIGEST_LENGTH) {
cipher = SSL_CIPHER_find(ssl, TLS13_AES_256_GCM_SHA384_BYTES);
} else if (mdsize == SHA256_DIGEST_LENGTH) {
/*
* Any ciphersuite using SHA256 will do - it will be compatible with
* the actual ciphersuite selected as long as it too is based on SHA256
*/
cipher = SSL_CIPHER_find(ssl, TLS13_AES_128_GCM_SHA256_BYTES);
} else {
/* Should not happen */
return NULL;
}
sess = SSL_SESSION_new();
if (!TEST_ptr(sess)
|| !TEST_ptr(cipher)
|| !TEST_true(SSL_SESSION_set1_master_key(sess, key, mdsize))
|| !TEST_true(SSL_SESSION_set_cipher(sess, cipher))
|| !TEST_true(
SSL_SESSION_set_protocol_version(sess,
TLS1_3_VERSION))) {
SSL_SESSION_free(sess);
return NULL;
}
return sess;
}

View File

@ -12,6 +12,12 @@
# include <openssl/ssl.h>
#define TLS13_AES_128_GCM_SHA256_BYTES ((const unsigned char *)"\x13\x01")
#define TLS13_AES_256_GCM_SHA384_BYTES ((const unsigned char *)"\x13\x02")
#define TLS13_CHACHA20_POLY1305_SHA256_BYTES ((const unsigned char *)"\x13\x03")
#define TLS13_AES_128_CCM_SHA256_BYTES ((const unsigned char *)"\x13\x04")
#define TLS13_AES_128_CCM_8_SHA256_BYTES ((const unsigned char *)"\x13\05")
int create_ssl_ctx_pair(OSSL_LIB_CTX *libctx, const SSL_METHOD *sm,
const SSL_METHOD *cm, int min_proto_version,
int max_proto_version, SSL_CTX **sctx, SSL_CTX **cctx,
@ -60,4 +66,6 @@ typedef struct mempacket_st MEMPACKET;
DEFINE_STACK_OF(MEMPACKET)
SSL_SESSION *create_a_psk(SSL *ssl, size_t mdsize);
#endif /* OSSL_TEST_SSLTESTLIB_H */

View File

@ -1061,6 +1061,92 @@ static int test_non_io_retry(int idx)
return testresult;
}
static int use_session_cb_cnt = 0;
static int find_session_cb_cnt = 0;
static const char *pskid = "Identity";
static SSL_SESSION *serverpsk = NULL, *clientpsk = NULL;
static int use_session_cb(SSL *ssl, const EVP_MD *md, const unsigned char **id,
size_t *idlen, SSL_SESSION **sess)
{
use_session_cb_cnt++;
if (clientpsk == NULL)
return 0;
SSL_SESSION_up_ref(clientpsk);
*sess = clientpsk;
*id = (const unsigned char *)pskid;
*idlen = strlen(pskid);
return 1;
}
static int find_session_cb(SSL *ssl, const unsigned char *identity,
size_t identity_len, SSL_SESSION **sess)
{
find_session_cb_cnt++;
if (serverpsk == NULL)
return 0;
/* Identity should match that set by the client */
if (strlen(pskid) != identity_len
|| strncmp(pskid, (const char *)identity, identity_len) != 0)
return 0;
SSL_SESSION_up_ref(serverpsk);
*sess = serverpsk;
return 1;
}
static int test_quic_psk(void)
{
SSL_CTX *cctx = SSL_CTX_new_ex(libctx, NULL, OSSL_QUIC_client_method());
SSL *clientquic = NULL;
QUIC_TSERVER *qtserv = NULL;
int testresult = 0;
if (!TEST_ptr(cctx)
/* No cert or private key for the server, i.e. PSK only */
|| !TEST_true(qtest_create_quic_objects(libctx, cctx, NULL, NULL,
NULL, 0, &qtserv,
&clientquic, NULL)))
goto end;
SSL_set_psk_use_session_callback(clientquic, use_session_cb);
ossl_quic_tserver_set_psk_find_session_cb(qtserv, find_session_cb);
use_session_cb_cnt = 0;
find_session_cb_cnt = 0;
clientpsk = serverpsk = create_a_psk(clientquic, SHA384_DIGEST_LENGTH);
if (!TEST_ptr(clientpsk))
goto end;
/* We already had one ref. Add another one */
SSL_SESSION_up_ref(clientpsk);
if (!TEST_true(qtest_create_quic_connection(qtserv, clientquic))
|| !TEST_int_eq(1, find_session_cb_cnt)
|| !TEST_int_eq(1, use_session_cb_cnt)
/* Check that we actually used the PSK */
|| !TEST_true(SSL_session_reused(clientquic)))
goto end;
testresult = 1;
end:
SSL_free(clientquic);
ossl_quic_tserver_free(qtserv);
SSL_CTX_free(cctx);
SSL_SESSION_free(clientpsk);
SSL_SESSION_free(serverpsk);
clientpsk = serverpsk = NULL;
return testresult;
}
OPT_TEST_DECLARE_USAGE("provider config certsdir datadir\n")
int setup_tests(void)
@ -1131,6 +1217,7 @@ int setup_tests(void)
ADD_TEST(test_back_pressure);
ADD_TEST(test_multiple_dgrams);
ADD_ALL_TESTS(test_non_io_retry, 2);
ADD_TEST(test_quic_psk);
return 1;
err:
cleanup_tests();

View File

@ -77,8 +77,6 @@ static int find_session_cb(SSL *ssl, const unsigned char *identity,
static int use_session_cb_cnt = 0;
static int find_session_cb_cnt = 0;
static SSL_SESSION *create_a_psk(SSL *ssl, size_t mdsize);
#endif
static char *certsdir = NULL;
@ -3385,51 +3383,6 @@ static unsigned int psk_server_cb(SSL *ssl, const char *identity,
#define MSG6 "test"
#define MSG7 "message."
#define TLS13_AES_128_GCM_SHA256_BYTES ((const unsigned char *)"\x13\x01")
#define TLS13_AES_256_GCM_SHA384_BYTES ((const unsigned char *)"\x13\x02")
#define TLS13_CHACHA20_POLY1305_SHA256_BYTES ((const unsigned char *)"\x13\x03")
#define TLS13_AES_128_CCM_SHA256_BYTES ((const unsigned char *)"\x13\x04")
#define TLS13_AES_128_CCM_8_SHA256_BYTES ((const unsigned char *)"\x13\05")
static SSL_SESSION *create_a_psk(SSL *ssl, size_t mdsize)
{
const SSL_CIPHER *cipher = NULL;
const unsigned char key[] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a,
0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b,
0x2c, 0x2d, 0x2e, 0x2f /* SHA384_DIGEST_LENGTH bytes */
};
SSL_SESSION *sess = NULL;
if (mdsize == SHA384_DIGEST_LENGTH) {
cipher = SSL_CIPHER_find(ssl, TLS13_AES_256_GCM_SHA384_BYTES);
} else if (mdsize == SHA256_DIGEST_LENGTH) {
/*
* Any ciphersuite using SHA256 will do - it will be compatible with
* the actual ciphersuite selected as long as it too is based on SHA256
*/
cipher = SSL_CIPHER_find(ssl, TLS13_AES_128_GCM_SHA256_BYTES);
} else {
/* Should not happen */
return NULL;
}
sess = SSL_SESSION_new();
if (!TEST_ptr(sess)
|| !TEST_ptr(cipher)
|| !TEST_true(SSL_SESSION_set1_master_key(sess, key, mdsize))
|| !TEST_true(SSL_SESSION_set_cipher(sess, cipher))
|| !TEST_true(
SSL_SESSION_set_protocol_version(sess,
TLS1_3_VERSION))) {
SSL_SESSION_free(sess);
return NULL;
}
return sess;
}
static int artificial_ticket_time = 0;
static int ed_gen_cb(SSL *s, void *arg)