|
| 1 | +use egg::Language; |
| 2 | + |
| 3 | +use super::*; |
| 4 | + |
| 5 | +/// The data type of aggragation analysis. |
| 6 | +pub type AggSet = Vec<Expr>; |
| 7 | + |
| 8 | +/// Returns all aggragations in the tree. |
| 9 | +/// |
| 10 | +/// Note: if there is an agg over agg, e.g. `sum(count(a))`, only the upper one will be returned. |
| 11 | +pub fn analyze_aggs(egraph: &EGraph, enode: &Expr) -> AggSet { |
| 12 | + use Expr::*; |
| 13 | + let x = |i: &Id| egraph[*i].data.aggs.clone(); |
| 14 | + if let Max(_) | Min(_) | Sum(_) | Avg(_) | Count(_) = enode { |
| 15 | + return vec![enode.clone()]; |
| 16 | + } |
| 17 | + // merge the set from all children |
| 18 | + // TODO: ignore plan nodes |
| 19 | + enode.children().iter().flat_map(x).collect() |
| 20 | +} |
| 21 | + |
| 22 | +#[derive(Debug, PartialEq, Eq)] |
| 23 | +pub enum Error { |
| 24 | + // #[error("aggregate function calls cannot be nested")] |
| 25 | + NestedAgg(String), |
| 26 | + // #[error("WHERE clause cannot contain aggregates")] |
| 27 | + AggInWhere, |
| 28 | + // #[error("GROUP BY clause cannot contain aggregates")] |
| 29 | + AggInGroupBy, |
| 30 | + // #[error("column {0} must appear in the GROUP BY clause or be used in an aggregate function")] |
| 31 | + ColumnNotInAgg(String), |
| 32 | +} |
| 33 | + |
| 34 | +/// Converts the SELECT statement into a plan tree. |
| 35 | +/// |
| 36 | +/// The nodes of all clauses have been added to the `egraph`. |
| 37 | +/// `from`, `where_`... are the ids of their root node. |
| 38 | +pub fn plan_select( |
| 39 | + egraph: &mut EGraph, |
| 40 | + from: Id, |
| 41 | + where_: Id, |
| 42 | + having: Id, |
| 43 | + groupby: Id, |
| 44 | + orderby: Id, |
| 45 | + projection: Id, |
| 46 | +) -> Result<Id, Error> { |
| 47 | + AggExtractor { egraph }.plan_select(from, where_, having, groupby, orderby, projection) |
| 48 | +} |
| 49 | + |
| 50 | +struct AggExtractor<'a> { |
| 51 | + egraph: &'a mut EGraph, |
| 52 | +} |
| 53 | + |
| 54 | +impl AggExtractor<'_> { |
| 55 | + fn aggs(&self, id: Id) -> &[Expr] { |
| 56 | + &self.egraph[id].data.aggs |
| 57 | + } |
| 58 | + |
| 59 | + fn node(&self, id: Id) -> &Expr { |
| 60 | + &self.egraph[id].nodes[0] |
| 61 | + } |
| 62 | + |
| 63 | + fn plan_select( |
| 64 | + &mut self, |
| 65 | + from: Id, |
| 66 | + where_: Id, |
| 67 | + having: Id, |
| 68 | + groupby: Id, |
| 69 | + orderby: Id, |
| 70 | + projection: Id, |
| 71 | + ) -> Result<Id, Error> { |
| 72 | + if !self.aggs(where_).is_empty() { |
| 73 | + return Err(Error::AggInWhere); |
| 74 | + } |
| 75 | + if !self.aggs(groupby).is_empty() { |
| 76 | + return Err(Error::AggInGroupBy); |
| 77 | + } |
| 78 | + let mut plan = self.egraph.add(Expr::Filter([where_, from])); |
| 79 | + let mut to_rewrite = [projection, having, orderby]; |
| 80 | + plan = self.plan_agg(&mut to_rewrite, groupby, plan)?; |
| 81 | + let [projection, having, orderby] = to_rewrite; |
| 82 | + plan = self.egraph.add(Expr::Filter([having, plan])); |
| 83 | + plan = self.egraph.add(Expr::Order([orderby, plan])); |
| 84 | + plan = self.egraph.add(Expr::Proj([projection, plan])); |
| 85 | + Ok(plan) |
| 86 | + } |
| 87 | + |
| 88 | + /// Extracts all aggregations from `exprs` and generates an [`Agg`](Expr::Agg) plan. |
| 89 | + /// If no aggregation is found and no `groupby` keys, returns the original `plan`. |
| 90 | + fn plan_agg(&mut self, exprs: &mut [Id], groupby: Id, plan: Id) -> Result<Id, Error> { |
| 91 | + let expr_list = self.egraph.add(Expr::List(exprs.to_vec().into())); |
| 92 | + let aggs = self.aggs(expr_list).to_vec(); |
| 93 | + if aggs.is_empty() && self.node(groupby).as_list().is_empty() { |
| 94 | + return Ok(plan); |
| 95 | + } |
| 96 | + // check nested agg |
| 97 | + for agg in aggs.iter() { |
| 98 | + if agg |
| 99 | + .children() |
| 100 | + .iter() |
| 101 | + .any(|child| !self.aggs(*child).is_empty()) |
| 102 | + { |
| 103 | + return Err(Error::NestedAgg(agg.to_string())); |
| 104 | + } |
| 105 | + } |
| 106 | + let mut list: Vec<_> = aggs.into_iter().map(|agg| self.egraph.add(agg)).collect(); |
| 107 | + // make sure the order of the aggs is deterministic |
| 108 | + list.sort(); |
| 109 | + list.dedup(); |
| 110 | + let mut schema = list.clone(); |
| 111 | + schema.extend_from_slice(self.node(groupby).as_list()); |
| 112 | + let aggs = self.egraph.add(Expr::List(list.into())); |
| 113 | + let plan = self.egraph.add(Expr::Agg([aggs, groupby, plan])); |
| 114 | + // check for not aggregated columns |
| 115 | + // rewrite the expressions with a wrapper over agg or group keys |
| 116 | + for id in exprs { |
| 117 | + *id = self.rewrite_agg_in_expr(*id, &schema)?; |
| 118 | + } |
| 119 | + Ok(plan) |
| 120 | + } |
| 121 | + |
| 122 | + /// Rewrites the expression `id` with aggs wrapped in a [`Nested`](Expr::Nested) node. |
| 123 | + /// Returns the new expression. |
| 124 | + /// |
| 125 | + /// # Example |
| 126 | + /// ```text |
| 127 | + /// id: (+ (sum a) (+ b 1)) |
| 128 | + /// schema: (sum a), (+ b 1) |
| 129 | + /// output: (+ (`(sum a)) (`(+ b 1))) |
| 130 | + /// |
| 131 | + /// so that `id` won't be optimized to: |
| 132 | + /// (+ b (+ (sum a) 1)) |
| 133 | + /// which can not be composed by `schema` |
| 134 | + /// ``` |
| 135 | + fn rewrite_agg_in_expr(&mut self, id: Id, schema: &[Id]) -> Result<Id, Error> { |
| 136 | + let mut expr = self.node(id).clone(); |
| 137 | + if schema.contains(&id) { |
| 138 | + // found agg, wrap it with Nested |
| 139 | + return Ok(self.egraph.add(Expr::Nested(id))); |
| 140 | + } |
| 141 | + if let Expr::Column(cid) = &expr { |
| 142 | + return Err(Error::ColumnNotInAgg(cid.to_string())); |
| 143 | + } |
| 144 | + for child in expr.children_mut() { |
| 145 | + *child = self.rewrite_agg_in_expr(*child, schema)?; |
| 146 | + } |
| 147 | + Ok(self.egraph.add(expr)) |
| 148 | + } |
| 149 | +} |
0 commit comments