-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support when-then-otherwise #2258
base: main
Are you sure you want to change the base?
Changes from 13 commits
114ca10
4f52302
5383dec
5865034
16efb40
c392641
9299f63
f4d488a
a7d699e
0e944bf
b4e9035
c20d0ab
a066bda
ee2cefa
4df9c52
26bb401
9d302f6
3f98b6a
0757bb2
9028500
dc7bed5
e3d87e4
5819d03
70c2834
c95be61
3edcc88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# This is more thoroughly tested than other Exprs only because | ||
# Chuck is new to this, and is more comfortable testing in python than rust, for now. | ||
|
||
import pytest | ||
import opendp.prelude as dp | ||
|
||
|
||
def example_lf(margin=None, **kwargs): | ||
pl = pytest.importorskip("polars") | ||
domains, series = example_series() | ||
lf_domain, lf = dp.lazyframe_domain(domains), pl.LazyFrame(series) | ||
if margin is not None: | ||
lf_domain = dp.with_margin(lf_domain, by=margin, **kwargs) | ||
return lf_domain, lf | ||
|
||
|
||
def example_series(): | ||
pl = pytest.importorskip("polars") | ||
return [ | ||
dp.series_domain("A", dp.option_domain(dp.atom_domain(T=dp.f64))), | ||
], [ | ||
pl.Series("A", [1.0] * 50, dtype=pl.Float64), | ||
] | ||
|
||
|
||
def test_when_then_otherwise(): | ||
pl = pytest.importorskip("polars") | ||
lf_domain, lf = example_lf() | ||
m_lf = dp.t.make_stable_lazyframe( | ||
lf_domain, | ||
dp.symmetric_distance(), | ||
lf.select( | ||
pl.when(pl.col("A") == 1).then(1).otherwise(0).alias('fifty'), | ||
pl.when(pl.col("A") == 0).then(1).otherwise(0).alias('zero'), | ||
), | ||
) | ||
results = m_lf(lf).collect().sum() | ||
assert results['fifty'].item() == 50 | ||
assert results['zero'].item() == 0 | ||
|
||
|
||
def test_when_then_otherwise_strings(): | ||
pl = pytest.importorskip("polars") | ||
lf_domain, lf = example_lf() | ||
m_lf = dp.t.make_stable_lazyframe( | ||
lf_domain, | ||
dp.symmetric_distance(), | ||
lf.select( | ||
pl.when(pl.col("A") == 1).then(pl.lit("one")).otherwise(pl.lit("other")), | ||
), | ||
) | ||
assert m_lf(lf).collect()['literal'][0] == 'one' | ||
|
||
|
||
def test_when_then_otherwise_mismatch_types(): | ||
pl = pytest.importorskip("polars") | ||
lf_domain, lf = example_lf() | ||
with pytest.raises(dp.OpenDPException, match=r'output domains in ternary must match'): | ||
dp.t.make_stable_lazyframe( | ||
lf_domain, | ||
dp.symmetric_distance(), | ||
lf.select( | ||
pl.when(pl.col("A") == 1).then(1).otherwise(pl.lit("!!!")).alias('fifty'), | ||
), | ||
) | ||
|
||
|
||
def test_when_then_otherwise_incomplete(): | ||
pl = pytest.importorskip("polars") | ||
lf_domain, lf = example_lf() | ||
with pytest.raises(Exception, match=r'unsupported literal value: null'): | ||
dp.t.make_stable_lazyframe( | ||
lf_domain, | ||
dp.symmetric_distance(), | ||
lf.select( | ||
pl.when(pl.col("A") == 1).then(1).alias('fifty'), | ||
), | ||
) | ||
# TODO: Should there be a better error message? |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,99 @@ | ||||||
use polars::prelude::*; | ||||||
use polars_plan::dsl::Expr; | ||||||
|
||||||
use crate::core::{Function, MetricSpace, StabilityMap, Transformation}; | ||||||
use crate::domains::{ExprDomain, ExprPlan, OuterMetric, WildExprDomain}; | ||||||
use crate::error::*; | ||||||
use crate::transformations::DatasetMetric; | ||||||
|
||||||
use super::StableExpr; | ||||||
|
||||||
#[cfg(test)] | ||||||
mod test; | ||||||
|
||||||
/// Make a Transformation that returns a `ternary` expression | ||||||
/// | ||||||
/// # Arguments | ||||||
/// * `input_domain` - Expr domain | ||||||
/// * `input_metric` - The metric under which neighboring LazyFrames are compared | ||||||
/// * `expr` - The ternary expression | ||||||
pub fn make_expr_ternary<M: OuterMetric>( | ||||||
input_domain: WildExprDomain, | ||||||
input_metric: M, | ||||||
expr: Expr, | ||||||
) -> Fallible<Transformation<WildExprDomain, ExprDomain, M, M>> | ||||||
where | ||||||
M::InnerMetric: DatasetMetric, | ||||||
M::Distance: Clone, | ||||||
(WildExprDomain, M): MetricSpace, | ||||||
(ExprDomain, M): MetricSpace, | ||||||
Expr: StableExpr<M, M>, | ||||||
{ | ||||||
let Expr::Ternary { | ||||||
predicate, | ||||||
truthy, | ||||||
falsy, | ||||||
} = expr | ||||||
else { | ||||||
return fallible!(MakeTransformation, "expected ternary expression"); | ||||||
}; | ||||||
|
||||||
let t_predicate = predicate | ||||||
.as_ref() | ||||||
.clone() | ||||||
.make_stable(input_domain.as_row_by_row(), input_metric.clone())?; | ||||||
let t_truthy = truthy | ||||||
.as_ref() | ||||||
.clone() | ||||||
.make_stable(input_domain.as_row_by_row(), input_metric.clone())?; | ||||||
let t_falsy = falsy | ||||||
.as_ref() | ||||||
.clone() | ||||||
.make_stable(input_domain.as_row_by_row(), input_metric.clone())?; | ||||||
|
||||||
let (truthy_domain, _truthy_metric) = t_truthy.output_space(); | ||||||
let (falsy_domain, _falsy_metric) = t_falsy.output_space(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be good to check that the metrics match too! |
||||||
|
||||||
if truthy_domain != falsy_domain { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only need to check that dtypes match. It's ok if the names of the columns in the branch arms are different, and similarly if nullability differs between them, and so on. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just to prevent polars fallible casting! |
||||||
return fallible!( | ||||||
MakeTransformation, | ||||||
"output domains in ternary must match, instead found {:?} and {:?}", | ||||||
truthy_domain, | ||||||
falsy_domain | ||||||
); | ||||||
} | ||||||
|
||||||
if matches!(truthy_domain.column.dtype(), DataType::Categorical(_, _)) { | ||||||
mccalluc marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
// Since literal categorical values aren't possible, | ||||||
// not clear if this is actually reachable. | ||||||
return fallible!(MakeTransformation, "ternary cannot be applied to categorical data, because it may trigger a data-dependent CategoricalRemappingWarning in Polars"); | ||||||
} | ||||||
|
||||||
let mut output_domain = truthy_domain.clone(); | ||||||
mccalluc marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
output_domain.column.drop_bounds().ok(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case, if the truthy domain does not have nans and falsey domain has nans, then the resulting output domain would have made the claim that the data does not have nans. Instead, lets clear those descriptors: This is a shorthand to completely replace the element domain with the loosest descriptors.
Suggested change
|
||||||
output_domain.column.nullable = false; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
output_domain.context = input_domain.context.clone(); | ||||||
mccalluc marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
Transformation::new( | ||||||
input_domain, | ||||||
output_domain, | ||||||
Function::new_fallible(move |arg| { | ||||||
let predicate = t_predicate.invoke(arg)?; | ||||||
let truthy = t_truthy.invoke(arg)?; | ||||||
let falsy = t_falsy.invoke(arg)?; | ||||||
|
||||||
Ok(ExprPlan { | ||||||
plan: arg.clone(), | ||||||
expr: Expr::Ternary { | ||||||
predicate: Arc::new(predicate.expr), | ||||||
truthy: Arc::new(truthy.expr), | ||||||
falsy: Arc::new(falsy.expr), | ||||||
}, | ||||||
fill: None, // Ternary is run before aggregation, so there's no empty group that needs a default filled in. | ||||||
}) | ||||||
}), | ||||||
input_metric.clone(), | ||||||
input_metric, | ||||||
StabilityMap::new(Clone::clone), | ||||||
) | ||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some tests would be good to add! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you can allow the computation to run, but the output domain may now contain null.