diff --git a/include/internal/quic_tserver.h b/include/internal/quic_tserver.h index 9213f60666..4f358dd4e8 100644 --- a/include/internal/quic_tserver.h +++ b/include/internal/quic_tserver.h @@ -211,6 +211,10 @@ int ossl_quic_tserver_new_ticket(QUIC_TSERVER *srv); int ossl_quic_tserver_set_max_early_data(QUIC_TSERVER *srv, uint32_t max_early_data); +/* Set the find session callback for getting a server PSK */ +void ossl_quic_tserver_set_psk_find_session_cb(QUIC_TSERVER *srv, + SSL_psk_find_session_cb_func cb); + # endif #endif diff --git a/ssl/quic/quic_tserver.c b/ssl/quic/quic_tserver.c index 788d4780d8..92c17d10f3 100644 --- a/ssl/quic/quic_tserver.c +++ b/ssl/quic/quic_tserver.c @@ -99,10 +99,12 @@ QUIC_TSERVER *ossl_quic_tserver_new(const QUIC_TSERVER_ARGS *args, if (srv->ctx == NULL) goto err; - if (SSL_CTX_use_certificate_file(srv->ctx, certfile, SSL_FILETYPE_PEM) <= 0) + if (certfile != NULL + && SSL_CTX_use_certificate_file(srv->ctx, certfile, SSL_FILETYPE_PEM) <= 0) goto err; - if (SSL_CTX_use_PrivateKey_file(srv->ctx, keyfile, SSL_FILETYPE_PEM) <= 0) + if (keyfile != NULL + && SSL_CTX_use_PrivateKey_file(srv->ctx, keyfile, SSL_FILETYPE_PEM) <= 0) goto err; SSL_CTX_set_alpn_select_cb(srv->ctx, alpn_select_cb, srv); @@ -556,3 +558,9 @@ int ossl_quic_tserver_set_max_early_data(QUIC_TSERVER *srv, { return SSL_set_max_early_data(srv->tls, max_early_data); } + +void ossl_quic_tserver_set_psk_find_session_cb(QUIC_TSERVER *srv, + SSL_psk_find_session_cb_func cb) +{ + SSL_set_psk_find_session_callback(srv->tls, cb); +} diff --git a/test/helpers/ssltestlib.c b/test/helpers/ssltestlib.c index 94a170b9a5..0b1e56f064 100644 --- a/test/helpers/ssltestlib.c +++ b/test/helpers/ssltestlib.c @@ -1247,3 +1247,41 @@ void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl) SSL_free(serverssl); SSL_free(clientssl); } + +SSL_SESSION *create_a_psk(SSL *ssl, size_t mdsize) +{ + const SSL_CIPHER *cipher = NULL; + const unsigned char key[SHA384_DIGEST_LENGTH] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, + 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, + 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, + 0x2c, 0x2d, 0x2e, 0x2f + }; + SSL_SESSION *sess = NULL; + + if (mdsize == SHA384_DIGEST_LENGTH) { + cipher = SSL_CIPHER_find(ssl, TLS13_AES_256_GCM_SHA384_BYTES); + } else if (mdsize == SHA256_DIGEST_LENGTH) { + /* + * Any ciphersuite using SHA256 will do - it will be compatible with + * the actual ciphersuite selected as long as it too is based on SHA256 + */ + cipher = SSL_CIPHER_find(ssl, TLS13_AES_128_GCM_SHA256_BYTES); + } else { + /* Should not happen */ + return NULL; + } + sess = SSL_SESSION_new(); + if (!TEST_ptr(sess) + || !TEST_ptr(cipher) + || !TEST_true(SSL_SESSION_set1_master_key(sess, key, mdsize)) + || !TEST_true(SSL_SESSION_set_cipher(sess, cipher)) + || !TEST_true( + SSL_SESSION_set_protocol_version(sess, + TLS1_3_VERSION))) { + SSL_SESSION_free(sess); + return NULL; + } + return sess; +} diff --git a/test/helpers/ssltestlib.h b/test/helpers/ssltestlib.h index c8dcb8a82d..c513769ddd 100644 --- a/test/helpers/ssltestlib.h +++ b/test/helpers/ssltestlib.h @@ -12,6 +12,12 @@ # include +#define TLS13_AES_128_GCM_SHA256_BYTES ((const unsigned char *)"\x13\x01") +#define TLS13_AES_256_GCM_SHA384_BYTES ((const unsigned char *)"\x13\x02") +#define TLS13_CHACHA20_POLY1305_SHA256_BYTES ((const unsigned char *)"\x13\x03") +#define TLS13_AES_128_CCM_SHA256_BYTES ((const unsigned char *)"\x13\x04") +#define TLS13_AES_128_CCM_8_SHA256_BYTES ((const unsigned char *)"\x13\05") + int create_ssl_ctx_pair(OSSL_LIB_CTX *libctx, const SSL_METHOD *sm, const SSL_METHOD *cm, int min_proto_version, int max_proto_version, SSL_CTX **sctx, SSL_CTX **cctx, @@ -60,4 +66,6 @@ typedef struct mempacket_st MEMPACKET; DEFINE_STACK_OF(MEMPACKET) +SSL_SESSION *create_a_psk(SSL *ssl, size_t mdsize); + #endif /* OSSL_TEST_SSLTESTLIB_H */ diff --git a/test/quicapitest.c b/test/quicapitest.c index 87c134eb88..a24946a649 100644 --- a/test/quicapitest.c +++ b/test/quicapitest.c @@ -1061,6 +1061,92 @@ static int test_non_io_retry(int idx) return testresult; } +static int use_session_cb_cnt = 0; +static int find_session_cb_cnt = 0; +static const char *pskid = "Identity"; +static SSL_SESSION *serverpsk = NULL, *clientpsk = NULL; + +static int use_session_cb(SSL *ssl, const EVP_MD *md, const unsigned char **id, + size_t *idlen, SSL_SESSION **sess) +{ + use_session_cb_cnt++; + + if (clientpsk == NULL) + return 0; + + SSL_SESSION_up_ref(clientpsk); + + *sess = clientpsk; + *id = (const unsigned char *)pskid; + *idlen = strlen(pskid); + + return 1; +} + +static int find_session_cb(SSL *ssl, const unsigned char *identity, + size_t identity_len, SSL_SESSION **sess) +{ + find_session_cb_cnt++; + + if (serverpsk == NULL) + return 0; + + /* Identity should match that set by the client */ + if (strlen(pskid) != identity_len + || strncmp(pskid, (const char *)identity, identity_len) != 0) + return 0; + + SSL_SESSION_up_ref(serverpsk); + *sess = serverpsk; + + return 1; +} + +static int test_quic_psk(void) +{ + SSL_CTX *cctx = SSL_CTX_new_ex(libctx, NULL, OSSL_QUIC_client_method()); + SSL *clientquic = NULL; + QUIC_TSERVER *qtserv = NULL; + int testresult = 0; + + if (!TEST_ptr(cctx) + /* No cert or private key for the server, i.e. PSK only */ + || !TEST_true(qtest_create_quic_objects(libctx, cctx, NULL, NULL, + NULL, 0, &qtserv, + &clientquic, NULL))) + goto end; + + SSL_set_psk_use_session_callback(clientquic, use_session_cb); + ossl_quic_tserver_set_psk_find_session_cb(qtserv, find_session_cb); + use_session_cb_cnt = 0; + find_session_cb_cnt = 0; + + clientpsk = serverpsk = create_a_psk(clientquic, SHA384_DIGEST_LENGTH); + if (!TEST_ptr(clientpsk)) + goto end; + /* We already had one ref. Add another one */ + SSL_SESSION_up_ref(clientpsk); + + if (!TEST_true(qtest_create_quic_connection(qtserv, clientquic)) + || !TEST_int_eq(1, find_session_cb_cnt) + || !TEST_int_eq(1, use_session_cb_cnt) + /* Check that we actually used the PSK */ + || !TEST_true(SSL_session_reused(clientquic))) + goto end; + + testresult = 1; + + end: + SSL_free(clientquic); + ossl_quic_tserver_free(qtserv); + SSL_CTX_free(cctx); + SSL_SESSION_free(clientpsk); + SSL_SESSION_free(serverpsk); + clientpsk = serverpsk = NULL; + + return testresult; +} + OPT_TEST_DECLARE_USAGE("provider config certsdir datadir\n") int setup_tests(void) @@ -1131,6 +1217,7 @@ int setup_tests(void) ADD_TEST(test_back_pressure); ADD_TEST(test_multiple_dgrams); ADD_ALL_TESTS(test_non_io_retry, 2); + ADD_TEST(test_quic_psk); return 1; err: cleanup_tests(); diff --git a/test/sslapitest.c b/test/sslapitest.c index 756675c1dc..ec29157007 100644 --- a/test/sslapitest.c +++ b/test/sslapitest.c @@ -77,8 +77,6 @@ static int find_session_cb(SSL *ssl, const unsigned char *identity, static int use_session_cb_cnt = 0; static int find_session_cb_cnt = 0; - -static SSL_SESSION *create_a_psk(SSL *ssl, size_t mdsize); #endif static char *certsdir = NULL; @@ -3385,51 +3383,6 @@ static unsigned int psk_server_cb(SSL *ssl, const char *identity, #define MSG6 "test" #define MSG7 "message." -#define TLS13_AES_128_GCM_SHA256_BYTES ((const unsigned char *)"\x13\x01") -#define TLS13_AES_256_GCM_SHA384_BYTES ((const unsigned char *)"\x13\x02") -#define TLS13_CHACHA20_POLY1305_SHA256_BYTES ((const unsigned char *)"\x13\x03") -#define TLS13_AES_128_CCM_SHA256_BYTES ((const unsigned char *)"\x13\x04") -#define TLS13_AES_128_CCM_8_SHA256_BYTES ((const unsigned char *)"\x13\05") - - -static SSL_SESSION *create_a_psk(SSL *ssl, size_t mdsize) -{ - const SSL_CIPHER *cipher = NULL; - const unsigned char key[] = { - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, - 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, - 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, - 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, - 0x2c, 0x2d, 0x2e, 0x2f /* SHA384_DIGEST_LENGTH bytes */ - }; - SSL_SESSION *sess = NULL; - - if (mdsize == SHA384_DIGEST_LENGTH) { - cipher = SSL_CIPHER_find(ssl, TLS13_AES_256_GCM_SHA384_BYTES); - } else if (mdsize == SHA256_DIGEST_LENGTH) { - /* - * Any ciphersuite using SHA256 will do - it will be compatible with - * the actual ciphersuite selected as long as it too is based on SHA256 - */ - cipher = SSL_CIPHER_find(ssl, TLS13_AES_128_GCM_SHA256_BYTES); - } else { - /* Should not happen */ - return NULL; - } - sess = SSL_SESSION_new(); - if (!TEST_ptr(sess) - || !TEST_ptr(cipher) - || !TEST_true(SSL_SESSION_set1_master_key(sess, key, mdsize)) - || !TEST_true(SSL_SESSION_set_cipher(sess, cipher)) - || !TEST_true( - SSL_SESSION_set_protocol_version(sess, - TLS1_3_VERSION))) { - SSL_SESSION_free(sess); - return NULL; - } - return sess; -} - static int artificial_ticket_time = 0; static int ed_gen_cb(SSL *s, void *arg)