@@ -386,9 +386,7 @@ fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
386
386
None
387
387
}
388
388
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
389
- Expr :: BinaryExpr ( BinaryExpr { left, op, right } )
390
- if matches ! ( op, Operator :: And | Operator :: Or ) =>
391
- {
389
+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) if op == Operator :: And => {
392
390
let l = remove_join_expressions ( * left, join_keys) ;
393
391
let r = remove_join_expressions ( * right, join_keys) ;
394
392
match ( l, r) {
@@ -402,7 +400,20 @@ fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
402
400
_ => None ,
403
401
}
404
402
}
405
-
403
+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) if op == Operator :: Or => {
404
+ let l = remove_join_expressions ( * left, join_keys) ;
405
+ let r = remove_join_expressions ( * right, join_keys) ;
406
+ match ( l, r) {
407
+ ( Some ( ll) , Some ( rr) ) => Some ( Expr :: BinaryExpr ( BinaryExpr :: new (
408
+ Box :: new ( ll) ,
409
+ op,
410
+ Box :: new ( rr) ,
411
+ ) ) ) ,
412
+ // When either `left` or `right` is empty, it means they are `true`
413
+ // so OR'ing anything with them will also be true
414
+ _ => None ,
415
+ }
416
+ }
406
417
_ => Some ( expr) ,
407
418
}
408
419
}
@@ -995,6 +1006,7 @@ mod tests {
995
1006
let t4 = test_table_scan_with_name ( "t4" ) ?;
996
1007
997
1008
// could eliminate to inner join
1009
+ // filter: (t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND tc.2 = 688)
998
1010
let plan1 = LogicalPlanBuilder :: from ( t1)
999
1011
. cross_join ( t2) ?
1000
1012
. filter ( binary_expr (
@@ -1012,6 +1024,10 @@ mod tests {
1012
1024
let plan2 = LogicalPlanBuilder :: from ( t3) . cross_join ( t4) ?. build ( ) ?;
1013
1025
1014
1026
// could eliminate to inner join
1027
+ // filter:
1028
+ // ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1029
+ // AND
1030
+ // ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1015
1031
let plan = LogicalPlanBuilder :: from ( plan1)
1016
1032
. cross_join ( plan2) ?
1017
1033
. filter ( binary_expr (
@@ -1057,7 +1073,7 @@ mod tests {
1057
1073
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]" ,
1058
1074
" Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]" ,
1059
1075
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]" ,
1060
- " Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]" ,
1076
+ " Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]" ,
1061
1077
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]" ,
1062
1078
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]" ,
1063
1079
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]" ,
@@ -1084,6 +1100,12 @@ mod tests {
1084
1100
let plan2 = LogicalPlanBuilder :: from ( t3) . cross_join ( t4) ?. build ( ) ?;
1085
1101
1086
1102
// could eliminate to inner join
1103
+ // Filter:
1104
+ // ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1105
+ // AND
1106
+ // ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1107
+ // AND
1108
+ // ((t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND t2.c = 688))
1087
1109
let plan = LogicalPlanBuilder :: from ( plan1)
1088
1110
. cross_join ( plan2) ?
1089
1111
. filter ( binary_expr (
@@ -1142,7 +1164,7 @@ mod tests {
1142
1164
. build ( ) ?;
1143
1165
1144
1166
let expected = vec ! [
1145
- "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]" ,
1167
+ "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]" ,
1146
1168
" Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]" ,
1147
1169
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]" ,
1148
1170
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]" ,
0 commit comments