openssl/include/internal/safe_math.h
Kelvin Lee 3189e12733 Avoid divide by 0 exception
CLA: trivial

Reviewed-by: Tomas Mraz <tomas@openssl.org>
Reviewed-by: Paul Dale <pauli@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/19310)
2022-10-02 20:41:34 +11:00

444 lines
25 KiB
C

/*
* Copyright 2021-2022 The OpenSSL Project Authors. All Rights Reserved.
*
* Licensed under the Apache License 2.0 (the "License"). You may not use
* this file except in compliance with the License. You can obtain a copy
* in the file LICENSE in the source distribution or at
* https://www.openssl.org/source/license.html
*/
#ifndef OSSL_INTERNAL_SAFE_MATH_H
# define OSSL_INTERNAL_SAFE_MATH_H
# pragma once
# include <openssl/e_os2.h> /* For 'ossl_inline' */
# ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING
# ifdef __has_builtin
# define has(func) __has_builtin(func)
# elif __GNUC__ > 5
# define has(func) 1
# endif
# endif /* OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING */
# ifndef has
# define has(func) 0
# endif
/*
* Safe addition helpers
*/
# if has(__builtin_add_overflow)
# define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
type b, \
int *err) \
{ \
type r; \
\
if (!__builtin_add_overflow(a, b, &r)) \
return r; \
*err |= 1; \
return a < 0 ? min : max; \
}
# define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
type b, \
int *err) \
{ \
type r; \
\
if (!__builtin_add_overflow(a, b, &r)) \
return r; \
*err |= 1; \
return a + b; \
}
# else /* has(__builtin_add_overflow) */
# define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
type b, \
int *err) \
{ \
if ((a < 0) ^ (b < 0) \
|| (a > 0 && b <= max - a) \
|| (a < 0 && b >= min - a) \
|| a == 0) \
return a + b; \
*err |= 1; \
return a < 0 ? min : max; \
}
# define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
type b, \
int *err) \
{ \
if (b > max - a) \
*err |= 1; \
return a + b; \
}
# endif /* has(__builtin_add_overflow) */
/*
* Safe subtraction helpers
*/
# if has(__builtin_sub_overflow)
# define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
type b, \
int *err) \
{ \
type r; \
\
if (!__builtin_sub_overflow(a, b, &r)) \
return r; \
*err |= 1; \
return a < 0 ? min : max; \
}
# else /* has(__builtin_sub_overflow) */
# define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
type b, \
int *err) \
{ \
if (!((a < 0) ^ (b < 0)) \
|| (b > 0 && a >= min + b) \
|| (b < 0 && a <= max + b) \
|| b == 0) \
return a - b; \
*err |= 1; \
return a < 0 ? min : max; \
}
# endif /* has(__builtin_sub_overflow) */
# define OSSL_SAFE_MATH_SUBU(type_name, type) \
static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
type b, \
int *err) \
{ \
if (b > a) \
*err |= 1; \
return a - b; \
}
/*
* Safe multiplication helpers
*/
# if has(__builtin_mul_overflow)
# define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
type b, \
int *err) \
{ \
type r; \
\
if (!__builtin_mul_overflow(a, b, &r)) \
return r; \
*err |= 1; \
return (a < 0) ^ (b < 0) ? min : max; \
}
# define OSSL_SAFE_MATH_MULU(type_name, type, max) \
static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
type b, \
int *err) \
{ \
type r; \
\
if (!__builtin_mul_overflow(a, b, &r)) \
return r; \
*err |= 1; \
return a * b; \
}
# else /* has(__builtin_mul_overflow) */
# define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
type b, \
int *err) \
{ \
if (a == 0 || b == 0) \
return 0; \
if (a == 1) \
return b; \
if (b == 1) \
return a; \
if (a != min && b != min) { \
const type x = a < 0 ? -a : a; \
const type y = b < 0 ? -b : b; \
\
if (x <= max / y) \
return a * b; \
} \
*err |= 1; \
return (a < 0) ^ (b < 0) ? min : max; \
}
# define OSSL_SAFE_MATH_MULU(type_name, type, max) \
static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
type b, \
int *err) \
{ \
if (b != 0 && a > max / b) \
*err |= 1; \
return a * b; \
}
# endif /* has(__builtin_mul_overflow) */
/*
* Safe division helpers
*/
# define OSSL_SAFE_MATH_DIVS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \
type b, \
int *err) \
{ \
if (b == 0) { \
*err |= 1; \
return a < 0 ? min : max; \
} \
if (b == -1 && a == min) { \
*err |= 1; \
return max; \
} \
return a / b; \
}
# define OSSL_SAFE_MATH_DIVU(type_name, type, max) \
static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \
type b, \
int *err) \
{ \
if (b != 0) \
return a / b; \
*err |= 1; \
return max; \
}
/*
* Safe modulus helpers
*/
# define OSSL_SAFE_MATH_MODS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \
type b, \
int *err) \
{ \
if (b == 0) { \
*err |= 1; \
return 0; \
} \
if (b == -1 && a == min) { \
*err |= 1; \
return max; \
} \
return a % b; \
}
# define OSSL_SAFE_MATH_MODU(type_name, type) \
static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \
type b, \
int *err) \
{ \
if (b != 0) \
return a % b; \
*err |= 1; \
return 0; \
}
/*
* Safe negation helpers
*/
# define OSSL_SAFE_MATH_NEGS(type_name, type, min) \
static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \
int *err) \
{ \
if (a != min) \
return -a; \
*err |= 1; \
return min; \
}
# define OSSL_SAFE_MATH_NEGU(type_name, type) \
static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \
int *err) \
{ \
if (a == 0) \
return a; \
*err |= 1; \
return 1 + ~a; \
}
/*
* Safe absolute value helpers
*/
# define OSSL_SAFE_MATH_ABSS(type_name, type, min) \
static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \
int *err) \
{ \
if (a != min) \
return a < 0 ? -a : a; \
*err |= 1; \
return min; \
}
# define OSSL_SAFE_MATH_ABSU(type_name, type) \
static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \
int *err) \
{ \
return a; \
}
/*
* Safe fused multiply divide helpers
*
* These are a bit obscure:
* . They begin by checking the denominator for zero and getting rid of this
* corner case.
*
* . Second is an attempt to do the multiplication directly, if it doesn't
* overflow, the quotient is returned (for signed values there is a
* potential problem here which isn't present for unsigned).
*
* . Finally, the multiplication/division is transformed so that the larger
* of the numerators is divided first. This requires a remainder
* correction:
*
* a b / c = (a / c) b + (a mod c) b / c, where a > b
*
* The individual operations need to be overflow checked (again signed
* being more problematic).
*
* The algorithm used is not perfect but it should be "good enough".
*/
# define OSSL_SAFE_MATH_MULDIVS(type_name, type, max) \
static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \
type b, \
type c, \
int *err) \
{ \
int e2 = 0; \
type q, r, x, y; \
\
if (c == 0) { \
*err |= 1; \
return a == 0 || b == 0 ? 0 : max; \
} \
x = safe_mul_ ## type_name(a, b, &e2); \
if (!e2) \
return safe_div_ ## type_name(x, c, err); \
if (b > a) { \
x = b; \
b = a; \
a = x; \
} \
q = safe_div_ ## type_name(a, c, err); \
r = safe_mod_ ## type_name(a, c, err); \
x = safe_mul_ ## type_name(r, b, err); \
y = safe_mul_ ## type_name(q, b, err); \
q = safe_div_ ## type_name(x, c, err); \
return safe_add_ ## type_name(y, q, err); \
}
# define OSSL_SAFE_MATH_MULDIVU(type_name, type, max) \
static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \
type b, \
type c, \
int *err) \
{ \
int e2 = 0; \
type x, y; \
\
if (c == 0) { \
*err |= 1; \
return a == 0 || b == 0 ? 0 : max; \
} \
x = safe_mul_ ## type_name(a, b, &e2); \
if (!e2) \
return x / c; \
if (b > a) { \
x = b; \
b = a; \
a = x; \
} \
x = safe_mul_ ## type_name(a % c, b, err); \
y = safe_mul_ ## type_name(a / c, b, err); \
return safe_add_ ## type_name(y, x / c, err); \
}
/*
* Calculate a / b rounding up:
* i.e. a / b + (a % b != 0)
* Which is usually (less safely) converted to (a + b - 1) / b
* If you *know* that b != 0, then it's safe to ignore err.
*/
#define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max) \
static ossl_inline ossl_unused type safe_div_round_up_ ## type_name \
(type a, type b, int *errp) \
{ \
type x; \
int *err, err_local = 0; \
\
/* Allow errors to be ignored by callers */ \
err = errp != NULL ? errp : &err_local; \
/* Fast path, both positive */ \
if (b > 0 && a > 0) { \
/* Faster path: no overflow concerns */ \
if (a < max - b) \
return (a + b - 1) / b; \
return a / b + (a % b != 0); \
} \
if (b == 0) { \
*err |= 1; \
return a == 0 ? 0 : max; \
} \
if (a == 0) \
return 0; \
/* Rather slow path because there are negatives involved */ \
x = safe_mod_ ## type_name(a, b, err); \
return safe_add_ ## type_name(safe_div_ ## type_name(a, b, err), \
x != 0, err); \
}
/* Calculate ranges of types */
# define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1))
# define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type))
# define OSSL_SAFE_MATH_MAXU(type) (~(type)0)
/*
* Wrapper macros to create all the functions of a given type
*/
# define OSSL_SAFE_MATH_SIGNED(type_name, type) \
OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \
OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type)) \
OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type))
# define OSSL_SAFE_MATH_UNSIGNED(type_name, type) \
OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
OSSL_SAFE_MATH_SUBU(type_name, type) \
OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
OSSL_SAFE_MATH_MODU(type_name, type) \
OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \
OSSL_SAFE_MATH_MAXU(type)) \
OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
OSSL_SAFE_MATH_NEGU(type_name, type) \
OSSL_SAFE_MATH_ABSU(type_name, type)
#endif /* OSSL_INTERNAL_SAFE_MATH_H */