From 12ddfa6b3461698543381a6087bd3f5d3a38d3f0 Mon Sep 17 00:00:00 2001 From: Pauli Date: Tue, 2 Mar 2021 22:46:24 +1000 Subject: [PATCH] support params argument to AES cipher init calls Reviewed-by: Shane Lontis (Merged from https://github.com/openssl/openssl/pull/14383) --- .../ciphers/cipher_aes_cbc_hmac_sha.c | 25 +++++++++++++++++-- .../ciphers/cipher_aes_cts.inc | 24 ++++++++++++++++-- .../implementations/ciphers/cipher_aes_ocb.c | 21 ++++++++++------ .../implementations/ciphers/cipher_aes_siv.c | 23 +++++++++++------ .../implementations/ciphers/cipher_aes_wrp.c | 18 ++++++++----- .../implementations/ciphers/cipher_aes_xts.c | 21 ++++++++++------ 6 files changed, 101 insertions(+), 31 deletions(-) diff --git a/providers/implementations/ciphers/cipher_aes_cbc_hmac_sha.c b/providers/implementations/ciphers/cipher_aes_cbc_hmac_sha.c index b78687ceae..a0eef7c1e5 100644 --- a/providers/implementations/ciphers/cipher_aes_cbc_hmac_sha.c +++ b/providers/implementations/ciphers/cipher_aes_cbc_hmac_sha.c @@ -33,6 +33,8 @@ const OSSL_DISPATCH ossl_##nm##kbits##sub##_functions[] = { \ # define AES_CBC_HMAC_SHA_FLAGS (PROV_CIPHER_FLAG_AEAD \ | PROV_CIPHER_FLAG_TLS1_MULTIBLOCK) +static OSSL_FUNC_cipher_encrypt_init_fn aes_einit; +static OSSL_FUNC_cipher_decrypt_init_fn aes_dinit; static OSSL_FUNC_cipher_freectx_fn aes_cbc_hmac_sha1_freectx; static OSSL_FUNC_cipher_freectx_fn aes_cbc_hmac_sha256_freectx; static OSSL_FUNC_cipher_get_ctx_params_fn aes_get_ctx_params; @@ -40,12 +42,28 @@ static OSSL_FUNC_cipher_gettable_ctx_params_fn aes_gettable_ctx_params; static OSSL_FUNC_cipher_set_ctx_params_fn aes_set_ctx_params; static OSSL_FUNC_cipher_settable_ctx_params_fn aes_settable_ctx_params; # define aes_gettable_params ossl_cipher_generic_gettable_params -# define aes_einit ossl_cipher_generic_einit -# define aes_dinit ossl_cipher_generic_dinit # define aes_update ossl_cipher_generic_stream_update # define aes_final ossl_cipher_generic_stream_final # define aes_cipher ossl_cipher_generic_cipher +static int aes_einit(void *ctx, const unsigned char *key, size_t keylen, + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[]) +{ + if (!ossl_cipher_generic_einit(ctx, key, keylen, iv, ivlen, NULL)) + return 0; + return aes_set_ctx_params(ctx, params); +} + +static int aes_dinit(void *ctx, const unsigned char *key, size_t keylen, + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[]) +{ + if (!ossl_cipher_generic_dinit(ctx, key, keylen, iv, ivlen, NULL)) + return 0; + return aes_set_ctx_params(ctx, params); +} + static const OSSL_PARAM cipher_aes_known_settable_ctx_params[] = { OSSL_PARAM_octet_string(OSSL_CIPHER_PARAM_AEAD_MAC_KEY, NULL, 0), OSSL_PARAM_octet_string(OSSL_CIPHER_PARAM_AEAD_TLS1_AAD, NULL, 0), @@ -76,6 +94,9 @@ static int aes_set_ctx_params(void *vctx, const OSSL_PARAM params[]) EVP_CTRL_TLS1_1_MULTIBLOCK_PARAM mb_param; # endif + if (params == NULL) + return 1; + p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_AEAD_MAC_KEY); if (p != NULL) { if (p->data_type != OSSL_PARAM_OCTET_STRING) { diff --git a/providers/implementations/ciphers/cipher_aes_cts.inc b/providers/implementations/ciphers/cipher_aes_cts.inc index fbd66eb257..2a3b88b2c0 100644 --- a/providers/implementations/ciphers/cipher_aes_cts.inc +++ b/providers/implementations/ciphers/cipher_aes_cts.inc @@ -14,6 +14,8 @@ #define AES_CTS_FLAGS PROV_CIPHER_FLAG_CTS +static OSSL_FUNC_cipher_encrypt_init_fn aes_cbc_cts_einit; +static OSSL_FUNC_cipher_decrypt_init_fn aes_cbc_cts_dinit; static OSSL_FUNC_cipher_get_ctx_params_fn aes_cbc_cts_get_ctx_params; static OSSL_FUNC_cipher_set_ctx_params_fn aes_cbc_cts_set_ctx_params; static OSSL_FUNC_cipher_gettable_ctx_params_fn aes_cbc_cts_gettable_ctx_params; @@ -23,6 +25,24 @@ CIPHER_DEFAULT_GETTABLE_CTX_PARAMS_START(aes_cbc_cts) OSSL_PARAM_utf8_string(OSSL_CIPHER_PARAM_CTS_MODE, NULL, 0), CIPHER_DEFAULT_GETTABLE_CTX_PARAMS_END(aes_cbc_cts) +static int aes_cbc_cts_einit(void *ctx, const unsigned char *key, size_t keylen, + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[]) +{ + if (!ossl_cipher_generic_einit(ctx, key, keylen, iv, ivlen, NULL)) + return 0; + return aes_cbc_cts_set_ctx_params(ctx, params); +} + +static int aes_cbc_cts_dinit(void *ctx, const unsigned char *key, size_t keylen, + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[]) +{ + if (!ossl_cipher_generic_dinit(ctx, key, keylen, iv, ivlen, NULL)) + return 0; + return aes_cbc_cts_set_ctx_params(ctx, params); +} + static int aes_cbc_cts_get_ctx_params(void *vctx, OSSL_PARAM params[]) { PROV_CIPHER_CTX *ctx = (PROV_CIPHER_CTX *)vctx; @@ -80,8 +100,8 @@ const OSSL_DISPATCH ossl_##alg##kbits##lcmode##_cts_functions[] = { \ (void (*)(void)) alg##_##kbits##_##lcmode##_newctx }, \ { OSSL_FUNC_CIPHER_FREECTX, (void (*)(void)) alg##_freectx }, \ { OSSL_FUNC_CIPHER_DUPCTX, (void (*)(void)) alg##_dupctx }, \ - { OSSL_FUNC_CIPHER_ENCRYPT_INIT, (void (*)(void))ossl_cipher_generic_einit }, \ - { OSSL_FUNC_CIPHER_DECRYPT_INIT, (void (*)(void))ossl_cipher_generic_dinit }, \ + { OSSL_FUNC_CIPHER_ENCRYPT_INIT, (void (*)(void))aes_cbc_cts_einit }, \ + { OSSL_FUNC_CIPHER_DECRYPT_INIT, (void (*)(void))aes_cbc_cts_dinit }, \ { OSSL_FUNC_CIPHER_UPDATE, \ (void (*)(void)) ossl_##alg##_##lcmode##_cts_block_update }, \ { OSSL_FUNC_CIPHER_FINAL, \ diff --git a/providers/implementations/ciphers/cipher_aes_ocb.c b/providers/implementations/ciphers/cipher_aes_ocb.c index 627f146273..ce377ad574 100644 --- a/providers/implementations/ciphers/cipher_aes_ocb.c +++ b/providers/implementations/ciphers/cipher_aes_ocb.c @@ -102,7 +102,8 @@ static ossl_inline int aes_generic_ocb_copy_ctx(PROV_AES_OCB_CTX *dst, * Provider dispatch functions */ static int aes_ocb_init(void *vctx, const unsigned char *key, size_t keylen, - const unsigned char *iv, size_t ivlen, int enc) + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[], int enc) { PROV_AES_OCB_CTX *ctx = (PROV_AES_OCB_CTX *)vctx; @@ -131,21 +132,24 @@ static int aes_ocb_init(void *vctx, const unsigned char *key, size_t keylen, ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_KEY_LENGTH); return 0; } - return ctx->base.hw->init(&ctx->base, key, keylen); + if (!ctx->base.hw->init(&ctx->base, key, keylen)) + return 0; } - return 1; + return aes_ocb_set_ctx_params(ctx, params); } static int aes_ocb_einit(void *vctx, const unsigned char *key, size_t keylen, - const unsigned char *iv, size_t ivlen) + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[]) { - return aes_ocb_init(vctx, key, keylen, iv, ivlen, 1); + return aes_ocb_init(vctx, key, keylen, iv, ivlen, params, 1); } static int aes_ocb_dinit(void *vctx, const unsigned char *key, size_t keylen, - const unsigned char *iv, size_t ivlen) + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[]) { - return aes_ocb_init(vctx, key, keylen, iv, ivlen, 0); + return aes_ocb_init(vctx, key, keylen, iv, ivlen, params, 0); } /* @@ -354,6 +358,9 @@ static int aes_ocb_set_ctx_params(void *vctx, const OSSL_PARAM params[]) const OSSL_PARAM *p; size_t sz; + if (params == NULL) + return 1; + p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_AEAD_TAG); if (p != NULL) { if (p->data_type != OSSL_PARAM_OCTET_STRING) { diff --git a/providers/implementations/ciphers/cipher_aes_siv.c b/providers/implementations/ciphers/cipher_aes_siv.c index 9a75f6f5b7..dd3346a81c 100644 --- a/providers/implementations/ciphers/cipher_aes_siv.c +++ b/providers/implementations/ciphers/cipher_aes_siv.c @@ -25,6 +25,8 @@ #define siv_stream_update siv_cipher #define SIV_FLAGS AEAD_FLAGS +static OSSL_FUNC_cipher_set_ctx_params_fn aes_siv_set_ctx_params; + static void *aes_siv_newctx(void *provctx, size_t keybits, unsigned int mode, uint64_t flags) { @@ -75,7 +77,8 @@ static void *siv_dupctx(void *vctx) } static int siv_init(void *vctx, const unsigned char *key, size_t keylen, - const unsigned char *iv, size_t ivlen, int enc) + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[], int enc) { PROV_AES_SIV_CTX *ctx = (PROV_AES_SIV_CTX *)vctx; @@ -89,21 +92,24 @@ static int siv_init(void *vctx, const unsigned char *key, size_t keylen, ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_KEY_LENGTH); return 0; } - return ctx->hw->initkey(ctx, key, ctx->keylen); + if (!ctx->hw->initkey(ctx, key, ctx->keylen)) + return 0; } - return 1; + return aes_siv_set_ctx_params(ctx, params); } static int siv_einit(void *vctx, const unsigned char *key, size_t keylen, - const unsigned char *iv, size_t ivlen) + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[]) { - return siv_init(vctx, key, keylen, iv, ivlen, 1); + return siv_init(vctx, key, keylen, iv, ivlen, params, 1); } static int siv_dinit(void *vctx, const unsigned char *key, size_t keylen, - const unsigned char *iv, size_t ivlen) + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[]) { - return siv_init(vctx, key, keylen, iv, ivlen, 0); + return siv_init(vctx, key, keylen, iv, ivlen, params, 0); } static int siv_cipher(void *vctx, unsigned char *out, size_t *outl, @@ -195,6 +201,9 @@ static int aes_siv_set_ctx_params(void *vctx, const OSSL_PARAM params[]) const OSSL_PARAM *p; unsigned int speed = 0; + if (params == NULL) + return 1; + p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_AEAD_TAG); if (p != NULL) { if (ctx->enc) diff --git a/providers/implementations/ciphers/cipher_aes_wrp.c b/providers/implementations/ciphers/cipher_aes_wrp.c index 4428ff0552..f797db4596 100644 --- a/providers/implementations/ciphers/cipher_aes_wrp.c +++ b/providers/implementations/ciphers/cipher_aes_wrp.c @@ -34,6 +34,7 @@ static OSSL_FUNC_cipher_decrypt_init_fn aes_wrap_dinit; static OSSL_FUNC_cipher_update_fn aes_wrap_cipher; static OSSL_FUNC_cipher_final_fn aes_wrap_final; static OSSL_FUNC_cipher_freectx_fn aes_wrap_freectx; +static OSSL_FUNC_cipher_set_ctx_params_fn aes_wrap_set_ctx_params; typedef struct prov_aes_wrap_ctx_st { PROV_CIPHER_CTX base; @@ -75,7 +76,7 @@ static void aes_wrap_freectx(void *vctx) static int aes_wrap_init(void *vctx, const unsigned char *key, size_t keylen, const unsigned char *iv, - size_t ivlen, int enc) + size_t ivlen, const OSSL_PARAM params[], int enc) { PROV_CIPHER_CTX *ctx = (PROV_CIPHER_CTX *)vctx; PROV_AES_WRAP_CTX *wctx = (PROV_AES_WRAP_CTX *)vctx; @@ -121,19 +122,21 @@ static int aes_wrap_init(void *vctx, const unsigned char *key, ctx->block = (block128_f)AES_decrypt; } } - return 1; + return aes_wrap_set_ctx_params(ctx, params); } static int aes_wrap_einit(void *ctx, const unsigned char *key, size_t keylen, - const unsigned char *iv, size_t ivlen) + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[]) { - return aes_wrap_init(ctx, key, keylen, iv, ivlen, 1); + return aes_wrap_init(ctx, key, keylen, iv, ivlen, params, 1); } static int aes_wrap_dinit(void *ctx, const unsigned char *key, size_t keylen, - const unsigned char *iv, size_t ivlen) + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[]) { - return aes_wrap_init(ctx, key, keylen, iv, ivlen, 0); + return aes_wrap_init(ctx, key, keylen, iv, ivlen, params, 0); } static int aes_wrap_cipher_internal(void *vctx, unsigned char *out, @@ -226,6 +229,9 @@ static int aes_wrap_set_ctx_params(void *vctx, const OSSL_PARAM params[]) const OSSL_PARAM *p; size_t keylen = 0; + if (params == NULL) + return 1; + p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_KEYLEN); if (p != NULL) { if (!OSSL_PARAM_get_size_t(p, &keylen)) { diff --git a/providers/implementations/ciphers/cipher_aes_xts.c b/providers/implementations/ciphers/cipher_aes_xts.c index 13552b2a76..5cfb22778e 100644 --- a/providers/implementations/ciphers/cipher_aes_xts.c +++ b/providers/implementations/ciphers/cipher_aes_xts.c @@ -66,7 +66,8 @@ static int aes_xts_check_keys_differ(const unsigned char *key, size_t bytes, * Provider dispatch functions */ static int aes_xts_init(void *vctx, const unsigned char *key, size_t keylen, - const unsigned char *iv, size_t ivlen, int enc) + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[], int enc) { PROV_AES_XTS_CTX *xctx = (PROV_AES_XTS_CTX *)vctx; PROV_CIPHER_CTX *ctx = &xctx->base; @@ -87,21 +88,24 @@ static int aes_xts_init(void *vctx, const unsigned char *key, size_t keylen, } if (!aes_xts_check_keys_differ(key, keylen / 2, enc)) return 0; - return ctx->hw->init(ctx, key, keylen); + if (!ctx->hw->init(ctx, key, keylen)) + return 0; } - return 1; + return aes_xts_set_ctx_params(ctx, params); } static int aes_xts_einit(void *vctx, const unsigned char *key, size_t keylen, - const unsigned char *iv, size_t ivlen) + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[]) { - return aes_xts_init(vctx, key, keylen, iv, ivlen, 1); + return aes_xts_init(vctx, key, keylen, iv, ivlen, params, 1); } static int aes_xts_dinit(void *vctx, const unsigned char *key, size_t keylen, - const unsigned char *iv, size_t ivlen) + const unsigned char *iv, size_t ivlen, + const OSSL_PARAM params[]) { - return aes_xts_init(vctx, key, keylen, iv, ivlen, 0); + return aes_xts_init(vctx, key, keylen, iv, ivlen, params, 0); } static void *aes_xts_newctx(void *provctx, unsigned int mode, uint64_t flags, @@ -229,6 +233,9 @@ static int aes_xts_set_ctx_params(void *vctx, const OSSL_PARAM params[]) PROV_CIPHER_CTX *ctx = (PROV_CIPHER_CTX *)vctx; const OSSL_PARAM *p; + if (params == NULL) + return 1; + p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_KEYLEN); if (p != NULL) { size_t keylen;