Skip to content

Commit

Permalink
Allow partial, merge and merge_extract companion functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Jan 8, 2025
1 parent 3f86559 commit 9911e8f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 30 deletions.
11 changes: 1 addition & 10 deletions velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,20 +340,11 @@ bool registerAggregateFunction(
const TypePtr& resultType,
const core::QueryConfig& config)
-> std::unique_ptr<Aggregate> {
const auto& [originalResultType, _] =
resolveAggregateFunction(mergeExtractFunctionName, argTypes);
if (!originalResultType) {
// TODO: limitation -- result type must be resolvable given
// intermediate type of the original UDAF.
VELOX_UNREACHABLE(
"Signatures whose result types are not resolvable given intermediate types should have been excluded.");
}

if (auto func = getAggregateFunctionEntry(name)) {
auto fn = func->factory(
core::AggregationNode::Step::kFinal,
argTypes,
originalResultType,
resultType,
config);
VELOX_CHECK_NOT_NULL(fn);
return std::make_unique<
Expand Down
11 changes: 0 additions & 11 deletions velox/exec/AggregateCompanionSignatures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,6 @@ CompanionSignatures::partialFunctionSignatures(
const std::vector<AggregateFunctionSignaturePtr>& signatures) {
std::vector<AggregateFunctionSignaturePtr> partialSignatures;
for (const auto& signature : signatures) {
if (!isResultTypeResolvableGivenIntermediateType(signature)) {
continue;
}
std::vector<TypeSignature> usedTypes = signature->argumentTypes();
usedTypes.push_back(signature->intermediateType());
auto variables = usedTypeVariables(usedTypes, signature->variables());
Expand All @@ -124,10 +121,6 @@ std::string CompanionSignatures::partialFunctionName(const std::string& name) {

AggregateFunctionSignaturePtr CompanionSignatures::mergeFunctionSignature(
const AggregateFunctionSignaturePtr& signature) {
if (!isResultTypeResolvableGivenIntermediateType(signature)) {
return nullptr;
}

std::vector<TypeSignature> usedTypes = {signature->intermediateType()};
auto variables = usedTypeVariables(usedTypes, signature->variables());
return std::make_shared<AggregateFunctionSignature>(
Expand Down Expand Up @@ -170,10 +163,6 @@ bool CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
AggregateFunctionSignaturePtr
CompanionSignatures::mergeExtractFunctionSignature(
const AggregateFunctionSignaturePtr& signature) {
if (!isResultTypeResolvableGivenIntermediateType(signature)) {
return nullptr;
}

std::vector<TypeSignature> usedTypes = {
signature->intermediateType(), signature->returnType()};
auto variables = usedTypeVariables(usedTypes, signature->variables());
Expand Down
23 changes: 14 additions & 9 deletions velox/exec/tests/AggregateCompanionAdapterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,22 +414,27 @@ TEST_F(
TEST_F(
AggregateCompanionRegistryTest,
resultTypeNotResolvableFromIntermediateType) {
// We only register companion functions for original signatures whose result
// type can be resolved from its intermediate type.
// We only register partial, merge and merge_extract companion functions for
// original signatures whose result type cannot be resolved from its
// intermediate type.
std::vector<std::shared_ptr<AggregateFunctionSignature>> signatures{
AggregateFunctionSignatureBuilder()
.typeVariable("T")
.returnType("array(T)")
.intermediateType("varbinary")
.argumentType("T")
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("i_precision", "min(38, a_precision + 10)")
.integerVariable("r_precision", "min(38, a_precision + 4)")
.integerVariable("r_scale", "min(38, a_scale + 4)")
.returnType("DECIMAL(r_precision, r_scale)")
.intermediateType("ROW(DECIMAL(i_precision, a_scale), bigint)")
.argumentType("DECIMAL(a_precision, a_scale)")
.build()};
registerDummyAggregateFunction("aggregateFunc6", signatures);

checkAggregateSignaturesCount("aggregateFunc6_partial", 0);
checkAggregateSignaturesCount("aggregateFunc6_partial", 1);

checkAggregateSignaturesCount("aggregateFunc6_merge", 0);
checkAggregateSignaturesCount("aggregateFunc6_merge", 1);

checkAggregateSignaturesCount("aggregateFunc6_merge_extract", 0);
checkAggregateSignaturesCount("aggregateFunc6_merge_extract", 1);

checkScalarSignaturesCount("aggregateFunc6_extract", 0);
}
Expand Down

0 comments on commit 9911e8f

Please sign in to comment.