Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide user-defined invariants for logical node extensions. #14329

Merged
merged 5 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 202 additions & 3 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
//!

use std::fmt::Debug;
use std::hash::Hash;
use std::task::{Context, Poll};
use std::{any::Any, collections::BTreeMap, fmt, sync::Arc};

Expand Down Expand Up @@ -93,7 +94,7 @@ use datafusion::{
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::ScalarValue;
use datafusion_expr::{FetchType, Projection, SortExpr};
use datafusion_expr::{FetchType, Invariant, InvariantLevel, Projection, SortExpr};
use datafusion_optimizer::optimizer::ApplyOrder;
use datafusion_optimizer::AnalyzerRule;
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
Expand Down Expand Up @@ -295,20 +296,175 @@ async fn topk_plan() -> Result<()> {
Ok(())
}

#[tokio::test]
/// Run invariant checks on the logical plan extension [`TopKPlanNode`].
async fn topk_invariants() -> Result<()> {
Comment on lines +300 to +301
Copy link
Contributor Author

@wiedld wiedld Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test: demonstrate the basic use case. That user-defined invariants will fail for an invalid extension node.

// Test: pass an InvariantLevel::Always
let pass = InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Always,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?;
run_and_compare_query(ctx, "Topk context").await?;

// Test: fail an InvariantLevel::Always
let fail = InvariantMock {
should_fail_invariant: true,
kind: InvariantLevel::Always,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?;
matches!(
&*run_and_compare_query(ctx, "Topk context")
.await
.unwrap_err()
.message(),
"node fails check, such as improper inputs"
);

// Test: pass an InvariantLevel::Executable
let pass = InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Executable,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?;
run_and_compare_query(ctx, "Topk context").await?;

// Test: fail an InvariantLevel::Executable
let fail = InvariantMock {
should_fail_invariant: true,
kind: InvariantLevel::Executable,
};
let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?;
matches!(
&*run_and_compare_query(ctx, "Topk context")
.await
.unwrap_err()
.message(),
"node fails check, such as improper inputs"
);

Ok(())
}

#[tokio::test]
async fn topk_invariants_after_invalid_mutation() -> Result<()> {
Comment on lines +349 to +350
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test: demonstrate a failed invariant check after logical plan mutation (during optimizer run).

// CONTROL
// Build a valid topK plan.
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(TopKQueryPlanner {}))
// 1. adds a valid TopKPlanNode
.with_optimizer_rule(Arc::new(TopKOptimizerRule {
invariant_mock: Some(InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Always,
}),
}))
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
.build();
let ctx = setup_table(SessionContext::new_with_state(state)).await?;
run_and_compare_query(ctx, "Topk context").await?;

// Test
// Build a valid topK plan.
// Then have an invalid mutation in an optimizer run.
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(TopKQueryPlanner {}))
// 1. adds a valid TopKPlanNode
.with_optimizer_rule(Arc::new(TopKOptimizerRule {
invariant_mock: Some(InvariantMock {
should_fail_invariant: false,
kind: InvariantLevel::Always,
}),
}))
// 2. break the TopKPlanNode
.with_optimizer_rule(Arc::new(OptimizerMakeExtensionNodeInvalid {}))
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
.build();
let ctx = setup_table(SessionContext::new_with_state(state)).await?;
matches!(
&*run_and_compare_query(ctx, "Topk context")
.await
.unwrap_err()
.message(),
"node fails check, such as improper inputs"
);

Ok(())
}

fn make_topk_context() -> SessionContext {
make_topk_context_with_invariants(None)
}

fn make_topk_context_with_invariants(
invariant_mock: Option<InvariantMock>,
) -> SessionContext {
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(TopKQueryPlanner {}))
.with_optimizer_rule(Arc::new(TopKOptimizerRule {}))
.with_optimizer_rule(Arc::new(TopKOptimizerRule { invariant_mock }))
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
.build();
SessionContext::new_with_state(state)
}

#[derive(Debug)]
struct OptimizerMakeExtensionNodeInvalid;

impl OptimizerRule for OptimizerMakeExtensionNodeInvalid {
fn name(&self) -> &str {
"OptimizerMakeExtensionNodeInvalid"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}

fn supports_rewrite(&self) -> bool {
true
}

// Example rewrite pass which impacts validity of the extension node.
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
if let LogicalPlan::Extension(Extension { node }) = &plan {
if let Some(prev) = node.as_any().downcast_ref::<TopKPlanNode>() {
return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
node: Arc::new(TopKPlanNode {
k: prev.k,
input: prev.input.clone(),
expr: prev.expr.clone(),
// In a real use case, this rewriter could have change the number of inputs, etc
invariant_mock: Some(InvariantMock {
should_fail_invariant: true,
kind: InvariantLevel::Always,
}),
}),
})));
}
};

Ok(Transformed::no(plan))
}
}

// ------ The implementation of the TopK code follows -----

#[derive(Debug)]
Expand Down Expand Up @@ -336,7 +492,10 @@ impl QueryPlanner for TopKQueryPlanner {
}

#[derive(Default, Debug)]
struct TopKOptimizerRule {}
struct TopKOptimizerRule {
/// A testing-only hashable fixture.
invariant_mock: Option<InvariantMock>,
}

impl OptimizerRule for TopKOptimizerRule {
fn name(&self) -> &str {
Expand Down Expand Up @@ -380,6 +539,7 @@ impl OptimizerRule for TopKOptimizerRule {
k: fetch,
input: input.as_ref().clone(),
expr: expr[0].clone(),
invariant_mock: self.invariant_mock.clone(),
}),
})));
}
Expand All @@ -396,6 +556,10 @@ struct TopKPlanNode {
/// The sort expression (this example only supports a single sort
/// expr)
expr: SortExpr,

/// A testing-only hashable fixture.
/// For actual use, define the [`Invariant`] in the [`UserDefinedLogicalNodeCore::invariants`].
invariant_mock: Option<InvariantMock>,
}

impl Debug for TopKPlanNode {
Expand All @@ -406,6 +570,20 @@ impl Debug for TopKPlanNode {
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
struct InvariantMock {
should_fail_invariant: bool,
kind: InvariantLevel,
}

fn invariant_helper_mock_ok(_: &LogicalPlan) -> Result<()> {
Ok(())
}

fn invariant_helper_mock_fails(_: &LogicalPlan) -> Result<()> {
internal_err!("node fails check, such as improper inputs")
}

impl UserDefinedLogicalNodeCore for TopKPlanNode {
fn name(&self) -> &str {
"TopK"
Expand All @@ -420,6 +598,26 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
self.input.schema()
}

fn invariants(&self) -> Vec<Invariant> {
if let Some(InvariantMock {
should_fail_invariant,
kind,
}) = self.invariant_mock.clone()
{
if should_fail_invariant {
return vec![Invariant {
kind,
fun: Arc::new(invariant_helper_mock_fails),
}];
}
return vec![Invariant {
kind,
fun: Arc::new(invariant_helper_mock_ok),
}];
}
vec![] // same as default impl
}

fn expressions(&self) -> Vec<Expr> {
vec![self.expr.expr.clone()]
}
Expand All @@ -440,6 +638,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
k: self.k,
input: inputs.swap_remove(0),
expr: self.expr.with_expr(exprs.swap_remove(0)),
invariant_mock: self.invariant_mock.clone(),
})
}

Expand Down
31 changes: 31 additions & 0 deletions datafusion/expr/src/logical_plan/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::{any::Any, collections::HashSet, fmt, sync::Arc};

use super::invariants::Invariant;
use super::InvariantLevel;

/// This defines the interface for [`LogicalPlan`] nodes that can be
/// used to extend DataFusion with custom relational operators.
///
Expand Down Expand Up @@ -54,6 +57,22 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
/// Return the output schema of this logical plan node.
fn schema(&self) -> &DFSchemaRef;

/// Return the list of invariants.
///
/// Implementing this function enables the user to define the
/// invariants for a given logical plan extension.
fn invariants(&self) -> Vec<Invariant> {
vec![]
}

/// Perform check of invariants for the extension node.
fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> {
self.invariants()
.into_iter()
.filter(|inv| check == inv.kind)
.try_for_each(|inv| inv.check(plan))
}

/// Returns all expressions in the current logical plan node. This should
/// not include expressions of any inputs (aka non-recursively).
///
Expand Down Expand Up @@ -244,6 +263,14 @@ pub trait UserDefinedLogicalNodeCore:
/// Return the output schema of this logical plan node.
fn schema(&self) -> &DFSchemaRef;

/// Return the list of invariants.
///
/// Implementing this function enables the user to define the
/// invariants for a given logical plan extension.
fn invariants(&self) -> Vec<Invariant> {
vec![]
}

/// Returns all expressions in the current logical plan node. This
/// should not include expressions of any inputs (aka
/// non-recursively). These expressions are used for optimizer
Expand Down Expand Up @@ -336,6 +363,10 @@ impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
self.schema()
}

fn invariants(&self) -> Vec<Invariant> {
self.invariants()
}

fn expressions(&self) -> Vec<Expr> {
self.expressions()
}
Expand Down
Loading
Loading