diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index 8440a76fbdc..914b02ceee4 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -461,10 +461,11 @@ static void hashagg_tapeinfo_release(HashTapeInfo *tapeinfo, int tapenum); static Datum GetAggInitVal(Datum textInitVal, Oid transtype); static void build_pertrans_for_aggref(AggStatePerTrans pertrans, AggState *aggstate, EState *estate, - Aggref *aggref, Oid aggtransfn, Oid aggtranstype, - Oid aggserialfn, Oid aggdeserialfn, - Datum initValue, bool initValueIsNull, - Oid *inputTypes, int numArguments); + Aggref *aggref, Oid transfn_oid, + Oid aggtranstype, Oid aggserialfn, + Oid aggdeserialfn, Datum initValue, + bool initValueIsNull, Oid *inputTypes, + int numArguments); /* @@ -3724,8 +3725,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) Aggref *aggref = lfirst(l); AggStatePerAgg peragg; AggStatePerTrans pertrans; - Oid inputTypes[FUNC_MAX_ARGS]; - int numArguments; + Oid aggTransFnInputTypes[FUNC_MAX_ARGS]; + int numAggTransFnArgs; int numDirectArgs; HeapTuple aggTuple; Form_pg_aggregate aggform; @@ -3859,14 +3860,15 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) * could be different from the agg's declared input types, when the * agg accepts ANY or a polymorphic type. */ - numArguments = get_aggregate_argtypes(aggref, inputTypes); + numAggTransFnArgs = get_aggregate_argtypes(aggref, + aggTransFnInputTypes); /* Count the "direct" arguments, if any */ numDirectArgs = list_length(aggref->aggdirectargs); /* Detect how many arguments to pass to the finalfn */ if (aggform->aggfinalextra) - peragg->numFinalArgs = numArguments + 1; + peragg->numFinalArgs = numAggTransFnArgs + 1; else peragg->numFinalArgs = numDirectArgs + 1; @@ -3880,7 +3882,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) */ if (OidIsValid(finalfn_oid)) { - build_aggregate_finalfn_expr(inputTypes, + build_aggregate_finalfn_expr(aggTransFnInputTypes, peragg->numFinalArgs, aggtranstype, aggref->aggtype, @@ -3911,7 +3913,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) /* * If this aggregation is performing state combines, then instead * of using the transition function, we'll use the combine - * function + * function. */ if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit)) { @@ -3924,8 +3926,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) else transfn_oid = aggform->aggtransfn; - aclresult = pg_proc_aclcheck(transfn_oid, aggOwner, - ACL_EXECUTE); + aclresult = pg_proc_aclcheck(transfn_oid, aggOwner, ACL_EXECUTE); if (aclresult != ACLCHECK_OK) aclcheck_error(aclresult, OBJECT_FUNCTION, get_func_name(transfn_oid)); @@ -3943,11 +3944,72 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) else initValue = GetAggInitVal(textInitVal, aggtranstype); - build_pertrans_for_aggref(pertrans, aggstate, estate, - aggref, transfn_oid, aggtranstype, - serialfn_oid, deserialfn_oid, - initValue, initValueIsNull, - inputTypes, numArguments); + if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit)) + { + Oid combineFnInputTypes[] = {aggtranstype, + aggtranstype}; + + /* + * When combining there's only one input, the to-be-combined + * transition value. The transition value is not counted + * here. + */ + pertrans->numTransInputs = 1; + + /* aggcombinefn always has two arguments of aggtranstype */ + build_pertrans_for_aggref(pertrans, aggstate, estate, + aggref, transfn_oid, aggtranstype, + serialfn_oid, deserialfn_oid, + initValue, initValueIsNull, + combineFnInputTypes, 2); + + /* + * Ensure that a combine function to combine INTERNAL states + * is not strict. This should have been checked during CREATE + * AGGREGATE, but the strict property could have been changed + * since then. + */ + if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("combine function with transition type %s must not be declared STRICT", + format_type_be(aggtranstype)))); + } + else + { + /* Detect how many arguments to pass to the transfn */ + if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) + pertrans->numTransInputs = list_length(aggref->args); + else + pertrans->numTransInputs = numAggTransFnArgs; + + build_pertrans_for_aggref(pertrans, aggstate, estate, + aggref, transfn_oid, aggtranstype, + serialfn_oid, deserialfn_oid, + initValue, initValueIsNull, + aggTransFnInputTypes, + numAggTransFnArgs); + + /* + * If the transfn is strict and the initval is NULL, make sure + * input type and transtype are the same (or at least + * binary-compatible), so that it's OK to use the first + * aggregated input value as the initial transValue. This + * should have been checked at agg definition time, but we + * must check again in case the transfn's strictness property + * has been changed. + */ + if (pertrans->transfn.fn_strict && pertrans->initValueIsNull) + { + if (numAggTransFnArgs <= numDirectArgs || + !IsBinaryCoercible(aggTransFnInputTypes[numDirectArgs], + aggtranstype)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate %u needs to have compatible input type and transition type", + aggref->aggfnoid))); + } + } } else pertrans->aggshared = true; @@ -4039,20 +4101,24 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) * Build the state needed to calculate a state value for an aggregate. * * This initializes all the fields in 'pertrans'. 'aggref' is the aggregate - * to initialize the state for. 'aggtransfn', 'aggtranstype', and the rest + * to initialize the state for. 'transfn_oid', 'aggtranstype', and the rest * of the arguments could be calculated from 'aggref', but the caller has * calculated them already, so might as well pass them. + * + * 'transfn_oid' may be either the Oid of the aggtransfn or the aggcombinefn. */ static void build_pertrans_for_aggref(AggStatePerTrans pertrans, AggState *aggstate, EState *estate, Aggref *aggref, - Oid aggtransfn, Oid aggtranstype, + Oid transfn_oid, Oid aggtranstype, Oid aggserialfn, Oid aggdeserialfn, Datum initValue, bool initValueIsNull, Oid *inputTypes, int numArguments) { int numGroupingSets = Max(aggstate->maxsets, 1); + Expr *transfnexpr; + int numTransArgs; Expr *serialfnexpr = NULL; Expr *deserialfnexpr = NULL; ListCell *lc; @@ -4067,7 +4133,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, pertrans->aggref = aggref; pertrans->aggshared = false; pertrans->aggCollation = aggref->inputcollid; - pertrans->transfn_oid = aggtransfn; + pertrans->transfn_oid = transfn_oid; pertrans->serialfn_oid = aggserialfn; pertrans->deserialfn_oid = aggdeserialfn; pertrans->initValue = initValue; @@ -4081,111 +4147,34 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, pertrans->aggtranstype = aggtranstype; + /* account for the current transition state */ + numTransArgs = pertrans->numTransInputs + 1; + /* - * When combining states, we have no use at all for the aggregate - * function's transfn. Instead we use the combinefn. In this case, the - * transfn and transfn_oid fields of pertrans refer to the combine - * function rather than the transition function. + * Set up infrastructure for calling the transfn. Note that invtrans is + * not needed here. */ - if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit)) - { - Expr *combinefnexpr; - size_t numTransArgs; + build_aggregate_transfn_expr(inputTypes, + numArguments, + numDirectArgs, + aggref->aggvariadic, + aggtranstype, + aggref->inputcollid, + transfn_oid, + InvalidOid, + &transfnexpr, + NULL); - /* - * When combining there's only one input, the to-be-combined added - * transition value from below (this node's transition value is - * counted separately). - */ - pertrans->numTransInputs = 1; + fmgr_info(transfn_oid, &pertrans->transfn); + fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn); - /* account for the current transition state */ - numTransArgs = pertrans->numTransInputs + 1; - - build_aggregate_combinefn_expr(aggtranstype, - aggref->inputcollid, - aggtransfn, - &combinefnexpr); - fmgr_info(aggtransfn, &pertrans->transfn); - fmgr_info_set_expr((Node *) combinefnexpr, &pertrans->transfn); - - pertrans->transfn_fcinfo = - (FunctionCallInfo) palloc(SizeForFunctionCallInfo(2)); - InitFunctionCallInfoData(*pertrans->transfn_fcinfo, - &pertrans->transfn, - numTransArgs, - pertrans->aggCollation, - (void *) aggstate, NULL); - - /* - * Ensure that a combine function to combine INTERNAL states is not - * strict. This should have been checked during CREATE AGGREGATE, but - * the strict property could have been changed since then. - */ - if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID) - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("combine function with transition type %s must not be declared STRICT", - format_type_be(aggtranstype)))); - } - else - { - Expr *transfnexpr; - size_t numTransArgs; - - /* Detect how many arguments to pass to the transfn */ - if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) - pertrans->numTransInputs = numInputs; - else - pertrans->numTransInputs = numArguments; - - /* account for the current transition state */ - numTransArgs = pertrans->numTransInputs + 1; - - /* - * Set up infrastructure for calling the transfn. Note that - * invtransfn is not needed here. - */ - build_aggregate_transfn_expr(inputTypes, - numArguments, - numDirectArgs, - aggref->aggvariadic, - aggtranstype, - aggref->inputcollid, - aggtransfn, - InvalidOid, - &transfnexpr, - NULL); - fmgr_info(aggtransfn, &pertrans->transfn); - fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn); - - pertrans->transfn_fcinfo = - (FunctionCallInfo) palloc(SizeForFunctionCallInfo(numTransArgs)); - InitFunctionCallInfoData(*pertrans->transfn_fcinfo, - &pertrans->transfn, - numTransArgs, - pertrans->aggCollation, - (void *) aggstate, NULL); - - /* - * If the transfn is strict and the initval is NULL, make sure input - * type and transtype are the same (or at least binary-compatible), so - * that it's OK to use the first aggregated input value as the initial - * transValue. This should have been checked at agg definition time, - * but we must check again in case the transfn's strictness property - * has been changed. - */ - if (pertrans->transfn.fn_strict && pertrans->initValueIsNull) - { - if (numArguments <= numDirectArgs || - !IsBinaryCoercible(inputTypes[numDirectArgs], - aggtranstype)) - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("aggregate %u needs to have compatible input type and transition type", - aggref->aggfnoid))); - } - } + pertrans->transfn_fcinfo = + (FunctionCallInfo) palloc(SizeForFunctionCallInfo(numTransArgs)); + InitFunctionCallInfoData(*pertrans->transfn_fcinfo, + &pertrans->transfn, + numTransArgs, + pertrans->aggCollation, + (void *) aggstate, NULL); /* get info about the state value's datatype */ get_typlenbyval(aggtranstype, @@ -4276,6 +4265,9 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, */ Assert(aggstate->aggstrategy != AGG_HASHED && aggstate->aggstrategy != AGG_MIXED); + /* ORDER BY aggregates are not supported with partial aggregation */ + Assert(!DO_AGGSPLIT_COMBINE(aggstate->aggsplit)); + /* If we have only one input, we need its len/byval info. */ if (numInputs == 1) { diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c index a25f8d5b989..24268eb5024 100644 --- a/src/backend/parser/parse_agg.c +++ b/src/backend/parser/parse_agg.c @@ -1959,6 +1959,11 @@ resolve_aggregate_transtype(Oid aggfuncid, * latter may be InvalidOid, however if invtransfn_oid is set then * transfn_oid must also be set. * + * transfn_oid may also be passed as the aggcombinefn when the *transfnexpr is + * to be used for a combine aggregate phase. We expect invtransfn_oid to be + * InvalidOid in this case since there is no such thing as an inverse + * combinefn. + * * Pointers to the constructed trees are returned into *transfnexpr, * *invtransfnexpr. If there is no invtransfn, the respective pointer is set * to NULL. Since use of the invtransfn is optional, NULL may be passed for @@ -2021,35 +2026,6 @@ build_aggregate_transfn_expr(Oid *agg_input_types, } } -/* - * Like build_aggregate_transfn_expr, but creates an expression tree for the - * combine function of an aggregate, rather than the transition function. - */ -void -build_aggregate_combinefn_expr(Oid agg_state_type, - Oid agg_input_collation, - Oid combinefn_oid, - Expr **combinefnexpr) -{ - Node *argp; - List *args; - FuncExpr *fexpr; - - /* combinefn takes two arguments of the aggregate state type */ - argp = make_agg_arg(agg_state_type, agg_input_collation); - - args = list_make2(argp, argp); - - fexpr = makeFuncExpr(combinefn_oid, - agg_state_type, - args, - InvalidOid, - agg_input_collation, - COERCE_EXPLICIT_CALL); - /* combinefn is currently never treated as variadic */ - *combinefnexpr = (Expr *) fexpr; -} - /* * Like build_aggregate_transfn_expr, but creates an expression tree for the * serialization function of an aggregate. diff --git a/src/include/parser/parse_agg.h b/src/include/parser/parse_agg.h index 4dea01752af..bffbb82df66 100644 --- a/src/include/parser/parse_agg.h +++ b/src/include/parser/parse_agg.h @@ -46,11 +46,6 @@ extern void build_aggregate_transfn_expr(Oid *agg_input_types, Expr **transfnexpr, Expr **invtransfnexpr); -extern void build_aggregate_combinefn_expr(Oid agg_state_type, - Oid agg_input_collation, - Oid combinefn_oid, - Expr **combinefnexpr); - extern void build_aggregate_serialfn_expr(Oid serialfn_oid, Expr **serialfnexpr);