diff --git a/crypto/rsa/rsa_ameth.c b/crypto/rsa/rsa_ameth.c index 22c06a2139..f5911ad233 100644 --- a/crypto/rsa/rsa_ameth.c +++ b/crypto/rsa/rsa_ameth.c @@ -34,6 +34,7 @@ static int rsa_cms_encrypt(CMS_RecipientInfo *ri); #endif static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg); +static int rsa_sync_to_pss_params_30(RSA *rsa); /* Set any parameters associated with pkey */ static int rsa_param_encode(const EVP_PKEY *pkey, @@ -78,6 +79,8 @@ static int rsa_param_decode(RSA *rsa, const X509_ALGOR *alg) rsa->pss = rsa_pss_decode(alg); if (rsa->pss == NULL) return 0; + if (!rsa_sync_to_pss_params_30(rsa)) + return 0; return 1; } @@ -118,6 +121,20 @@ static int rsa_pub_decode(EVP_PKEY *pkey, const X509_PUBKEY *pubkey) RSA_free(rsa); return 0; } + + RSA_clear_flags(rsa, RSA_FLAG_TYPE_MASK); + switch (pkey->ameth->pkey_id) { + case EVP_PKEY_RSA: + RSA_set_flags(rsa, RSA_FLAG_TYPE_RSA); + break; + case EVP_PKEY_RSA_PSS: + RSA_set_flags(rsa, RSA_FLAG_TYPE_RSASSAPSS); + break; + default: + /* Leave the type bits zero */ + break; + } + if (!EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa)) { RSA_free(rsa); return 0; @@ -729,9 +746,34 @@ static int rsa_pss_to_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pkctx, return rv; } -int rsa_pss_get_param(const RSA_PSS_PARAMS *pss, const EVP_MD **pmd, - const EVP_MD **pmgf1md, int *psaltlen) +static int rsa_pss_verify_param(const EVP_MD **pmd, const EVP_MD **pmgf1md, + int *psaltlen, int *ptrailerField) { + if (psaltlen != NULL && *psaltlen < 0) { + ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_SALT_LENGTH); + return 0; + } + /* + * low-level routines support only trailer field 0xbc (value 1) and + * PKCS#1 says we should reject any other value anyway. + */ + if (ptrailerField != NULL && *ptrailerField != 1) { + ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_TRAILER); + return 0; + } + return 1; +} + +static int rsa_pss_get_param_unverified(const RSA_PSS_PARAMS *pss, + const EVP_MD **pmd, + const EVP_MD **pmgf1md, + int *psaltlen, int *ptrailerField) +{ + RSA_PSS_PARAMS_30 pss_params; + + /* Get the defaults from the ONE place */ + (void)rsa_pss_params_30_set_defaults(&pss_params); + if (pss == NULL) return 0; *pmd = rsa_algor_to_md(pss->hashAlgorithm); @@ -740,25 +782,65 @@ int rsa_pss_get_param(const RSA_PSS_PARAMS *pss, const EVP_MD **pmd, *pmgf1md = rsa_algor_to_md(pss->maskHash); if (*pmgf1md == NULL) return 0; - if (pss->saltLength) { + if (pss->saltLength) *psaltlen = ASN1_INTEGER_get(pss->saltLength); - if (*psaltlen < 0) { - RSAerr(RSA_F_RSA_PSS_GET_PARAM, RSA_R_INVALID_SALT_LENGTH); - return 0; - } - } else { - *psaltlen = 20; - } + else + *psaltlen = rsa_pss_params_30_saltlen(&pss_params); + if (pss->trailerField) + *ptrailerField = ASN1_INTEGER_get(pss->trailerField); + else + *ptrailerField = rsa_pss_params_30_trailerfield(&pss_params);; + return 1; +} + +int rsa_pss_get_param(const RSA_PSS_PARAMS *pss, const EVP_MD **pmd, + const EVP_MD **pmgf1md, int *psaltlen) +{ /* - * low-level routines support only trailer field 0xbc (value 1) and - * PKCS#1 says we should reject any other value anyway. + * Callers do not care about the trailer field, and yet, we must + * pass it from get_param to verify_param, since the latter checks + * its value. + * + * When callers start caring, it's a simple thing to add another + * argument to this function. */ - if (pss->trailerField && ASN1_INTEGER_get(pss->trailerField) != 1) { - RSAerr(RSA_F_RSA_PSS_GET_PARAM, RSA_R_INVALID_TRAILER); - return 0; - } + int trailerField = 0; + return rsa_pss_get_param_unverified(pss, pmd, pmgf1md, psaltlen, + &trailerField) + && rsa_pss_verify_param(pmd, pmgf1md, psaltlen, &trailerField); +} + +static int rsa_sync_to_pss_params_30(RSA *rsa) +{ + if (rsa != NULL && rsa->pss != NULL) { + const EVP_MD *md = NULL, *mgf1md = NULL; + int md_nid, mgf1md_nid, saltlen, trailerField; + RSA_PSS_PARAMS_30 pss_params; + + /* + * We don't care about the validity of the fields here, we just + * want to synchronise values. Verifying here makes it impossible + * to even read a key with invalid values, making it hard to test + * a bad situation. + * + * Other routines use rsa_pss_get_param(), so the values will be + * checked, eventually. + */ + if (!rsa_pss_get_param_unverified(rsa->pss, &md, &mgf1md, + &saltlen, &trailerField)) + return 0; + md_nid = EVP_MD_type(md); + mgf1md_nid = EVP_MD_type(mgf1md); + if (!rsa_pss_params_30_set_defaults(&pss_params) + || !rsa_pss_params_30_set_hashalg(&pss_params, md_nid) + || !rsa_pss_params_30_set_maskgenhashalg(&pss_params, mgf1md_nid) + || !rsa_pss_params_30_set_saltlen(&pss_params, saltlen) + || !rsa_pss_params_30_set_trailerfield(&pss_params, trailerField)) + return 0; + rsa->pss_params = pss_params; + } return 1; }