Prevent numeric overflows in parallel numeric aggregates.

Formerly various numeric aggregate functions supported parallel
aggregation by having each worker convert partial aggregate values to
Numeric and use numeric_send() as part of serializing their state.
That's problematic, since the range of Numeric is smaller than that of
NumericVar, so it's possible for it to overflow (on either side of the
decimal point) in cases that would succeed in non-parallel mode.

Fix by serializing NumericVars instead, to avoid the overflow risk and
ensure that parallel and non-parallel modes work the same.

A side benefit is that this improves the efficiency of the
serialization/deserialization code, which can make a noticeable
difference to performance with large numbers of parallel workers.

No back-patch due to risk from changing the binary format of the
aggregate serialization states, as well as lack of prior field
complaints and low probability of such overflows in practice.

Patch by me. Thanks to David Rowley for review and performance
testing, and Ranier Vilela for an additional suggestion.

Discussion: https://postgr.es/m/CAEZATCUmeFWCrq2dNzZpRj5+6LfN85jYiDoqm+ucSXhb9U2TbA@mail.gmail.com
This commit is contained in:
Dean Rasheed 2021-07-05 10:16:42 +01:00
parent 903d9aa780
commit f025f2390e
3 changed files with 203 additions and 138 deletions

View File

@ -515,6 +515,9 @@ static void set_var_from_var(const NumericVar *value, NumericVar *dest);
static char *get_str_from_var(const NumericVar *var);
static char *get_str_from_var_sci(const NumericVar *var, int rscale);
static void numericvar_serialize(StringInfo buf, const NumericVar *var);
static void numericvar_deserialize(StringInfo buf, NumericVar *var);
static Numeric duplicate_numeric(Numeric num);
static Numeric make_result(const NumericVar *var);
static Numeric make_result_opt_error(const NumericVar *var, bool *error);
@ -4943,8 +4946,6 @@ numeric_avg_serialize(PG_FUNCTION_ARGS)
{
NumericAggState *state;
StringInfoData buf;
Datum temp;
bytea *sumX;
bytea *result;
NumericVar tmp_var;
@ -4954,19 +4955,7 @@ numeric_avg_serialize(PG_FUNCTION_ARGS)
state = (NumericAggState *) PG_GETARG_POINTER(0);
/*
* This is a little wasteful since make_result converts the NumericVar
* into a Numeric and numeric_send converts it back again. Is it worth
* splitting the tasks in numeric_send into separate functions to stop
* this? Doing so would also remove the fmgr call overhead.
*/
init_var(&tmp_var);
accum_sum_final(&state->sumX, &tmp_var);
temp = DirectFunctionCall1(numeric_send,
NumericGetDatum(make_result(&tmp_var)));
sumX = DatumGetByteaPP(temp);
free_var(&tmp_var);
pq_begintypsend(&buf);
@ -4974,7 +4963,8 @@ numeric_avg_serialize(PG_FUNCTION_ARGS)
pq_sendint64(&buf, state->N);
/* sumX */
pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
accum_sum_final(&state->sumX, &tmp_var);
numericvar_serialize(&buf, &tmp_var);
/* maxScale */
pq_sendint32(&buf, state->maxScale);
@ -4993,6 +4983,8 @@ numeric_avg_serialize(PG_FUNCTION_ARGS)
result = pq_endtypsend(&buf);
free_var(&tmp_var);
PG_RETURN_BYTEA_P(result);
}
@ -5006,15 +4998,16 @@ numeric_avg_deserialize(PG_FUNCTION_ARGS)
{
bytea *sstate;
NumericAggState *result;
Datum temp;
NumericVar tmp_var;
StringInfoData buf;
NumericVar tmp_var;
if (!AggCheckCallContext(fcinfo, NULL))
elog(ERROR, "aggregate function called in non-aggregate context");
sstate = PG_GETARG_BYTEA_PP(0);
init_var(&tmp_var);
/*
* Copy the bytea into a StringInfo so that we can "receive" it using the
* standard recv-function infrastructure.
@ -5029,11 +5022,7 @@ numeric_avg_deserialize(PG_FUNCTION_ARGS)
result->N = pq_getmsgint64(&buf);
/* sumX */
temp = DirectFunctionCall3(numeric_recv,
PointerGetDatum(&buf),
ObjectIdGetDatum(InvalidOid),
Int32GetDatum(-1));
init_var_from_num(DatumGetNumeric(temp), &tmp_var);
numericvar_deserialize(&buf, &tmp_var);
accum_sum_add(&(result->sumX), &tmp_var);
/* maxScale */
@ -5054,6 +5043,8 @@ numeric_avg_deserialize(PG_FUNCTION_ARGS)
pq_getmsgend(&buf);
pfree(buf.data);
free_var(&tmp_var);
PG_RETURN_POINTER(result);
}
@ -5067,11 +5058,8 @@ numeric_serialize(PG_FUNCTION_ARGS)
{
NumericAggState *state;
StringInfoData buf;
Datum temp;
bytea *sumX;
NumericVar tmp_var;
bytea *sumX2;
bytea *result;
NumericVar tmp_var;
/* Ensure we disallow calling when not in aggregate context */
if (!AggCheckCallContext(fcinfo, NULL))
@ -5079,36 +5067,20 @@ numeric_serialize(PG_FUNCTION_ARGS)
state = (NumericAggState *) PG_GETARG_POINTER(0);
/*
* This is a little wasteful since make_result converts the NumericVar
* into a Numeric and numeric_send converts it back again. Is it worth
* splitting the tasks in numeric_send into separate functions to stop
* this? Doing so would also remove the fmgr call overhead.
*/
init_var(&tmp_var);
accum_sum_final(&state->sumX, &tmp_var);
temp = DirectFunctionCall1(numeric_send,
NumericGetDatum(make_result(&tmp_var)));
sumX = DatumGetByteaPP(temp);
accum_sum_final(&state->sumX2, &tmp_var);
temp = DirectFunctionCall1(numeric_send,
NumericGetDatum(make_result(&tmp_var)));
sumX2 = DatumGetByteaPP(temp);
free_var(&tmp_var);
pq_begintypsend(&buf);
/* N */
pq_sendint64(&buf, state->N);
/* sumX */
pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
accum_sum_final(&state->sumX, &tmp_var);
numericvar_serialize(&buf, &tmp_var);
/* sumX2 */
pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
accum_sum_final(&state->sumX2, &tmp_var);
numericvar_serialize(&buf, &tmp_var);
/* maxScale */
pq_sendint32(&buf, state->maxScale);
@ -5127,6 +5099,8 @@ numeric_serialize(PG_FUNCTION_ARGS)
result = pq_endtypsend(&buf);
free_var(&tmp_var);
PG_RETURN_BYTEA_P(result);
}
@ -5140,16 +5114,16 @@ numeric_deserialize(PG_FUNCTION_ARGS)
{
bytea *sstate;
NumericAggState *result;
Datum temp;
NumericVar sumX_var;
NumericVar sumX2_var;
StringInfoData buf;
NumericVar tmp_var;
if (!AggCheckCallContext(fcinfo, NULL))
elog(ERROR, "aggregate function called in non-aggregate context");
sstate = PG_GETARG_BYTEA_PP(0);
init_var(&tmp_var);
/*
* Copy the bytea into a StringInfo so that we can "receive" it using the
* standard recv-function infrastructure.
@ -5164,20 +5138,12 @@ numeric_deserialize(PG_FUNCTION_ARGS)
result->N = pq_getmsgint64(&buf);
/* sumX */
temp = DirectFunctionCall3(numeric_recv,
PointerGetDatum(&buf),
ObjectIdGetDatum(InvalidOid),
Int32GetDatum(-1));
init_var_from_num(DatumGetNumeric(temp), &sumX_var);
accum_sum_add(&(result->sumX), &sumX_var);
numericvar_deserialize(&buf, &tmp_var);
accum_sum_add(&(result->sumX), &tmp_var);
/* sumX2 */
temp = DirectFunctionCall3(numeric_recv,
PointerGetDatum(&buf),
ObjectIdGetDatum(InvalidOid),
Int32GetDatum(-1));
init_var_from_num(DatumGetNumeric(temp), &sumX2_var);
accum_sum_add(&(result->sumX2), &sumX2_var);
numericvar_deserialize(&buf, &tmp_var);
accum_sum_add(&(result->sumX2), &tmp_var);
/* maxScale */
result->maxScale = pq_getmsgint(&buf, 4);
@ -5197,6 +5163,8 @@ numeric_deserialize(PG_FUNCTION_ARGS)
pq_getmsgend(&buf);
pfree(buf.data);
free_var(&tmp_var);
PG_RETURN_POINTER(result);
}
@ -5459,9 +5427,8 @@ numeric_poly_serialize(PG_FUNCTION_ARGS)
{
PolyNumAggState *state;
StringInfoData buf;
bytea *sumX;
bytea *sumX2;
bytea *result;
NumericVar tmp_var;
/* Ensure we disallow calling when not in aggregate context */
if (!AggCheckCallContext(fcinfo, NULL))
@ -5477,32 +5444,8 @@ numeric_poly_serialize(PG_FUNCTION_ARGS)
* day we might like to send these over to another server for further
* processing and we want a standard format to work with.
*/
{
Datum temp;
NumericVar num;
init_var(&num);
#ifdef HAVE_INT128
int128_to_numericvar(state->sumX, &num);
#else
accum_sum_final(&state->sumX, &num);
#endif
temp = DirectFunctionCall1(numeric_send,
NumericGetDatum(make_result(&num)));
sumX = DatumGetByteaPP(temp);
#ifdef HAVE_INT128
int128_to_numericvar(state->sumX2, &num);
#else
accum_sum_final(&state->sumX2, &num);
#endif
temp = DirectFunctionCall1(numeric_send,
NumericGetDatum(make_result(&num)));
sumX2 = DatumGetByteaPP(temp);
free_var(&num);
}
init_var(&tmp_var);
pq_begintypsend(&buf);
@ -5510,13 +5453,25 @@ numeric_poly_serialize(PG_FUNCTION_ARGS)
pq_sendint64(&buf, state->N);
/* sumX */
pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
#ifdef HAVE_INT128
int128_to_numericvar(state->sumX, &tmp_var);
#else
accum_sum_final(&state->sumX, &tmp_var);
#endif
numericvar_serialize(&buf, &tmp_var);
/* sumX2 */
pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
#ifdef HAVE_INT128
int128_to_numericvar(state->sumX2, &tmp_var);
#else
accum_sum_final(&state->sumX2, &tmp_var);
#endif
numericvar_serialize(&buf, &tmp_var);
result = pq_endtypsend(&buf);
free_var(&tmp_var);
PG_RETURN_BYTEA_P(result);
}
@ -5530,17 +5485,16 @@ numeric_poly_deserialize(PG_FUNCTION_ARGS)
{
bytea *sstate;
PolyNumAggState *result;
Datum sumX;
NumericVar sumX_var;
Datum sumX2;
NumericVar sumX2_var;
StringInfoData buf;
NumericVar tmp_var;
if (!AggCheckCallContext(fcinfo, NULL))
elog(ERROR, "aggregate function called in non-aggregate context");
sstate = PG_GETARG_BYTEA_PP(0);
init_var(&tmp_var);
/*
* Copy the bytea into a StringInfo so that we can "receive" it using the
* standard recv-function infrastructure.
@ -5555,34 +5509,26 @@ numeric_poly_deserialize(PG_FUNCTION_ARGS)
result->N = pq_getmsgint64(&buf);
/* sumX */
sumX = DirectFunctionCall3(numeric_recv,
PointerGetDatum(&buf),
ObjectIdGetDatum(InvalidOid),
Int32GetDatum(-1));
/* sumX2 */
sumX2 = DirectFunctionCall3(numeric_recv,
PointerGetDatum(&buf),
ObjectIdGetDatum(InvalidOid),
Int32GetDatum(-1));
init_var_from_num(DatumGetNumeric(sumX), &sumX_var);
numericvar_deserialize(&buf, &tmp_var);
#ifdef HAVE_INT128
numericvar_to_int128(&sumX_var, &result->sumX);
numericvar_to_int128(&tmp_var, &result->sumX);
#else
accum_sum_add(&result->sumX, &sumX_var);
accum_sum_add(&result->sumX, &tmp_var);
#endif
init_var_from_num(DatumGetNumeric(sumX2), &sumX2_var);
/* sumX2 */
numericvar_deserialize(&buf, &tmp_var);
#ifdef HAVE_INT128
numericvar_to_int128(&sumX2_var, &result->sumX2);
numericvar_to_int128(&tmp_var, &result->sumX2);
#else
accum_sum_add(&result->sumX2, &sumX2_var);
accum_sum_add(&result->sumX2, &tmp_var);
#endif
pq_getmsgend(&buf);
pfree(buf.data);
free_var(&tmp_var);
PG_RETURN_POINTER(result);
}
@ -5681,8 +5627,8 @@ int8_avg_serialize(PG_FUNCTION_ARGS)
{
PolyNumAggState *state;
StringInfoData buf;
bytea *sumX;
bytea *result;
NumericVar tmp_var;
/* Ensure we disallow calling when not in aggregate context */
if (!AggCheckCallContext(fcinfo, NULL))
@ -5698,23 +5644,8 @@ int8_avg_serialize(PG_FUNCTION_ARGS)
* like to send these over to another server for further processing and we
* want a standard format to work with.
*/
{
Datum temp;
NumericVar num;
init_var(&num);
#ifdef HAVE_INT128
int128_to_numericvar(state->sumX, &num);
#else
accum_sum_final(&state->sumX, &num);
#endif
temp = DirectFunctionCall1(numeric_send,
NumericGetDatum(make_result(&num)));
sumX = DatumGetByteaPP(temp);
free_var(&num);
}
init_var(&tmp_var);
pq_begintypsend(&buf);
@ -5722,10 +5653,17 @@ int8_avg_serialize(PG_FUNCTION_ARGS)
pq_sendint64(&buf, state->N);
/* sumX */
pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
#ifdef HAVE_INT128
int128_to_numericvar(state->sumX, &tmp_var);
#else
accum_sum_final(&state->sumX, &tmp_var);
#endif
numericvar_serialize(&buf, &tmp_var);
result = pq_endtypsend(&buf);
free_var(&tmp_var);
PG_RETURN_BYTEA_P(result);
}
@ -5739,14 +5677,15 @@ int8_avg_deserialize(PG_FUNCTION_ARGS)
bytea *sstate;
PolyNumAggState *result;
StringInfoData buf;
Datum temp;
NumericVar num;
NumericVar tmp_var;
if (!AggCheckCallContext(fcinfo, NULL))
elog(ERROR, "aggregate function called in non-aggregate context");
sstate = PG_GETARG_BYTEA_PP(0);
init_var(&tmp_var);
/*
* Copy the bytea into a StringInfo so that we can "receive" it using the
* standard recv-function infrastructure.
@ -5761,20 +5700,18 @@ int8_avg_deserialize(PG_FUNCTION_ARGS)
result->N = pq_getmsgint64(&buf);
/* sumX */
temp = DirectFunctionCall3(numeric_recv,
PointerGetDatum(&buf),
ObjectIdGetDatum(InvalidOid),
Int32GetDatum(-1));
init_var_from_num(DatumGetNumeric(temp), &num);
numericvar_deserialize(&buf, &tmp_var);
#ifdef HAVE_INT128
numericvar_to_int128(&num, &result->sumX);
numericvar_to_int128(&tmp_var, &result->sumX);
#else
accum_sum_add(&result->sumX, &num);
accum_sum_add(&result->sumX, &tmp_var);
#endif
pq_getmsgend(&buf);
pfree(buf.data);
free_var(&tmp_var);
PG_RETURN_POINTER(result);
}
@ -7286,6 +7223,48 @@ get_str_from_var_sci(const NumericVar *var, int rscale)
}
/*
* numericvar_serialize - serialize NumericVar to binary format
*
* At variable level, no checks are performed on the weight or dscale, allowing
* us to pass around intermediate values with higher precision than supported
* by the numeric type. Note: this is incompatible with numeric_send/recv(),
* which use 16-bit integers for these fields.
*/
static void
numericvar_serialize(StringInfo buf, const NumericVar *var)
{
int i;
pq_sendint32(buf, var->ndigits);
pq_sendint32(buf, var->weight);
pq_sendint32(buf, var->sign);
pq_sendint32(buf, var->dscale);
for (i = 0; i < var->ndigits; i++)
pq_sendint16(buf, var->digits[i]);
}
/*
* numericvar_deserialize - deserialize binary format to NumericVar
*/
static void
numericvar_deserialize(StringInfo buf, NumericVar *var)
{
int len,
i;
len = pq_getmsgint(buf, sizeof(int32));
alloc_var(var, len); /* sets var->ndigits */
var->weight = pq_getmsgint(buf, sizeof(int32));
var->sign = pq_getmsgint(buf, sizeof(int32));
var->dscale = pq_getmsgint(buf, sizeof(int32));
for (i = 0; i < len; i++)
var->digits[i] = pq_getmsgint(buf, sizeof(int16));
}
/*
* duplicate_numeric() - copy a packed-format Numeric
*

View File

@ -2966,6 +2966,56 @@ SELECT SUM((-9999)::numeric) FROM generate_series(1, 100000);
-999900000
(1 row)
--
-- Tests for VARIANCE()
--
CREATE TABLE num_variance (a numeric);
INSERT INTO num_variance VALUES (0);
INSERT INTO num_variance VALUES (3e-500);
INSERT INTO num_variance VALUES (-3e-500);
INSERT INTO num_variance VALUES (4e-500 - 1e-16383);
INSERT INTO num_variance VALUES (-4e-500 + 1e-16383);
-- variance is just under 12.5e-1000 and so should round down to 12e-1000
SELECT trim_scale(variance(a) * 1e1000) FROM num_variance;
trim_scale
------------
12
(1 row)
-- check that parallel execution produces the same result
BEGIN;
ALTER TABLE num_variance SET (parallel_workers = 4);
SET LOCAL parallel_setup_cost = 0;
SET LOCAL max_parallel_workers_per_gather = 4;
SELECT trim_scale(variance(a) * 1e1000) FROM num_variance;
trim_scale
------------
12
(1 row)
ROLLBACK;
-- case where sum of squares would overflow but variance does not
DELETE FROM num_variance;
INSERT INTO num_variance SELECT 9e131071 + x FROM generate_series(1, 5) x;
SELECT variance(a) FROM num_variance;
variance
--------------------
2.5000000000000000
(1 row)
-- check that parallel execution produces the same result
BEGIN;
ALTER TABLE num_variance SET (parallel_workers = 4);
SET LOCAL parallel_setup_cost = 0;
SET LOCAL max_parallel_workers_per_gather = 4;
SELECT variance(a) FROM num_variance;
variance
--------------------
2.5000000000000000
(1 row)
ROLLBACK;
DROP TABLE num_variance;
--
-- Tests for GCD()
--

View File

@ -1277,6 +1277,42 @@ select trim_scale(1e100);
SELECT SUM(9999::numeric) FROM generate_series(1, 100000);
SELECT SUM((-9999)::numeric) FROM generate_series(1, 100000);
--
-- Tests for VARIANCE()
--
CREATE TABLE num_variance (a numeric);
INSERT INTO num_variance VALUES (0);
INSERT INTO num_variance VALUES (3e-500);
INSERT INTO num_variance VALUES (-3e-500);
INSERT INTO num_variance VALUES (4e-500 - 1e-16383);
INSERT INTO num_variance VALUES (-4e-500 + 1e-16383);
-- variance is just under 12.5e-1000 and so should round down to 12e-1000
SELECT trim_scale(variance(a) * 1e1000) FROM num_variance;
-- check that parallel execution produces the same result
BEGIN;
ALTER TABLE num_variance SET (parallel_workers = 4);
SET LOCAL parallel_setup_cost = 0;
SET LOCAL max_parallel_workers_per_gather = 4;
SELECT trim_scale(variance(a) * 1e1000) FROM num_variance;
ROLLBACK;
-- case where sum of squares would overflow but variance does not
DELETE FROM num_variance;
INSERT INTO num_variance SELECT 9e131071 + x FROM generate_series(1, 5) x;
SELECT variance(a) FROM num_variance;
-- check that parallel execution produces the same result
BEGIN;
ALTER TABLE num_variance SET (parallel_workers = 4);
SET LOCAL parallel_setup_cost = 0;
SET LOCAL max_parallel_workers_per_gather = 4;
SELECT variance(a) FROM num_variance;
ROLLBACK;
DROP TABLE num_variance;
--
-- Tests for GCD()
--