Skip to content

Commit d4e644d

Browse files
committed
lab9: agg extraction
Signed-off-by: Runji Wang <[email protected]>
1 parent cb234dc commit d4e644d

File tree

3 files changed

+320
-2
lines changed

3 files changed

+320
-2
lines changed

src/agg.rs

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
use egg::Language;
2+
3+
use super::*;
4+
5+
/// The data type of aggragation analysis.
6+
pub type AggSet = Vec<Expr>;
7+
8+
/// Returns all aggragations in the tree.
9+
///
10+
/// Note: if there is an agg over agg, e.g. `sum(count(a))`, only the upper one will be returned.
11+
pub fn analyze_aggs(egraph: &EGraph, enode: &Expr) -> AggSet {
12+
use Expr::*;
13+
let x = |i: &Id| egraph[*i].data.aggs.clone();
14+
if let Max(_) | Min(_) | Sum(_) | Avg(_) | Count(_) = enode {
15+
return vec![enode.clone()];
16+
}
17+
// merge the set from all children
18+
// TODO: ignore plan nodes
19+
enode.children().iter().flat_map(x).collect()
20+
}
21+
22+
#[derive(Debug, PartialEq, Eq)]
23+
pub enum Error {
24+
// #[error("aggregate function calls cannot be nested")]
25+
NestedAgg(String),
26+
// #[error("WHERE clause cannot contain aggregates")]
27+
AggInWhere,
28+
// #[error("GROUP BY clause cannot contain aggregates")]
29+
AggInGroupBy,
30+
// #[error("column {0} must appear in the GROUP BY clause or be used in an aggregate function")]
31+
ColumnNotInAgg(String),
32+
}
33+
34+
/// Converts the SELECT statement into a plan tree.
35+
///
36+
/// The nodes of all clauses have been added to the `egraph`.
37+
/// `from`, `where_`... are the ids of their root node.
38+
pub fn plan_select(
39+
egraph: &mut EGraph,
40+
from: Id,
41+
where_: Id,
42+
having: Id,
43+
groupby: Id,
44+
orderby: Id,
45+
projection: Id,
46+
) -> Result<Id, Error> {
47+
AggExtractor { egraph }.plan_select(from, where_, having, groupby, orderby, projection)
48+
}
49+
50+
struct AggExtractor<'a> {
51+
egraph: &'a mut EGraph,
52+
}
53+
54+
impl AggExtractor<'_> {
55+
fn aggs(&self, id: Id) -> &[Expr] {
56+
&self.egraph[id].data.aggs
57+
}
58+
59+
fn node(&self, id: Id) -> &Expr {
60+
&self.egraph[id].nodes[0]
61+
}
62+
63+
fn plan_select(
64+
&mut self,
65+
from: Id,
66+
where_: Id,
67+
having: Id,
68+
groupby: Id,
69+
orderby: Id,
70+
projection: Id,
71+
) -> Result<Id, Error> {
72+
if !self.aggs(where_).is_empty() {
73+
return Err(Error::AggInWhere);
74+
}
75+
if !self.aggs(groupby).is_empty() {
76+
return Err(Error::AggInGroupBy);
77+
}
78+
let mut plan = self.egraph.add(Expr::Filter([where_, from]));
79+
let mut to_rewrite = [projection, having, orderby];
80+
plan = self.plan_agg(&mut to_rewrite, groupby, plan)?;
81+
let [projection, having, orderby] = to_rewrite;
82+
plan = self.egraph.add(Expr::Filter([having, plan]));
83+
plan = self.egraph.add(Expr::Order([orderby, plan]));
84+
plan = self.egraph.add(Expr::Proj([projection, plan]));
85+
Ok(plan)
86+
}
87+
88+
/// Extracts all aggregations from `exprs` and generates an [`Agg`](Expr::Agg) plan.
89+
/// If no aggregation is found and no `groupby` keys, returns the original `plan`.
90+
fn plan_agg(&mut self, exprs: &mut [Id], groupby: Id, plan: Id) -> Result<Id, Error> {
91+
let expr_list = self.egraph.add(Expr::List(exprs.to_vec().into()));
92+
let aggs = self.aggs(expr_list).to_vec();
93+
if aggs.is_empty() && self.node(groupby).as_list().is_empty() {
94+
return Ok(plan);
95+
}
96+
// check nested agg
97+
for agg in aggs.iter() {
98+
if agg
99+
.children()
100+
.iter()
101+
.any(|child| !self.aggs(*child).is_empty())
102+
{
103+
return Err(Error::NestedAgg(agg.to_string()));
104+
}
105+
}
106+
let mut list: Vec<_> = aggs.into_iter().map(|agg| self.egraph.add(agg)).collect();
107+
// make sure the order of the aggs is deterministic
108+
list.sort();
109+
list.dedup();
110+
let mut schema = list.clone();
111+
schema.extend_from_slice(self.node(groupby).as_list());
112+
let aggs = self.egraph.add(Expr::List(list.into()));
113+
let plan = self.egraph.add(Expr::Agg([aggs, groupby, plan]));
114+
// check for not aggregated columns
115+
// rewrite the expressions with a wrapper over agg or group keys
116+
for id in exprs {
117+
*id = self.rewrite_agg_in_expr(*id, &schema)?;
118+
}
119+
Ok(plan)
120+
}
121+
122+
/// Rewrites the expression `id` with aggs wrapped in a [`Nested`](Expr::Nested) node.
123+
/// Returns the new expression.
124+
///
125+
/// # Example
126+
/// ```text
127+
/// id: (+ (sum a) (+ b 1))
128+
/// schema: (sum a), (+ b 1)
129+
/// output: (+ (`(sum a)) (`(+ b 1)))
130+
///
131+
/// so that `id` won't be optimized to:
132+
/// (+ b (+ (sum a) 1))
133+
/// which can not be composed by `schema`
134+
/// ```
135+
fn rewrite_agg_in_expr(&mut self, id: Id, schema: &[Id]) -> Result<Id, Error> {
136+
let mut expr = self.node(id).clone();
137+
if schema.contains(&id) {
138+
// found agg, wrap it with Nested
139+
return Ok(self.egraph.add(Expr::Nested(id)));
140+
}
141+
if let Expr::Column(cid) = &expr {
142+
return Err(Error::ColumnNotInAgg(cid.to_string()));
143+
}
144+
for child in expr.children_mut() {
145+
*child = self.rewrite_agg_in_expr(*child, schema)?;
146+
}
147+
Ok(self.egraph.add(expr))
148+
}
149+
}

src/lib.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::hash::Hash;
22

33
use egg::{define_language, Analysis, DidMerge, Id, Var};
44

5+
pub mod agg;
56
pub mod expr;
67
pub mod plan;
78
mod value;
@@ -16,9 +17,9 @@ define_language! {
1617
pub enum Expr {
1718
// values
1819
Constant(Value), // null, true, 1, 'hello'
19-
Column(Column), // t.a, b, c
2020

2121
// utilities
22+
"`" = Nested(Id), // (` expr) a wrapper over expr to prevent optimization
2223
"list" = List(Box<[Id]>), // (list ...)
2324

2425
// unary operations
@@ -75,6 +76,8 @@ define_language! {
7576
"empty" = Empty(Id), // (empty child)
7677
// returns empty chunk
7778
// with the same schema as `child`
79+
80+
Column(Column), // t.a, b, c
7881
}
7982
}
8083

@@ -101,6 +104,9 @@ pub struct Data {
101104

102105
/// All columns involved in the node.
103106
pub columns: plan::ColumnSet,
107+
108+
/// All aggragations in the tree.
109+
pub aggs: agg::AggSet,
104110
}
105111

106112
impl Analysis<Expr> for ExprAnalysis {
@@ -111,6 +117,7 @@ impl Analysis<Expr> for ExprAnalysis {
111117
Data {
112118
constant: expr::eval_constant(egraph, enode),
113119
columns: plan::analyze_columns(egraph, enode),
120+
aggs: agg::analyze_aggs(egraph, enode),
114121
}
115122
}
116123

@@ -125,7 +132,8 @@ impl Analysis<Expr> for ExprAnalysis {
125132
fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge {
126133
let merge_const = egg::merge_max(&mut to.constant, from.constant);
127134
let merge_columns = plan::merge(&mut to.columns, from.columns);
128-
merge_const | merge_columns
135+
let merge_aggs = egg::merge_max(&mut to.aggs, from.aggs);
136+
merge_const | merge_columns | merge_aggs
129137
}
130138

131139
/// Modify the graph after analyzing a node.

tests/9_agg_extraction.rs

+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
use egg::Language;
2+
use egg_sql_labs::{
3+
agg::{plan_select, Error},
4+
EGraph, RecExpr,
5+
};
6+
7+
#[test]
8+
fn no_agg() {
9+
// SELECT a FROM t;
10+
test(Case {
11+
select: "(list a)",
12+
from: "(scan t (list a))",
13+
where_: "",
14+
having: "",
15+
groupby: "",
16+
orderby: "",
17+
expected: Ok("
18+
(proj (list a)
19+
(order list
20+
(filter true
21+
(filter true
22+
(scan t (list a))
23+
))))"),
24+
});
25+
}
26+
27+
#[test]
28+
fn agg() {
29+
// SELECT sum(a + b) + (a + 1) FROM t
30+
// WHERE b > 1
31+
// GROUP BY a + 1
32+
// HAVING count(a) > 1
33+
// ORDER BY max(b)
34+
test(Case {
35+
select: "(list (+ (sum (+ a b)) (+ a 1)))",
36+
from: "(scan t (list a b))",
37+
where_: "(> b 1)",
38+
having: "(> (count a) 1)",
39+
groupby: "(list (+ a 1))",
40+
orderby: "(list (asc (max b)))",
41+
expected: Ok("
42+
(proj (list (+ (` (sum (+ a b))) (` (+ a 1))))
43+
(order (list (asc (` (max b))))
44+
(filter (> (` (count a)) 1)
45+
(agg (list (sum (+ a b)) (count a) (max b)) (list (+ a 1))
46+
(filter (> b 1)
47+
(scan t (list a b))
48+
)))))"),
49+
});
50+
}
51+
52+
#[test]
53+
fn error_agg_in_where() {
54+
// SELECT a FROM t WHERE sum(a) > 1
55+
test(Case {
56+
select: "(list a)",
57+
from: "(scan t (list a b))",
58+
where_: "(> (sum a) 1)",
59+
having: "",
60+
groupby: "",
61+
orderby: "",
62+
expected: Err(Error::AggInWhere),
63+
});
64+
}
65+
66+
#[test]
67+
fn error_agg_in_groupby() {
68+
// SELECT a FROM t GROUP BY sum(a)
69+
test(Case {
70+
select: "(list a)",
71+
from: "(scan t (list a b))",
72+
where_: "",
73+
having: "",
74+
groupby: "(list (sum a))",
75+
orderby: "",
76+
expected: Err(Error::AggInGroupBy),
77+
});
78+
}
79+
80+
#[test]
81+
fn error_nested_agg() {
82+
// SELECT count(sum(a)) FROM t
83+
test(Case {
84+
select: "(list (count (sum a)))",
85+
from: "(scan t (list a b))",
86+
where_: "",
87+
having: "",
88+
groupby: "",
89+
orderby: "",
90+
expected: Err(Error::NestedAgg("count".into())),
91+
});
92+
}
93+
94+
#[test]
95+
fn error_column_not_in_agg() {
96+
// SELECT b FROM t GROUP BY a
97+
test(Case {
98+
select: "(list b)",
99+
from: "(scan t (list a b))",
100+
where_: "",
101+
having: "",
102+
groupby: "(list a)",
103+
orderby: "",
104+
expected: Err(Error::ColumnNotInAgg("b".into())),
105+
});
106+
}
107+
108+
struct Case {
109+
select: &'static str,
110+
from: &'static str,
111+
where_: &'static str,
112+
having: &'static str,
113+
groupby: &'static str,
114+
orderby: &'static str,
115+
expected: Result<&'static str, Error>,
116+
}
117+
118+
#[track_caller]
119+
fn test(mut case: Case) {
120+
if case.where_.is_empty() {
121+
case.where_ = "true";
122+
}
123+
if case.having.is_empty() {
124+
case.having = "true";
125+
}
126+
if case.groupby.is_empty() {
127+
case.groupby = "list";
128+
}
129+
if case.orderby.is_empty() {
130+
case.orderby = "list";
131+
}
132+
let mut egraph = EGraph::default();
133+
let projection = egraph.add_expr(&case.select.parse().unwrap());
134+
let from = egraph.add_expr(&case.from.parse().unwrap());
135+
let where_ = egraph.add_expr(&case.where_.parse().unwrap());
136+
let having = egraph.add_expr(&case.having.parse().unwrap());
137+
let groupby = egraph.add_expr(&case.groupby.parse().unwrap());
138+
let orderby = egraph.add_expr(&case.orderby.parse().unwrap());
139+
match plan_select(
140+
&mut egraph,
141+
from,
142+
where_,
143+
having,
144+
groupby,
145+
orderby,
146+
projection,
147+
) {
148+
Err(e) => assert_eq!(case.expected, Err(e)),
149+
Ok(id) => {
150+
let get_node = |id| egraph[id].nodes[0].clone();
151+
let actual = get_node(id).build_recexpr(get_node).to_string();
152+
let expected = case
153+
.expected
154+
.expect(&format!("expect error, but got: {actual:?}"))
155+
.parse::<RecExpr>()
156+
.unwrap()
157+
.to_string();
158+
assert_eq!(actual, expected);
159+
}
160+
}
161+
}

0 commit comments

Comments
 (0)