Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[substrait] Synthetically added grouping expressions in Aggregates can cause mismatched output columns #14348

Closed
anlinc opened this issue Jan 29, 2025 · 1 comment · Fixed by #14860
Assignees
Labels
bug Something isn't working

Comments

@anlinc
Copy link
Contributor

anlinc commented Jan 29, 2025

Describe the bug

In #8356, support was added so that functionally dependent expressions on unique columns can also be available for output. This was achieved by synthetically adding said expressions as additional grouping expressions in the logical plan builder.

However, in Substrait plans, we use pre-defined, ordinal-based field references. These ordinals are indexes into the combined input + output expressions. By increasing the number of expressions in the input, you can expect that these ordinals will point to new, unexpected expressions -- thus altering the final output.

Given a sample plan without the added grouping:
Image

The sample plan with the added grouping produces something semantically different:
Image

To Reproduce

Create a new test file in datafusion/substrait/tests/testdata/test_plans/multi_layer_aggregation.substrait.json with the following Substrait JSON:

{
  "extensionUris": [{
    "extensionUriAnchor": 1,
    "uri": "/functions_aggregate_generic.yaml"
  }, {
    "extensionUriAnchor": 2,
    "uri": "/functions_arithmetic.yaml"
  }, {
    "extensionUriAnchor": 3,
    "uri": "/functions_string.yaml"
  }],
  "extensions": [{
    "extensionFunction": {
      "extensionUriReference": 1,
      "functionAnchor": 0,
      "name": "count:any"
    }
  }, {
    "extensionFunction": {
      "extensionUriReference": 2,
      "functionAnchor": 1,
      "name": "sum:i64"
    }
  }, {
    "extensionFunction": {
      "extensionUriReference": 3,
      "functionAnchor": 2,
      "name": "lower:str"
    }
  }],
  "relations": [{
    "root": {
      "input": {
        "project": {
          "common": {
            "emit": {
              "outputMapping": [2, 3]
            }
          },
          "input": {
            "aggregate": {
              "common": {
                "direct": {
                }
              },
              "input": {
                "aggregate": {
                  "common": {
                    "direct": {
                    }
                  },
                  "input": {
                    "read": {
                      "common": {
                        "direct": {}
                      },
                      "baseSchema": {
                        "names": [
                          "product"
                        ],
                        "struct": {
                          "types": [
                            {
                              "string": {
                                "nullability": "NULLABILITY_REQUIRED"
                              }
                            }
                          ],
                          "nullability": "NULLABILITY_REQUIRED"
                        }
                      },
                      "namedTable": {
                        "names": [
                          "sales"
                        ]
                      }
                    }
                  },
                  "groupings": [{
                    "groupingExpressions": [{
                      "selection": {
                        "directReference": {
                          "structField": {
                            "field": 0
                          }
                        },
                        "rootReference": {
                        }
                      }
                    }],
                    "expressionReferences": []
                  }],
                  "measures": [{
                    "measure": {
                      "functionReference": 0,
                      "args": [],
                      "sorts": [],
                      "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
                      "outputType": {
                        "i64": {
                          "typeVariationReference": 0,
                          "nullability": "NULLABILITY_REQUIRED"
                        }
                      },
                      "invocation": "AGGREGATION_INVOCATION_ALL",
                      "arguments": [{
                        "value": {
                          "selection": {
                            "directReference": {
                              "structField": {
                                "field": 0
                              }
                            },
                            "rootReference": {
                            }
                          }
                        }
                      }],
                      "options": []
                    }
                  }],
                  "groupingExpressions": []
                }
              },
              "groupings": [{
                "groupingExpressions": [{
                  "selection": {
                    "directReference": {
                      "structField": {
                        "field": 0
                      }
                    },
                    "rootReference": {
                    }
                  }
                }],
                "expressionReferences": []
              }],
              "measures": [{
                "measure": {
                  "functionReference": 1,
                  "args": [],
                  "sorts": [],
                  "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
                  "outputType": {
                    "i64": {
                      "typeVariationReference": 0,
                      "nullability": "NULLABILITY_NULLABLE"
                    }
                  },
                  "invocation": "AGGREGATION_INVOCATION_ALL",
                  "arguments": [{
                    "value": {
                      "selection": {
                        "directReference": {
                          "structField": {
                            "field": 1
                          }
                        },
                        "rootReference": {
                        }
                      }
                    }
                  }],
                  "options": []
                }
              }],
              "groupingExpressions": []
            }
          },
          "expressions": [{
            "scalarFunction": {
              "functionReference": 2,
              "args": [],
              "outputType": {
                "string": {
                  "typeVariationReference": 0,
                  "nullability": "NULLABILITY_NULLABLE"
                }
              },
              "arguments": [{
                "value": {
                  "selection": {
                    "directReference": {
                      "structField": {
                        "field": 0
                      }
                    },
                    "rootReference": {
                    }
                  }
                }
              }],
              "options": []
            }
          }, {
            "selection": {
              "directReference": {
                "structField": {
                  "field": 1
                }
              },
              "rootReference": {
              }
            }
          }]
        }
      },
      "names": ["lower(product)", "product_count"]
    }
  }],
  "expectedTypeUrls": []
}

Note: this was generated from a SQL query of the form:

SELECT lower(product), sum(count) as product_count FROM (
    SELECT product, count(product) as count
    FROM sales
    GROUP BY product
)
GROUP BY product;

Add a test in datafusion/substrait/tests/cases/logical_plans.rs:

#[tokio::test]
async fn multi_layer_aggregation() -> Result<()> {
    let proto_plan =
        read_json("tests/testdata/test_plans/multi_layer_aggregation.substrait.json");
    let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
    let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;

    assert_eq!(
        format!("{}", plan),
        "Projection: lower(sales.product) as lower(product), sum(count(sales.product)) as product_count\
        \n  Aggregate: groupBy=[[sales.product]], aggr=[[sum(count(sales.product))]]\
        \n    Aggregate: groupBy=[[sales.product]], aggr=[[count(sales.product)]]\
        \n      TableScan: sales"
    );

    Ok(())
}

Test should succeed. But instead will fail with:

assertion `left == right` failed

left: 
"Projection: sum(count(sales.product)) AS lower(product), lower(sales.product) AS product_count
    Aggregate: groupBy=[[sales.product, count(sales.product)]], aggr=[[sum(count(sales.product))]]
      Aggregate: groupBy=[[sales.product]], aggr=[[count(sales.product)]]
         TableScan: sales"

right: 
"Projection: lower(sales.product) as lower(product), sum(count(sales.product)) as product_count
    Aggregate: groupBy=[[sales.product]], aggr=[[sum(count(sales.product))]]
      Aggregate: groupBy=[[sales.product]], aggr=[[count(sales.product)]]
        TableScan: sales"

Expected behavior

The translated LogicalPlan must preserve the semantics of the Substrait plan.

Not trying to prescribe a particular solution here. There may be a few possible approaches:

  • If we introduce new expression(s), we must produce a new remapping that takes into account the added expression(s).
  • The Substrait -> LogicalPlan translation layer should never modify the original intent of the plan. In other words, maybe in this particular scenario, the additional grouping should not be introduced by the Substrait consumer. Instead, allow the Substrait producer to be responsible for introducing the extra grouping if the same PK Aggregate feature is desired.
  • For this particular query structure, another viable fix may be to not add functionally dependent expressions into the grouping set if they are never referenced or projected.

Additional context

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment