Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Witgen: Pass range constraints separately #2451

Merged
merged 8 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions executor/src/witgen/affine_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ where
/// we can deduce the values of all components from the offset part.
pub fn solve_with_range_constraints(
&self,
known_constraints: &impl RangeConstraintSet<K, T>,
known_constraints: &dyn RangeConstraintSet<K, T>,
) -> EvalResult<T, K> {
// Try to solve directly.
let value = self.solve()?;
Expand Down Expand Up @@ -272,7 +272,7 @@ where
/// where `dividend` and `divisor` are known and `remainder` is range-constrained to be smaller than `divisor`.
fn try_solve_division(
&self,
known_constraints: &impl RangeConstraintSet<K, T>,
known_constraints: &dyn RangeConstraintSet<K, T>,
) -> Option<EvalResult<T, K>> {
// Detect pattern: `dividend = divisor * quotient + remainder`
let (first, second, offset) = match self {
Expand Down Expand Up @@ -332,7 +332,7 @@ where

fn try_transfer_constraints(
&self,
known_constraints: &impl RangeConstraintSet<K, T>,
known_constraints: &dyn RangeConstraintSet<K, T>,
) -> Option<(K, RangeConstraint<T>)> {
// We are looking for X = a * Y + b * Z + ... or -X = a * Y + b * Z + ...
// where X is least constrained.
Expand Down Expand Up @@ -378,7 +378,7 @@ where
/// Returns an empty vector if it is not able to solve the equation.
fn try_solve_through_constraints(
&self,
known_constraints: &impl RangeConstraintSet<K, T>,
known_constraints: &dyn RangeConstraintSet<K, T>,
) -> EvalResult<T, K> {
// Get constraints from coefficients and also collect unconstrained indices.
let (constraints, unconstrained): (Vec<_>, Vec<K>) = self
Expand Down
12 changes: 6 additions & 6 deletions executor/src/witgen/data_structures/caller_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use powdr_number::FieldElement;

use crate::witgen::{
machines::LookupCell,
processor::{Left, OuterQuery},
processor::{Arguments, OuterQuery},
EvalError, EvalResult, EvalValue,
};

Expand All @@ -14,20 +14,20 @@ pub struct CallerData<'a, 'b, T> {
/// The raw data of the caller. Unknown values should be ignored.
data: Vec<T>,
/// The affine expressions of the caller.
left: &'b Left<'a, T>,
arguments: &'b Arguments<'a, T>,
}

impl<'a, 'b, T: FieldElement> From<&'b OuterQuery<'a, '_, T>> for CallerData<'a, 'b, T> {
/// Builds a `CallerData` from an `OuterQuery`.
fn from(outer_query: &'b OuterQuery<'a, '_, T>) -> Self {
let data = outer_query
.left
.arguments
.iter()
.map(|l| l.constant_value().unwrap_or_default())
.collect();
Self {
data,
left: &outer_query.left,
arguments: &outer_query.arguments,
}
}
}
Expand All @@ -37,7 +37,7 @@ impl<T: FieldElement> CallerData<'_, '_, T> {
pub fn as_lookup_cells(&mut self) -> Vec<LookupCell<'_, T>> {
self.data
.iter_mut()
.zip_eq(self.left.iter())
.zip_eq(self.arguments.iter())
.map(|(value, left)| match left.constant_value().is_some() {
true => LookupCell::Input(value),
false => LookupCell::Output(value),
Expand All @@ -52,7 +52,7 @@ impl<'a, 'b, T: FieldElement> From<CallerData<'a, 'b, T>> for EvalResult<'a, T>
/// Note that this function assumes that the lookup was successful and complete.
fn from(data: CallerData<'a, 'b, T>) -> EvalResult<'a, T> {
let mut result = EvalValue::complete(vec![]);
for (l, v) in data.left.iter().zip_eq(data.data.iter()) {
for (l, v) in data.arguments.iter().zip_eq(data.data.iter()) {
if !l.is_constant() {
let evaluated = l.clone() - (*v).into();
match evaluated.solve() {
Expand Down
17 changes: 11 additions & 6 deletions executor/src/witgen/data_structures/mutable_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ use bit_vec::BitVec;
use powdr_number::FieldElement;

use crate::witgen::{
global_constraints::RangeConstraintSet,
machines::{KnownMachine, LookupCell, Machine},
range_constraints::RangeConstraint,
rows::RowPair,
EvalError, EvalResult, QueryCallback,
AffineExpression, AlgebraicVariable, EvalError, EvalResult, QueryCallback,
};

/// The container and access method for machines and the query callback.
Expand Down Expand Up @@ -61,11 +61,16 @@ impl<'a, T: FieldElement, Q: QueryCallback<T>> MutableState<'a, T, Q> {
machine.can_process_call_fully(self, identity_id, known_inputs, range_constraints)
}

/// Call the machine responsible for the right-hand-side of an identity given its ID
/// and the row pair of the caller.
pub fn call(&self, identity_id: u64, caller_rows: &RowPair<'_, 'a, T>) -> EvalResult<'a, T> {
/// Call the machine responsible for the right-hand-side of an identity given its ID,
/// the evaluated arguments and the caller's range constraints.
pub fn call(
&self,
identity_id: u64,
arguments: &[AffineExpression<AlgebraicVariable<'a>, T>],
range_constraints: &dyn RangeConstraintSet<AlgebraicVariable<'a>, T>,
) -> EvalResult<'a, T> {
self.responsible_machine(identity_id)?
.process_plookup_timed(self, identity_id, caller_rows)
.process_plookup_timed(self, identity_id, arguments, range_constraints)
}

/// Call the machine responsible for the right-hand-side of an identity given its ID,
Expand Down
24 changes: 12 additions & 12 deletions executor/src/witgen/global_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,26 @@ impl<'a, T: FieldElement> RangeConstraintSet<AlgebraicVariable<'a>, T>
}

/// A range constraint set that combines two other range constraint sets.
pub struct CombinedRangeConstraintSet<'a, R1, R2, K, T>
pub struct CombinedRangeConstraintSet<'a, R, K, T>
where
T: FieldElement,
R1: RangeConstraintSet<K, T>,
R2: RangeConstraintSet<K, T>,
R: RangeConstraintSet<K, T>,
{
range_constraints1: &'a R1,
range_constraints2: &'a R2,
range_constraints1: &'a dyn RangeConstraintSet<K, T>,
range_constraints2: &'a R,
_marker_k: PhantomData<K>,
_marker_t: PhantomData<T>,
}

impl<'a, R1, R2, K, T> CombinedRangeConstraintSet<'a, R1, R2, K, T>
impl<'a, R, K, T> CombinedRangeConstraintSet<'a, R, K, T>
where
T: FieldElement,
R1: RangeConstraintSet<K, T>,
R2: RangeConstraintSet<K, T>,
R: RangeConstraintSet<K, T>,
{
pub fn new(range_constraints1: &'a R1, range_constraints2: &'a R2) -> Self {
pub fn new(
range_constraints1: &'a dyn RangeConstraintSet<K, T>,
range_constraints2: &'a R,
) -> Self {
Self {
range_constraints1,
range_constraints2,
Expand All @@ -75,12 +76,11 @@ where
}
}

impl<R1, R2, K, T> RangeConstraintSet<K, T> for CombinedRangeConstraintSet<'_, R1, R2, K, T>
impl<R, K, T> RangeConstraintSet<K, T> for CombinedRangeConstraintSet<'_, R, K, T>
where
T: FieldElement,
K: Copy,
R1: RangeConstraintSet<K, T>,
R2: RangeConstraintSet<K, T>,
R: RangeConstraintSet<K, T>,
{
fn range_constraint(&self, id: K) -> Option<RangeConstraint<T>> {
match (
Expand Down
16 changes: 13 additions & 3 deletions executor/src/witgen/identity_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,17 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> IdentityProcessor<'a, 'c, T,
return Ok(status);
}

self.mutable_state.call(id, rows)
let left = match left
.expressions
.iter()
.map(|e| rows.evaluate(e))
.collect::<Result<Vec<_>, _>>()
{
Ok(expressions) => expressions,
Err(incomplete_cause) => return Ok(EvalValue::incomplete(incomplete_cause)),
};

self.mutable_state.call(id, &left, rows)
}

/// Handles the lookup that connects the current machine to the calling machine.
Expand All @@ -102,11 +112,11 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> IdentityProcessor<'a, 'c, T,
.ok_or(EvalError::Generic("Selector is not 1!".to_string()))?;

let range_constraint =
CombinedRangeConstraintSet::new(outer_query.caller_rows, current_rows);
CombinedRangeConstraintSet::new(outer_query.range_constraints, current_rows);

let mut updates = EvalValue::complete(vec![]);

for (l, r) in outer_query.left.iter().zip(right.expressions.iter()) {
for (l, r) in outer_query.arguments.iter().zip(right.expressions.iter()) {
match current_rows.evaluate(r) {
Ok(r) => {
let result = (l.clone() - r).solve_with_range_constraints(&range_constraint)?;
Expand Down
42 changes: 27 additions & 15 deletions executor/src/witgen/machines/block_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@ use crate::witgen::block_processor::BlockProcessor;
use crate::witgen::data_structures::caller_data::CallerData;
use crate::witgen::data_structures::finalizable_data::FinalizableData;
use crate::witgen::data_structures::mutable_state::MutableState;
use crate::witgen::global_constraints::RangeConstraintSet;
use crate::witgen::jit::function_cache::FunctionCache;
use crate::witgen::jit::witgen_inference::CanProcessCall;
use crate::witgen::processor::{OuterQuery, Processor, SolverState};
use crate::witgen::range_constraints::RangeConstraint;
use crate::witgen::rows::{Row, RowIndex, RowPair};
use crate::witgen::rows::{Row, RowIndex};
use crate::witgen::sequence_iterator::{
DefaultSequenceIterator, ProcessingSequenceCache, ProcessingSequenceIterator,
};
use crate::witgen::util::try_to_simple_poly;
use crate::witgen::AffineExpression;
use crate::witgen::{machines::Machine, EvalError, EvalValue, IncompleteCause, QueryCallback};
use bit_vec::BitVec;
use powdr_ast::analyzed::{DegreeRange, PolyID, PolynomialType};
Expand Down Expand Up @@ -202,10 +204,12 @@ impl<'a, T: FieldElement> Machine<'a, T> for BlockMachine<'a, T> {
&mut self,
mutable_state: &'b MutableState<'a, T, Q>,
identity_id: u64,
caller_rows: &'b RowPair<'b, 'a, T>,
arguments: &[AffineExpression<AlgebraicVariable<'a>, T>],
range_constraints: &dyn RangeConstraintSet<AlgebraicVariable<'a>, T>,
) -> EvalResult<'a, T> {
let previous_len = self.data.len();
let result = self.process_plookup_internal(mutable_state, identity_id, caller_rows);
let result =
self.process_plookup_internal(mutable_state, identity_id, arguments, range_constraints);
if let Ok(assignments) = &result {
if !assignments.is_complete() {
// rollback the changes.
Expand Down Expand Up @@ -406,22 +410,26 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
RowIndex::from_i64(self.rows() as i64 - 1, self.degree)
}

fn process_plookup_internal<'b, Q: QueryCallback<T>>(
fn process_plookup_internal<Q: QueryCallback<T>>(
&mut self,
mutable_state: &MutableState<'a, T, Q>,
identity_id: u64,
caller_rows: &'b RowPair<'b, 'a, T>,
arguments: &[AffineExpression<AlgebraicVariable<'a>, T>],
range_constraints: &dyn RangeConstraintSet<AlgebraicVariable<'a>, T>,
) -> EvalResult<'a, T> {
let outer_query =
match OuterQuery::try_new(caller_rows, self.parts.connections[&identity_id]) {
Ok(outer_query) => outer_query,
Err(incomplete_cause) => return Ok(EvalValue::incomplete(incomplete_cause)),
};
let outer_query = match OuterQuery::try_new(
arguments,
range_constraints,
self.parts.connections[&identity_id],
) {
Ok(outer_query) => outer_query,
Err(incomplete_cause) => return Ok(EvalValue::incomplete(incomplete_cause)),
};

log::trace!("Start processing block machine '{}'", self.name());
log::trace!("Left values of lookup:");
if log::log_enabled!(log::Level::Trace) {
for l in &outer_query.left {
for l in &outer_query.arguments {
log::trace!(" {}", l);
}
}
Expand All @@ -430,7 +438,11 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
return Err(EvalError::RowsExhausted(self.name.clone()));
}

let known_inputs = outer_query.left.iter().map(|e| e.is_constant()).collect();
let known_inputs = outer_query
.arguments
.iter()
.map(|e| e.is_constant())
.collect();
if self
.function_cache
.compile_cached(mutable_state, identity_id, &known_inputs)
Expand All @@ -445,7 +457,7 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
// TODO this assumes we are always using the same lookup for this machine.
let mut sequence_iterator = self
.processing_sequence_cache
.get_processing_sequence(&outer_query.left);
.get_processing_sequence(&outer_query.arguments);

if !sequence_iterator.has_steps() {
// Shortcut, no need to do anything.
Expand Down Expand Up @@ -477,7 +489,7 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {

// We solved the query, so report it to the cache.
self.processing_sequence_cache
.report_processing_sequence(&outer_query.left, sequence_iterator);
.report_processing_sequence(&outer_query.arguments, sequence_iterator);
Ok(updates)
}
ProcessResult::Incomplete(updates) => {
Expand All @@ -486,7 +498,7 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
self.name()
);
self.processing_sequence_cache
.report_incomplete(&outer_query.left);
.report_incomplete(&outer_query.arguments);
Ok(updates)
}
}
Expand Down
Loading
Loading