Skip to content

Commit

Permalink
support simple lateral joins
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi Z <[email protected]>
  • Loading branch information
skyzh committed Feb 11, 2025
1 parent 9c12919 commit ab3832f
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 1 deletion.
106 changes: 106 additions & 0 deletions datafusion/optimizer/src/decorrelate_lateral_join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`DecorrelateLateralJoin`] decorrelates logical plans produced by lateral joins.
use crate::optimizer::ApplyOrder;
use crate::{decorrelate_predicate_subquery, OptimizerConfig, OptimizerRule};

use datafusion_common::tree_node::{Transformed, TreeNodeRecursion, TreeNodeVisitor};
use datafusion_common::Result;
use datafusion_expr::logical_plan::JoinType;
use datafusion_expr::LogicalPlan;

/// Optimizer rule for rewriting lateral joins to joins
#[derive(Default, Debug)]
pub struct DecorrelateLateralJoin {}

impl DecorrelateLateralJoin {
#[allow(missing_docs)]
pub fn new() -> Self {
Self::default()
}
}

impl OptimizerRule for DecorrelateLateralJoin {
fn supports_rewrite(&self) -> bool {
true
}

fn rewrite(
&self,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
// Find cross joins with outer column references on the right side (i.e., the apply operator).
let LogicalPlan::Join(join) = &plan else {
return Ok(Transformed::no(plan));
};
if join.join_type != JoinType::Inner {
return Ok(Transformed::no(plan));
}
// TODO: this makes the rule to be quadratic to the number of nodes, in theory, we can build this property
// bottom-up.
if !plan_contains_outer_reference(&join.right) {
return Ok(Transformed::no(plan));
}
// The right side contains outer references, we need to decorrelate it.
let LogicalPlan::Subquery(subquery) = &*join.right else {
return Ok(Transformed::no(plan));
};
let alias = config.alias_generator();
let Some(new_plan) = decorrelate_predicate_subquery::build_join(
&join.left,
subquery.subquery.as_ref(),
None,
join.join_type,
alias.next("__lateral_sq"),
)?
else {
return Ok(Transformed::no(plan));
};
Ok(Transformed::new(new_plan, true, TreeNodeRecursion::Jump))
}

fn name(&self) -> &str {
"decorrelate_lateral_join"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
}

fn plan_contains_outer_reference(plan: &LogicalPlan) -> bool {
struct Visitor {
contains: bool,
}
impl<'n> TreeNodeVisitor<'n> for Visitor {
type Node = LogicalPlan;
fn f_down(&mut self, plan: &'n LogicalPlan) -> Result<TreeNodeRecursion> {
if plan.contains_outer_reference() {
self.contains = true;
Ok(TreeNodeRecursion::Stop)
} else {
Ok(TreeNodeRecursion::Continue)
}
}
}
let mut visitor = Visitor { contains: false };
plan.visit_with_subqueries(&mut visitor).unwrap();
visitor.contains
}
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/decorrelate_predicate_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ fn mark_join(
)
}

fn build_join(
pub(crate) fn build_join(
left: &LogicalPlan,
subquery: &LogicalPlan,
in_predicate_opt: Option<Expr>,
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
pub mod analyzer;
pub mod common_subexpr_eliminate;
pub mod decorrelate;
pub mod decorrelate_lateral_join;
pub mod decorrelate_predicate_subquery;
pub mod eliminate_cross_join;
pub mod eliminate_duplicated_expr;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use datafusion_common::{internal_err, DFSchema, DataFusionError, HashSet, Result
use datafusion_expr::logical_plan::LogicalPlan;

use crate::common_subexpr_eliminate::CommonSubexprEliminate;
use crate::decorrelate_lateral_join::DecorrelateLateralJoin;
use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery;
use crate::eliminate_cross_join::EliminateCrossJoin;
use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr;
Expand Down Expand Up @@ -248,6 +249,7 @@ impl Optimizer {
Arc::new(EliminateJoin::new()),
Arc::new(DecorrelatePredicateSubquery::new()),
Arc::new(ScalarSubqueryToJoin::new()),
Arc::new(DecorrelateLateralJoin::new()),
Arc::new(ExtractEquijoinPredicate::new()),
Arc::new(EliminateDuplicatedExpr::new()),
Arc::new(EliminateFilter::new()),
Expand Down
82 changes: 82 additions & 0 deletions datafusion/sqllogictest/test_files/join.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -1312,3 +1312,85 @@ SELECT a+b*2,

statement ok
drop table t1;


statement ok
CREATE TABLE t1(v0 BIGINT, v1 BIGINT);

statement ok
CREATE TABLE t0(v0 BIGINT, v1 BIGINT);

statement ok
INSERT INTO t0(v0, v1) VALUES (1, 1), (1, 2), (3, 3);

statement ok
INSERT INTO t1(v0, v1) VALUES (1, 1), (3, 2), (3, 5);

query TT
explain SELECT *
FROM t0,
LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0);
----
logical_plan
01)Projection: t0.v0, t0.v1, sum(t1.v1)
02)--Inner Join: t0.v0 = __lateral_sq_1.v0
03)----TableScan: t0 projection=[v0, v1]
04)----SubqueryAlias: __lateral_sq_1
05)------Projection: sum(t1.v1), t1.v0
06)--------Aggregate: groupBy=[[t1.v0]], aggr=[[sum(t1.v1)]]
07)----------TableScan: t1 projection=[v0, v1]
physical_plan
01)ProjectionExec: expr=[v0@1 as v0, v1@2 as v1, sum(t1.v1)@0 as sum(t1.v1)]
02)--CoalesceBatchesExec: target_batch_size=8192
03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(v0@1, v0@0)], projection=[sum(t1.v1)@0, v0@2, v1@3]
04)------CoalescePartitionsExec
05)--------ProjectionExec: expr=[sum(t1.v1)@1 as sum(t1.v1), v0@0 as v0]
06)----------AggregateExec: mode=FinalPartitioned, gby=[v0@0 as v0], aggr=[sum(t1.v1)]
07)------------CoalesceBatchesExec: target_batch_size=8192
08)--------------RepartitionExec: partitioning=Hash([v0@0], 4), input_partitions=4
09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
10)------------------AggregateExec: mode=Partial, gby=[v0@0 as v0], aggr=[sum(t1.v1)]
11)--------------------DataSourceExec: partitions=1, partition_sizes=[1]
12)------DataSourceExec: partitions=1, partition_sizes=[1]

query TT
explain SELECT *
FROM t0,
LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0);
----
logical_plan
01)Projection: t0.v0, t0.v1, sum(t1.v1)
02)--Inner Join: t0.v0 = __lateral_sq_1.v0
03)----TableScan: t0 projection=[v0, v1]
04)----SubqueryAlias: __lateral_sq_1
05)------Projection: sum(t1.v1), t1.v0
06)--------Aggregate: groupBy=[[t1.v0]], aggr=[[sum(t1.v1)]]
07)----------TableScan: t1 projection=[v0, v1]
physical_plan
01)ProjectionExec: expr=[v0@1 as v0, v1@2 as v1, sum(t1.v1)@0 as sum(t1.v1)]
02)--CoalesceBatchesExec: target_batch_size=8192
03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(v0@1, v0@0)], projection=[sum(t1.v1)@0, v0@2, v1@3]
04)------CoalescePartitionsExec
05)--------ProjectionExec: expr=[sum(t1.v1)@1 as sum(t1.v1), v0@0 as v0]
06)----------AggregateExec: mode=FinalPartitioned, gby=[v0@0 as v0], aggr=[sum(t1.v1)]
07)------------CoalesceBatchesExec: target_batch_size=8192
08)--------------RepartitionExec: partitioning=Hash([v0@0], 4), input_partitions=4
09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
10)------------------AggregateExec: mode=Partial, gby=[v0@0 as v0], aggr=[sum(t1.v1)]
11)--------------------DataSourceExec: partitions=1, partition_sizes=[1]
12)------DataSourceExec: partitions=1, partition_sizes=[1]

query III
SELECT *
FROM t0,
LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0);
----
1 1 1
1 2 1
3 3 7

statement ok
drop table t1;

statement ok
drop table t0;

0 comments on commit ab3832f

Please sign in to comment.