Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support when-then-otherwise #2258

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions python/test/test_polars_ternary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# This is more thoroughly tested than other Exprs only because
# Chuck is new to this, and is more comfortable testing in python than rust, for now.

import pytest
import opendp.prelude as dp


def example_lf(margin=None, **kwargs):
pl = pytest.importorskip("polars")
domains, series = example_series()
lf_domain, lf = dp.lazyframe_domain(domains), pl.LazyFrame(series)
if margin is not None:
lf_domain = dp.with_margin(lf_domain, by=margin, **kwargs)
return lf_domain, lf


def example_series():
pl = pytest.importorskip("polars")
return [
dp.series_domain("A", dp.option_domain(dp.atom_domain(T=dp.f64))),
], [
pl.Series("A", [1.0] * 50, dtype=pl.Float64),
]


def test_when_then_otherwise():
pl = pytest.importorskip("polars")
lf_domain, lf = example_lf()
m_lf = dp.t.make_stable_lazyframe(
lf_domain,
dp.symmetric_distance(),
lf.select(
pl.when(pl.col("A") == 1).then(1).otherwise(0).alias('fifty'),
pl.when(pl.col("A") == 0).then(1).otherwise(0).alias('zero'),
),
)
results = m_lf(lf).collect().sum()
assert results['fifty'].item() == 50
assert results['zero'].item() == 0


def test_when_then_otherwise_strings():
pl = pytest.importorskip("polars")
lf_domain, lf = example_lf()
m_lf = dp.t.make_stable_lazyframe(
lf_domain,
dp.symmetric_distance(),
lf.select(
pl.when(pl.col("A") == 1).then(pl.lit("one")).otherwise(pl.lit("other")),
),
)
assert m_lf(lf).collect()['literal'][0] == 'one'


def test_when_then_otherwise_mismatch_types():
pl = pytest.importorskip("polars")
lf_domain, lf = example_lf()
with pytest.raises(dp.OpenDPException, match=r'output domains in ternary must match'):
dp.t.make_stable_lazyframe(
lf_domain,
dp.symmetric_distance(),
lf.select(
pl.when(pl.col("A") == 1).then(1).otherwise(pl.lit("!!!")).alias('fifty'),
),
)


def test_when_then_otherwise_incomplete():
pl = pytest.importorskip("polars")
lf_domain, lf = example_lf()
with pytest.raises(Exception, match=r'unsupported literal value: null'):
dp.t.make_stable_lazyframe(
lf_domain,
dp.symmetric_distance(),
lf.select(
pl.when(pl.col("A") == 1).then(1).alias('fifty'),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you can allow the computation to run, but the output domain may now contain null.

),
)
# TODO: Should there be a better error message?
7 changes: 7 additions & 0 deletions rust/src/domains/polars/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,16 @@ impl LazyFrameDomain {
}

#[derive(Clone)]
/// Used internally as we wrap built-in Polars expressions
pub struct ExprPlan {
/// This can just be cloned.
pub plan: DslPlan,
/// Typically we will pull the fields out of an incoming expression,
/// process them, and then create a new equivalent expression.
pub expr: Expr,
/// For expressions before aggregation in the chain, this should be None.
/// If this *is* an aggregation, it should be a literal.
/// After an aggregation, this should just duplicate `expr`.
pub fill: Option<Expr>,
}

Expand Down
99 changes: 99 additions & 0 deletions rust/src/transformations/make_stable_expr/expr_ternary/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use polars::prelude::*;
use polars_plan::dsl::Expr;

use crate::core::{Function, MetricSpace, StabilityMap, Transformation};
use crate::domains::{ExprDomain, ExprPlan, OuterMetric, WildExprDomain};
use crate::error::*;
use crate::transformations::DatasetMetric;

use super::StableExpr;

#[cfg(test)]
mod test;

/// Make a Transformation that returns a `ternary` expression
///
/// # Arguments
/// * `input_domain` - Expr domain
/// * `input_metric` - The metric under which neighboring LazyFrames are compared
/// * `expr` - The ternary expression
pub fn make_expr_ternary<M: OuterMetric>(
input_domain: WildExprDomain,
input_metric: M,
expr: Expr,
) -> Fallible<Transformation<WildExprDomain, ExprDomain, M, M>>
where
M::InnerMetric: DatasetMetric,
M::Distance: Clone,
(WildExprDomain, M): MetricSpace,
(ExprDomain, M): MetricSpace,
Expr: StableExpr<M, M>,
{
let Expr::Ternary {
predicate,
truthy,
falsy,
} = expr
else {
return fallible!(MakeTransformation, "expected ternary expression");
};

let t_predicate = predicate
.as_ref()
.clone()
.make_stable(input_domain.as_row_by_row(), input_metric.clone())?;
let t_truthy = truthy
.as_ref()
.clone()
.make_stable(input_domain.as_row_by_row(), input_metric.clone())?;
let t_falsy = falsy
.as_ref()
.clone()
.make_stable(input_domain.as_row_by_row(), input_metric.clone())?;

let (truthy_domain, _truthy_metric) = t_truthy.output_space();
let (falsy_domain, _falsy_metric) = t_falsy.output_space();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to check that the metrics match too!


if truthy_domain != falsy_domain {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only need to check that dtypes match. It's ok if the names of the columns in the branch arms are different, and similarly if nullability differs between them, and so on.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just to prevent polars fallible casting!

return fallible!(
MakeTransformation,
"output domains in ternary must match, instead found {:?} and {:?}",
truthy_domain,
falsy_domain
);
}

if matches!(truthy_domain.column.dtype(), DataType::Categorical(_, _)) {
// Since literal categorical values aren't possible,
// not clear if this is actually reachable.
return fallible!(MakeTransformation, "ternary cannot be applied to categorical data, because it may trigger a data-dependent CategoricalRemappingWarning in Polars");
}

let mut output_domain = truthy_domain.clone();
output_domain.column.drop_bounds().ok();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, if the truthy domain does not have nans and falsey domain has nans, then the resulting output domain would have made the claim that the data does not have nans. Instead, lets clear those descriptors:

This is a shorthand to completely replace the element domain with the loosest descriptors.

Suggested change
output_domain.column.drop_bounds().ok();
output_domain.column.set_dtype(output_domain.column.dtype())?;

output_domain.column.nullable = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_domain.column.nullable = false;
output_domain.column.nullable |= falsey_domain.column.nullable;

output_domain.context = input_domain.context.clone();

Transformation::new(
input_domain,
output_domain,
Function::new_fallible(move |arg| {
let predicate = t_predicate.invoke(arg)?;
let truthy = t_truthy.invoke(arg)?;
let falsy = t_falsy.invoke(arg)?;

Ok(ExprPlan {
plan: arg.clone(),
expr: Expr::Ternary {
predicate: Arc::new(predicate.expr),
truthy: Arc::new(truthy.expr),
falsy: Arc::new(falsy.expr),
},
fill: None, // Ternary is run before aggregation, so there's no empty group that needs a default filled in.
})
}),
input_metric.clone(),
input_metric,
StabilityMap::new(Clone::clone),
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tests would be good to add!

6 changes: 6 additions & 0 deletions rust/src/transformations/make_stable_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ mod expr_sum;
#[cfg(feature = "contrib")]
mod expr_to_physical;

#[cfg(feature = "contrib")]
mod expr_ternary;

#[cfg(feature = "contrib")]
mod namespace_dt;

Expand Down Expand Up @@ -174,6 +177,9 @@ where
..
} => namespace_str::make_namespace_str(input_domain, input_metric, self),

#[cfg(feature = "contrib")]
Expr::Ternary { .. } => expr_ternary::make_expr_ternary(input_domain, input_metric, self),

expr => fallible!(
MakeTransformation,
"Expr is not recognized at this time: {:?}. {}If you would like to see this supported, please file an issue.",
Expand Down
4 changes: 3 additions & 1 deletion rust/src/transformations/make_stable_lazyframe/source/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ where
if &input_domain.schema() != schema.as_ref() {
return fallible!(
MakeTransformation,
"Schema mismatch. LazyFrame schema must match the schema from the input domain."
"Schema mismatch. LazyFrame schema must match the schema from the input domain. {:?} != {:?}",
&input_domain.schema(),
schema.as_ref()
);
}

Expand Down
Loading