diff --git a/cli/src/main.rs b/cli/src/main.rs index b426c173d..ea3a302d2 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -12,6 +12,7 @@ use powdr::number::{ BabyBearField, BigUint, Bn254Field, FieldElement, GoldilocksField, KoalaBearField, Mersenne31Field, }; +use powdr::pipeline::pipeline::LinkerMode; use powdr::pipeline::test_runner; use powdr::Pipeline; use std::io; @@ -154,6 +155,10 @@ enum Commands { #[arg(long)] backend_options: Option, + /// Linker mode, deciding how to reduce links to constraints. + #[arg(long)] + linker_mode: Option, + /// Generate a CSV file containing the witness column values. #[arg(long)] #[arg(default_value_t = false)] @@ -464,6 +469,7 @@ fn run_command(command: Commands) { prove_with, params, backend_options, + linker_mode, export_witness_csv, export_all_columns_csv, csv_mode, @@ -478,6 +484,7 @@ fn run_command(command: Commands) { prove_with, params, backend_options, + linker_mode, export_witness_csv, export_all_columns_csv, csv_mode @@ -668,6 +675,7 @@ fn run_pil( prove_with: Option, params: Option, backend_options: Option, + linker_mode: Option, export_witness: bool, export_all_columns: bool, csv_mode: CsvRenderModeCLI, @@ -675,7 +683,9 @@ fn run_pil( let inputs = split_inputs::(&inputs); let pipeline = bind_cli_args( - Pipeline::::default().from_file(PathBuf::from(&file)), + Pipeline::::default() + .from_file(PathBuf::from(&file)) + .with_linker_mode(linker_mode.unwrap_or_default()), inputs.clone(), PathBuf::from(output_directory), force, @@ -804,6 +814,7 @@ mod test { prove_with: Some(BackendType::EStarkDump), params: None, backend_options: Some("stark_gl".to_string()), + linker_mode: None, export_witness_csv: false, export_all_columns_csv: true, csv_mode: CsvRenderModeCLI::Hex, diff --git a/linker/Cargo.toml b/linker/Cargo.toml index 5d05b58d3..a44b2dd5e 100644 --- a/linker/Cargo.toml +++ b/linker/Cargo.toml @@ -12,6 +12,7 @@ powdr-analysis.workspace = true powdr-ast.workspace = true powdr-number.workspace = true powdr-parser-util.workspace = true +strum = { version = "0.24.1", features = ["derive"] } pretty_assertions = "1.4.0" itertools = "0.13" diff --git a/linker/src/lib.rs b/linker/src/lib.rs index ce3c41d1e..f644eb087 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -4,15 +4,17 @@ use lazy_static::lazy_static; use powdr_analysis::utils::parse_pil_statement; use powdr_ast::{ asm_analysis::{combine_flags, MachineDegree}, - object::{Link, Location, MachineInstanceGraph}, + object::{Link, Location, MachineInstanceGraph, Object}, parsed::{ - asm::{AbsoluteSymbolPath, SymbolPath}, + asm::{AbsoluteSymbolPath, Part, SymbolPath}, build::{index_access, lookup, namespaced_reference, permutation, selected}, - ArrayLiteral, Expression, NamespaceDegree, PILFile, PilStatement, + visitor::{ExpressionVisitable, VisitOrder}, + ArrayLiteral, Expression, FunctionCall, NamespaceDegree, PILFile, PilStatement, }, }; use powdr_parser_util::SourceRef; -use std::{collections::BTreeMap, iter::once}; +use std::{collections::BTreeMap, iter::once, ops::ControlFlow, str::FromStr}; +use strum::{Display, EnumString, EnumVariantNames}; const MAIN_OPERATION_NAME: &str = "main"; /// The log of the default minimum degree @@ -35,75 +37,389 @@ lazy_static! { }; } -/// Convert a [MachineDegree] into a [NamespaceDegree], setting any unset bounds to the relevant default values -fn to_namespace_degree(d: MachineDegree) -> NamespaceDegree { - NamespaceDegree { - min: d - .min - .unwrap_or_else(|| Expression::from(1 << MIN_DEGREE_LOG)), - max: d - .max - .unwrap_or_else(|| Expression::from(1 << *MAX_DEGREE_LOG)), - } +/// Link the objects into a single PIL file, using the specified mode. +pub fn link(graph: MachineInstanceGraph, mode: LinkerMode) -> Result> { + Linker::new(mode).link(graph) } -/// The optional degree of the namespace is set to that of the object if it's set, to that of the main object otherwise. -pub fn link(graph: MachineInstanceGraph) -> Result> { - let main_machine = graph.main; - let main_degree = graph - .objects - .get(&main_machine.location) - .unwrap() - .degree - .clone(); +#[derive(Clone, EnumString, EnumVariantNames, Display, Copy, Default)] +/// Whether to link the machines natively or via a global bus. +pub enum LinkerMode { + #[default] + #[strum(serialize = "native")] + Native, + #[strum(serialize = "bus")] + Bus, +} - let mut pil = process_definitions(graph.statements); +#[derive(Default)] +/// Whether to align the degrees of all machines to the main machine, or to use the degrees of the individual machines. +pub enum DegreeMode { + Monolithic(MachineDegree), + #[default] + Vadcop, +} - for (location, object) in graph.objects.into_iter() { - // create a namespace for this object - let degree = match main_degree.is_static() { - true => main_degree.clone(), - false => object.degree, +#[derive(Default)] +struct Linker { + mode: LinkerMode, + degrees: DegreeMode, + /// for each namespace, we store the statements resulting from processing the links separately, because we need to make sure they do not come first. + namespaces: BTreeMap, Vec)>, + next_interaction_id: u32, +} + +impl Linker { + fn new(mode: LinkerMode) -> Self { + Self { + mode, + ..Default::default() + } + } + + fn next_interaction_id(&mut self) -> u32 { + let id = self.next_interaction_id; + self.next_interaction_id += 1; + id + } + + fn link(mut self, graph: MachineInstanceGraph) -> Result> { + let main_machine = graph.main; + let main_degree = graph + .objects + .get(&main_machine.location) + .unwrap() + .degree + .clone(); + + // If the main object has a static degree, we use the monolithic mode with that degree. + self.degrees = match main_degree.is_static() { + true => DegreeMode::Monolithic(main_degree), + false => DegreeMode::Vadcop, }; + let common_definitions = process_definitions(graph.statements); + + for (location, object) in graph.objects { + self.process_object(location.clone(), object); + + if location == Location::main() { + if let Some(main_operation) = graph + .entry_points + .iter() + .find(|f| f.name == MAIN_OPERATION_NAME) + { + let main_operation_id = main_operation.id.clone(); + let operation_id = main_machine.operation_id.clone(); + match (operation_id, main_operation_id) { + (Some(operation_id), Some(main_operation_id)) => { + // call the main operation by initializing `operation_id` to that of the main operation + let linker_first_step = "_linker_first_step"; + self.namespaces.get_mut(&location.to_string()).unwrap().1.extend([ + parse_pil_statement(&format!( + "col fixed {linker_first_step}(i) {{ if i == 0 {{ 1 }} else {{ 0 }} }};" + )), + parse_pil_statement(&format!( + "{linker_first_step} * ({operation_id} - {main_operation_id}) = 0;" + )), + ]); + } + (None, None) => {} + _ => unreachable!(), + } + } + } + } + + Ok(PILFile( + common_definitions + .into_iter() + .chain( + self.namespaces + .into_iter() + .flat_map(|(_, (statements, links))| statements.into_iter().chain(links)), + ) + .collect(), + )) + } + + fn process_object(&mut self, location: Location, object: Object) { + let degree = match &self.degrees { + DegreeMode::Monolithic(d) => d.clone(), + DegreeMode::Vadcop => object.degree, + }; + + let namespace = location.to_string(); + let namespace_degree = to_namespace_degree(degree); + + let (pil, _) = self.namespaces.entry(namespace.clone()).or_default(); + + // create a namespace for this object pil.push(PilStatement::Namespace( SourceRef::unknown(), - SymbolPath::from_identifier(location.to_string()), - Some(to_namespace_degree(degree)), + SymbolPath::from_identifier(namespace.clone()), + Some(namespace_degree), )); pil.extend(object.pil); - pil.extend(object.links.into_iter().map(process_link)); - - if location == Location::main() { - if let Some(main_operation) = graph - .entry_points - .iter() - .find(|f| f.name == MAIN_OPERATION_NAME) - { - let main_operation_id = main_operation.id.clone(); - let operation_id = main_machine.operation_id.clone(); - match (operation_id, main_operation_id) { - (Some(operation_id), Some(main_operation_id)) => { - // call the main operation by initializing `operation_id` to that of the main operation - let linker_first_step = "_linker_first_step"; - pil.extend([ - parse_pil_statement(&format!( - "col fixed {linker_first_step}(i) {{ if i == 0 {{ 1 }} else {{ 0 }} }};" - )), - parse_pil_statement(&format!( - "{linker_first_step} * ({operation_id} - {main_operation_id}) = 0;" - )), - ]); - } - (None, None) => {} - _ => unreachable!(), + for link in object.links { + self.process_link(link, namespace.clone()); + } + } + + fn process_link(&mut self, link: Link, from_namespace: String) { + let from = link.from; + let to = link.to; + + let to_namespace = to.machine.location.clone().to_string(); + + // the lhs is `instr_flag { operation_id, inputs, outputs }` + let op_id = to.operation.id.iter().cloned().map(|n| n.into()); + + if link.is_permutation { + // permutation lhs is `flag { operation_id, inputs, outputs }` + let lhs = selected( + combine_flags(from.instr_flag, from.link_flag), + ArrayLiteral { + items: op_id + .chain(from.params.inputs) + .chain(from.params.outputs) + .collect(), } + .into(), + ); + + // permutation rhs is `(latch * selector[idx]) { operation_id, inputs, outputs }` + let op_id = to + .machine + .operation_id + .map(|oid| namespaced_reference(to_namespace.clone(), oid)) + .into_iter(); + + let latch = namespaced_reference(to_namespace.clone(), to.machine.latch.unwrap()); + let rhs_selector = if let Some(call_selectors) = to.machine.call_selectors { + let call_selector_array = + namespaced_reference(to_namespace.clone(), call_selectors); + let call_selector = + index_access(call_selector_array, Some(to.selector_idx.unwrap().into())); + latch * call_selector + } else { + latch + }; + + let rhs = selected( + rhs_selector, + ArrayLiteral { + items: op_id + .chain(to.operation.params.inputs_and_outputs().map(|i| { + index_access( + namespaced_reference(to_namespace.clone(), &i.name), + i.index.clone(), + ) + })) + .collect(), + } + .into(), + ); + + self.insert_interaction( + InteractionType::Permutation, + from_namespace, + to_namespace, + lhs, + rhs, + ); + } else { + // plookup lhs is `flag $ [ operation_id, inputs, outputs ]` + let lhs = selected( + combine_flags(from.instr_flag, from.link_flag), + ArrayLiteral { + items: op_id + .chain(from.params.inputs) + .chain(from.params.outputs) + .collect(), + } + .into(), + ); + + let op_id = to + .machine + .operation_id + .map(|oid| namespaced_reference(to_namespace.clone(), oid)) + .into_iter(); + + let latch = namespaced_reference(to_namespace.clone(), to.machine.latch.unwrap()); + + // plookup rhs is `latch $ [ operation_id, inputs, outputs ]` + let rhs = selected( + latch, + ArrayLiteral { + items: op_id + .chain(to.operation.params.inputs_and_outputs().map(|i| { + index_access( + namespaced_reference(to_namespace.clone(), &i.name), + i.index.clone(), + ) + })) + .collect(), + } + .into(), + ); + + self.insert_interaction( + InteractionType::Lookup, + from_namespace, + to_namespace, + lhs, + rhs, + ); + }; + } + + fn insert_interaction( + &mut self, + interaction_type: InteractionType, + from_namespace: String, + to_namespace: String, + lhs: Expression, + rhs: Expression, + ) { + // get a new unique interaction id + let interaction_id = self.next_interaction_id(); + + match self.mode { + LinkerMode::Native => { + self.namespaces.entry(from_namespace).or_default().1.push( + PilStatement::Expression( + SourceRef::unknown(), + match interaction_type { + InteractionType::Lookup => lookup(lhs, rhs), + InteractionType::Permutation => permutation(lhs, rhs), + }, + ), + ); + } + LinkerMode::Bus => { + // send in the origin + self.namespaces + .entry(from_namespace.clone()) + .or_default() + .1 + .push(PilStatement::Expression( + SourceRef::unknown(), + send(interaction_type, lhs.clone(), rhs.clone(), interaction_id), + )); + + // receive in the destination + self.namespaces + .entry(to_namespace) + .or_default() + .1 + .push(PilStatement::Expression( + SourceRef::unknown(), + receive( + interaction_type, + namespaced_expression(from_namespace, lhs), + rhs, + interaction_id, + ), + )); } } } +} + +#[derive(Clone, Copy)] +enum InteractionType { + Lookup, + Permutation, +} + +fn send( + identity_type: InteractionType, + lhs: Expression, + rhs: Expression, + interaction_id: u32, +) -> Expression { + let (function, identity) = match identity_type { + InteractionType::Lookup => ( + SymbolPath::from_str("std::protocols::lookup_via_bus::lookup_send") + .unwrap() + .into(), + lookup(lhs, rhs), + ), + InteractionType::Permutation => ( + SymbolPath::from_str("std::protocols::permutation_via_bus::permutation_send") + .unwrap() + .into(), + permutation(lhs, rhs), + ), + }; - Ok(PILFile(pil)) + Expression::FunctionCall( + SourceRef::unknown(), + FunctionCall { + function: Box::new(Expression::Reference(SourceRef::unknown(), function)), + arguments: vec![interaction_id.into(), identity], + }, + ) +} + +fn receive( + identity_type: InteractionType, + lhs: Expression, + rhs: Expression, + interaction_id: u32, +) -> Expression { + let (function, identity) = match identity_type { + InteractionType::Lookup => ( + SymbolPath::from_str("std::protocols::lookup_via_bus::lookup_receive") + .unwrap() + .into(), + lookup(lhs, rhs), + ), + InteractionType::Permutation => ( + SymbolPath::from_str("std::protocols::permutation_via_bus::permutation_receive") + .unwrap() + .into(), + permutation(lhs, rhs), + ), + }; + + Expression::FunctionCall( + SourceRef::unknown(), + FunctionCall { + function: Box::new(Expression::Reference(SourceRef::unknown(), function)), + arguments: vec![interaction_id.into(), identity], + }, + ) +} + +/// Convert a [MachineDegree] into a [NamespaceDegree], setting any unset bounds to the relevant default values +fn to_namespace_degree(d: MachineDegree) -> NamespaceDegree { + NamespaceDegree { + min: d + .min + .unwrap_or_else(|| Expression::from(1 << MIN_DEGREE_LOG)), + max: d + .max + .unwrap_or_else(|| Expression::from(1 << *MAX_DEGREE_LOG)), + } +} + +fn namespaced_expression(namespace: String, mut expr: Expression) -> Expression { + expr.visit_expressions_mut( + &mut |expr| { + if let Expression::Reference(_, refs) = expr { + refs.path = SymbolPath::from_parts( + once(Part::Named(namespace.clone())).chain(refs.path.clone().into_parts()), + ); + } + ControlFlow::Continue::<(), _>(()) + }, + VisitOrder::Pre, + ); + expr } // Extract the utilities and sort them into namespaces where possible. @@ -130,109 +446,11 @@ fn process_definitions( .collect() } -fn process_link(link: Link) -> PilStatement { - let from = link.from; - let to = link.to; - - // the lhs is `instr_flag { operation_id, inputs, outputs }` - let op_id = to.operation.id.iter().cloned().map(|n| n.into()); - - let expr = if link.is_permutation { - // permutation lhs is `flag { operation_id, inputs, outputs }` - let lhs = selected( - combine_flags(from.instr_flag, from.link_flag), - ArrayLiteral { - items: op_id - .chain(from.params.inputs) - .chain(from.params.outputs) - .collect(), - } - .into(), - ); - - // permutation rhs is `(latch * selector[idx]) { operation_id, inputs, outputs }` - let to_namespace = to.machine.location.clone().to_string(); - let op_id = to - .machine - .operation_id - .map(|oid| namespaced_reference(to_namespace.clone(), oid)) - .into_iter(); - - let latch = namespaced_reference(to_namespace.clone(), to.machine.latch.unwrap()); - let rhs_selector = if let Some(call_selectors) = to.machine.call_selectors { - let call_selector_array = namespaced_reference(to_namespace.clone(), call_selectors); - let call_selector = - index_access(call_selector_array, Some(to.selector_idx.unwrap().into())); - latch * call_selector - } else { - latch - }; - - let rhs = selected( - rhs_selector, - ArrayLiteral { - items: op_id - .chain(to.operation.params.inputs_and_outputs().map(|i| { - index_access( - namespaced_reference(to_namespace.clone(), &i.name), - i.index.clone(), - ) - })) - .collect(), - } - .into(), - ); - - permutation(lhs, rhs) - } else { - // plookup lhs is `flag $ [ operation_id, inputs, outputs ]` - let lhs = selected( - combine_flags(from.instr_flag, from.link_flag), - ArrayLiteral { - items: op_id - .chain(from.params.inputs) - .chain(from.params.outputs) - .collect(), - } - .into(), - ); - - let to_namespace = to.machine.location.clone().to_string(); - let op_id = to - .machine - .operation_id - .map(|oid| namespaced_reference(to_namespace.clone(), oid)) - .into_iter(); - - let latch = namespaced_reference(to_namespace.clone(), to.machine.latch.unwrap()); - - // plookup rhs is `latch $ [ operation_id, inputs, outputs ]` - let rhs = selected( - latch, - ArrayLiteral { - items: op_id - .chain(to.operation.params.inputs_and_outputs().map(|i| { - index_access( - namespaced_reference(to_namespace.clone(), &i.name), - i.index.clone(), - ) - })) - .collect(), - } - .into(), - ); - - lookup(lhs, rhs) - }; - - PilStatement::Expression(SourceRef::unknown(), expr) -} - #[cfg(test)] mod test { use std::{fs, path::PathBuf}; - use powdr_ast::object::MachineInstanceGraph; + use powdr_ast::{object::MachineInstanceGraph, parsed::PILFile}; use powdr_number::{FieldElement, GoldilocksField}; use powdr_analysis::convert_asm_to_pil; @@ -240,7 +458,15 @@ mod test { use pretty_assertions::assert_eq; - use crate::{link, MAX_DEGREE_LOG, MIN_DEGREE_LOG}; + use crate::{MAX_DEGREE_LOG, MIN_DEGREE_LOG}; + + fn link_native(graph: MachineInstanceGraph) -> Result> { + super::link(graph, super::LinkerMode::Native) + } + + fn link_with_bus(graph: MachineInstanceGraph) -> Result> { + super::link(graph, super::LinkerMode::Bus) + } fn parse_analyze_and_compile_file(file: &str) -> MachineInstanceGraph { let contents = fs::read_to_string(file).unwrap(); @@ -270,7 +496,7 @@ mod test { #[test] fn compile_empty_vm() { - let expectation = r#"namespace main(4 + 4); + let native_expectation = r#"namespace main(4 + 4); let _operation_id; query |__i| std::prover::provide_if_unknown(_operation_id, __i, || 2); pol constant _block_enforcer_last_step = [0]* + [1]; @@ -296,10 +522,39 @@ namespace main__rom(4 + 4); pol constant latch = [1]*; "#; + let bus_expectation = r#"namespace main(4 + 4); + let _operation_id; + query |__i| std::prover::provide_if_unknown(_operation_id, __i, || 2); + pol constant _block_enforcer_last_step = [0]* + [1]; + let _operation_id_no_change = (1 - _block_enforcer_last_step) * (1 - instr_return); + _operation_id_no_change * (_operation_id' - _operation_id) = 0; + pol commit pc; + pol commit instr__jump_to_operation; + pol commit instr__reset; + pol commit instr__loop; + pol commit instr_return; + pol constant first_step = [1] + [0]*; + pol commit pc_update; + pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pc' = (1 - first_step') * pc_update; + std::protocols::lookup_via_bus::lookup_send(0, 1 $ [0, pc, instr__jump_to_operation, instr__reset, instr__loop, instr_return] in main__rom::latch $ [main__rom::operation_id, main__rom::p_line, main__rom::p_instr__jump_to_operation, main__rom::p_instr__reset, main__rom::p_instr__loop, main__rom::p_instr_return]); +namespace main__rom(4 + 4); + pol constant p_line = [0, 1, 2] + [2]*; + pol constant p_instr__jump_to_operation = [0, 1, 0] + [0]*; + pol constant p_instr__loop = [0, 0, 1] + [1]*; + pol constant p_instr__reset = [1, 0, 0] + [0]*; + pol constant p_instr_return = [0]*; + pol constant operation_id = [0]*; + pol constant latch = [1]*; + std::protocols::lookup_via_bus::lookup_receive(0, 1 $ [0, main::pc, main::instr__jump_to_operation, main::instr__reset, main::instr__loop, main::instr_return] in main__rom::latch $ [main__rom::operation_id, main__rom::p_line, main__rom::p_instr__jump_to_operation, main__rom::p_instr__reset, main__rom::p_instr__loop, main__rom::p_instr_return]); +"#; + let file_name = "../test_data/asm/empty_vm.asm"; let graph = parse_analyze_and_compile_file::(file_name); - let pil = link(graph).unwrap(); - assert_eq!(extract_main(&format!("{pil}")), expectation); + let pil = link_native(graph.clone()).unwrap(); + assert_eq!(extract_main(&format!("{pil}")), native_expectation); + let pil = link_with_bus(graph).unwrap(); + assert_eq!(extract_main(&format!("{pil}")), bus_expectation); } #[test] @@ -312,7 +567,7 @@ namespace main__rom(4 + 4); ); let graph = parse_analyze_and_compile::(""); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), expectation); } @@ -320,7 +575,7 @@ namespace main__rom(4 + 4); fn compile_pil_without_machine() { let input = " let even = std::array::new(5, |i| 2 * i);"; let graph = parse_analyze_and_compile::(input); - let pil = link(graph).unwrap().to_string(); + let pil = link_native(graph).unwrap().to_string(); assert_eq!(&pil[0..input.len()], input); } @@ -429,7 +684,7 @@ namespace main_sub__rom(16); "#; let file_name = "../test_data/asm/different_signatures.asm"; let graph = parse_analyze_and_compile_file::(file_name); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), expectation); } @@ -511,7 +766,7 @@ namespace main__rom(16); "#; let file_name = "../test_data/asm/simple_sum.asm"; let graph = parse_analyze_and_compile_file::(file_name); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), expectation); } @@ -572,7 +827,7 @@ namespace main__rom(8); pol constant latch = [1]*; "#; let graph = parse_analyze_and_compile::(source); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), expectation); } @@ -592,7 +847,7 @@ machine NegativeForUnsigned { } "#; let graph = parse_analyze_and_compile::(source); - let _ = link(graph); + let _ = link_native(graph); } #[test] @@ -676,7 +931,7 @@ namespace main_vm(64..128); y = x + 5; "#; let graph = parse_analyze_and_compile::(asm); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&(pil.to_string())), expected); } @@ -797,7 +1052,7 @@ namespace main_bin(128); "#; let file_name = "../test_data/asm/permutations/vm_to_block.asm"; let graph = parse_analyze_and_compile_file::(file_name); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), expected); } @@ -954,7 +1209,7 @@ namespace main_submachine(32); "#; let file_name = "../test_data/asm/permutations/link_merging.asm"; let graph = parse_analyze_and_compile_file::(file_name); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), expected); } } diff --git a/pipeline/src/pipeline.rs b/pipeline/src/pipeline.rs index b55340a17..c8e7a4c8f 100644 --- a/pipeline/src/pipeline.rs +++ b/pipeline/src/pipeline.rs @@ -27,6 +27,7 @@ use powdr_executor::{ WitgenCallbackContext, WitnessGenerator, }, }; +pub use powdr_linker::LinkerMode; use powdr_number::{write_polys_csv_file, CsvRenderMode, FieldElement, ReadWrite}; use powdr_schemas::SerializedAnalyzed; @@ -102,6 +103,8 @@ struct Arguments { backend: Option, /// Backend options backend_options: BackendOptions, + /// Backend options + linker_mode: LinkerMode, /// CSV render mode for witness generation. csv_render_mode: CsvRenderMode, /// Whether to export the witness as a CSV file. @@ -335,6 +338,11 @@ impl Pipeline { self.add_query_callback(Arc::new(dict_data_to_query_callback(inputs))) } + pub fn with_linker_mode(mut self, linker_mode: LinkerMode) -> Self { + self.arguments.linker_mode = linker_mode; + self + } + pub fn with_backend(mut self, backend: BackendType, options: Option) -> Self { self.arguments.backend = Some(backend); self.arguments.backend_options = options.unwrap_or_default(); @@ -834,7 +842,7 @@ impl Pipeline { let graph = self.artifact.linked_machine_graph.take().unwrap(); self.log("Run linker"); - let linked = powdr_linker::link(graph)?; + let linked = powdr_linker::link(graph, self.arguments.linker_mode)?; log::trace!("{linked}"); self.maybe_write_pil(&linked, "")?;