From d325a7207988dab9a46a1472d28028fef123cee2 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 11 Nov 2024 18:30:48 +0100 Subject: [PATCH 1/5] introduce linker mode, expose in CLI, refactor linker to create statements in source and destination --- cli/src/main.rs | 13 +- linker/Cargo.toml | 1 + linker/src/lib.rs | 468 ++++++++++++++++++++++++++------------- pipeline/src/pipeline.rs | 10 +- 4 files changed, 335 insertions(+), 157 deletions(-) diff --git a/cli/src/main.rs b/cli/src/main.rs index b426c173d9..ea3a302d2d 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -12,6 +12,7 @@ use powdr::number::{ BabyBearField, BigUint, Bn254Field, FieldElement, GoldilocksField, KoalaBearField, Mersenne31Field, }; +use powdr::pipeline::pipeline::LinkerMode; use powdr::pipeline::test_runner; use powdr::Pipeline; use std::io; @@ -154,6 +155,10 @@ enum Commands { #[arg(long)] backend_options: Option, + /// Linker mode, deciding how to reduce links to constraints. + #[arg(long)] + linker_mode: Option, + /// Generate a CSV file containing the witness column values. #[arg(long)] #[arg(default_value_t = false)] @@ -464,6 +469,7 @@ fn run_command(command: Commands) { prove_with, params, backend_options, + linker_mode, export_witness_csv, export_all_columns_csv, csv_mode, @@ -478,6 +484,7 @@ fn run_command(command: Commands) { prove_with, params, backend_options, + linker_mode, export_witness_csv, export_all_columns_csv, csv_mode @@ -668,6 +675,7 @@ fn run_pil( prove_with: Option, params: Option, backend_options: Option, + linker_mode: Option, export_witness: bool, export_all_columns: bool, csv_mode: CsvRenderModeCLI, @@ -675,7 +683,9 @@ fn run_pil( let inputs = split_inputs::(&inputs); let pipeline = bind_cli_args( - Pipeline::::default().from_file(PathBuf::from(&file)), + Pipeline::::default() + .from_file(PathBuf::from(&file)) + .with_linker_mode(linker_mode.unwrap_or_default()), inputs.clone(), PathBuf::from(output_directory), force, @@ -804,6 +814,7 @@ mod test { prove_with: Some(BackendType::EStarkDump), params: None, backend_options: Some("stark_gl".to_string()), + linker_mode: None, export_witness_csv: false, export_all_columns_csv: true, csv_mode: CsvRenderModeCLI::Hex, diff --git a/linker/Cargo.toml b/linker/Cargo.toml index 5d05b58d3c..a44b2dd5eb 100644 --- a/linker/Cargo.toml +++ b/linker/Cargo.toml @@ -12,6 +12,7 @@ powdr-analysis.workspace = true powdr-ast.workspace = true powdr-number.workspace = true powdr-parser-util.workspace = true +strum = { version = "0.24.1", features = ["derive"] } pretty_assertions = "1.4.0" itertools = "0.13" diff --git a/linker/src/lib.rs b/linker/src/lib.rs index ce3c41d1e4..3e3292c722 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -4,15 +4,16 @@ use lazy_static::lazy_static; use powdr_analysis::utils::parse_pil_statement; use powdr_ast::{ asm_analysis::{combine_flags, MachineDegree}, - object::{Link, Location, MachineInstanceGraph}, + object::{Link, Location, MachineInstanceGraph, Object}, parsed::{ asm::{AbsoluteSymbolPath, SymbolPath}, build::{index_access, lookup, namespaced_reference, permutation, selected}, - ArrayLiteral, Expression, NamespaceDegree, PILFile, PilStatement, + ArrayLiteral, Expression, FunctionCall, NamespaceDegree, PILFile, PilStatement, }, }; use powdr_parser_util::SourceRef; -use std::{collections::BTreeMap, iter::once}; +use std::{collections::BTreeMap, iter::once, str::FromStr}; +use strum::{Display, EnumString, EnumVariantNames}; const MAIN_OPERATION_NAME: &str = "main"; /// The log of the default minimum degree @@ -35,75 +36,326 @@ lazy_static! { }; } -/// Convert a [MachineDegree] into a [NamespaceDegree], setting any unset bounds to the relevant default values -fn to_namespace_degree(d: MachineDegree) -> NamespaceDegree { - NamespaceDegree { - min: d - .min - .unwrap_or_else(|| Expression::from(1 << MIN_DEGREE_LOG)), - max: d - .max - .unwrap_or_else(|| Expression::from(1 << *MAX_DEGREE_LOG)), - } +#[derive(Clone, EnumString, EnumVariantNames, Display, Copy, Default)] +pub enum LinkerMode { + #[default] + Native, + Bus, } -/// The optional degree of the namespace is set to that of the object if it's set, to that of the main object otherwise. -pub fn link(graph: MachineInstanceGraph) -> Result> { - let main_machine = graph.main; - let main_degree = graph - .objects - .get(&main_machine.location) - .unwrap() - .degree - .clone(); +#[derive(Default)] +pub enum DegreesMode { + Monolithic(MachineDegree), + #[default] + Vadcop, +} - let mut pil = process_definitions(graph.statements); +#[derive(Default)] +struct Linker { + mode: LinkerMode, + degrees: DegreesMode, + namespaces: BTreeMap>, + next_interaction_id: u32, +} - for (location, object) in graph.objects.into_iter() { - // create a namespace for this object - let degree = match main_degree.is_static() { - true => main_degree.clone(), - false => object.degree, +impl Linker { + fn new(mode: LinkerMode) -> Self { + Self { + mode, + ..Default::default() + } + } + + fn next_interaction_id(&mut self) -> u32 { + let id = self.next_interaction_id; + self.next_interaction_id += 1; + id + } + + /// The optional degree of the namespace is set to that of the object if it's set, to that of the main object otherwise. + pub fn link(mut self, graph: MachineInstanceGraph) -> Result> { + let main_machine = graph.main; + let main_degree = graph + .objects + .get(&main_machine.location) + .unwrap() + .degree + .clone(); + + // If the main object has a static degree, we use the monolithic mode with that degree. + self.degrees = match main_degree.is_static() { + true => DegreesMode::Monolithic(main_degree), + false => DegreesMode::Vadcop, + }; + + let common_definitions = process_definitions(graph.statements); + + for (location, object) in graph.objects.into_iter() { + self.process_object(location.clone(), object); + + if location == Location::main() { + if let Some(main_operation) = graph + .entry_points + .iter() + .find(|f| f.name == MAIN_OPERATION_NAME) + { + let main_operation_id = main_operation.id.clone(); + let operation_id = main_machine.operation_id.clone(); + match (operation_id, main_operation_id) { + (Some(operation_id), Some(main_operation_id)) => { + // call the main operation by initializing `operation_id` to that of the main operation + let linker_first_step = "_linker_first_step"; + self.namespaces.get_mut(&location.to_string()).unwrap().extend([ + parse_pil_statement(&format!( + "col fixed {linker_first_step}(i) {{ if i == 0 {{ 1 }} else {{ 0 }} }};" + )), + parse_pil_statement(&format!( + "{linker_first_step} * ({operation_id} - {main_operation_id}) = 0;" + )), + ]); + } + (None, None) => {} + _ => unreachable!(), + } + } + } + } + + Ok(PILFile( + common_definitions + .into_iter() + .chain(self.namespaces.into_iter().flat_map(|(_, v)| v.into_iter())) + .collect(), + )) + } + + fn process_object(&mut self, location: Location, object: Object) { + let degree = match &self.degrees { + DegreesMode::Monolithic(d) => d.clone(), + DegreesMode::Vadcop => object.degree, }; + let namespace = location.to_string(); + let namespace_degree = to_namespace_degree(degree); + + let pil = self.namespaces.entry(namespace).or_default(); + + // create a namespace for this object pil.push(PilStatement::Namespace( SourceRef::unknown(), SymbolPath::from_identifier(location.to_string()), - Some(to_namespace_degree(degree)), + Some(namespace_degree), )); pil.extend(object.pil); - pil.extend(object.links.into_iter().map(process_link)); - - if location == Location::main() { - if let Some(main_operation) = graph - .entry_points - .iter() - .find(|f| f.name == MAIN_OPERATION_NAME) - { - let main_operation_id = main_operation.id.clone(); - let operation_id = main_machine.operation_id.clone(); - match (operation_id, main_operation_id) { - (Some(operation_id), Some(main_operation_id)) => { - // call the main operation by initializing `operation_id` to that of the main operation - let linker_first_step = "_linker_first_step"; - pil.extend([ - parse_pil_statement(&format!( - "col fixed {linker_first_step}(i) {{ if i == 0 {{ 1 }} else {{ 0 }} }};" - )), - parse_pil_statement(&format!( - "{linker_first_step} * ({operation_id} - {main_operation_id}) = 0;" - )), - ]); - } - (None, None) => {} - _ => unreachable!(), + for link in object.links { + self.process_link(link, location.clone()); + } + } + + fn process_link(&mut self, link: Link, from_location: Location) { + let from = link.from; + let to = link.to; + + let to_location = to.machine.location.clone(); + + // the lhs is `instr_flag { operation_id, inputs, outputs }` + let op_id = to.operation.id.iter().cloned().map(|n| n.into()); + + if link.is_permutation { + // permutation lhs is `flag { operation_id, inputs, outputs }` + let lhs = selected( + combine_flags(from.instr_flag, from.link_flag), + ArrayLiteral { + items: op_id + .chain(from.params.inputs) + .chain(from.params.outputs) + .collect(), + } + .into(), + ); + + // permutation rhs is `(latch * selector[idx]) { operation_id, inputs, outputs }` + let to_namespace = to.machine.location.clone().to_string(); + let op_id = to + .machine + .operation_id + .map(|oid| namespaced_reference(to_namespace.clone(), oid)) + .into_iter(); + + let latch = namespaced_reference(to_namespace.clone(), to.machine.latch.unwrap()); + let rhs_selector = if let Some(call_selectors) = to.machine.call_selectors { + let call_selector_array = + namespaced_reference(to_namespace.clone(), call_selectors); + let call_selector = + index_access(call_selector_array, Some(to.selector_idx.unwrap().into())); + latch * call_selector + } else { + latch + }; + + let rhs = selected( + rhs_selector, + ArrayLiteral { + items: op_id + .chain(to.operation.params.inputs_and_outputs().map(|i| { + index_access( + namespaced_reference(to_namespace.clone(), &i.name), + i.index.clone(), + ) + })) + .collect(), + } + .into(), + ); + + self.insert_permutation(from_location, to_location, permutation(lhs, rhs)); + } else { + // plookup lhs is `flag $ [ operation_id, inputs, outputs ]` + let lhs = selected( + combine_flags(from.instr_flag, from.link_flag), + ArrayLiteral { + items: op_id + .chain(from.params.inputs) + .chain(from.params.outputs) + .collect(), + } + .into(), + ); + + let to_namespace = to.machine.location.clone().to_string(); + let op_id = to + .machine + .operation_id + .map(|oid| namespaced_reference(to_namespace.clone(), oid)) + .into_iter(); + + let latch = namespaced_reference(to_namespace.clone(), to.machine.latch.unwrap()); + + // plookup rhs is `latch $ [ operation_id, inputs, outputs ]` + let rhs = selected( + latch, + ArrayLiteral { + items: op_id + .chain(to.operation.params.inputs_and_outputs().map(|i| { + index_access( + namespaced_reference(to_namespace.clone(), &i.name), + i.index.clone(), + ) + })) + .collect(), } + .into(), + ); + + self.insert_lookup(from_location, to_location, lookup(lhs, rhs)); + }; + } + + fn insert_lookup(&mut self, from: Location, to: Location, lookup: Expression) { + let interaction_id = self.next_interaction_id(); + + match self.mode { + LinkerMode::Native => { + self.namespaces + .entry(from.to_string()) + .or_default() + .push(PilStatement::Expression(SourceRef::unknown(), lookup)); + } + LinkerMode::Bus => { + self.namespaces.entry(from.to_string()).or_default().push( + PilStatement::Expression( + SourceRef::unknown(), + call_pil_link_function( + "std::protocols::lookup_via_bus::lookup_send", + lookup.clone(), + interaction_id, + ), + ), + ); + self.namespaces + .entry(to.to_string()) + .or_default() + .push(PilStatement::Expression( + SourceRef::unknown(), + call_pil_link_function( + "std::protocols::lookup_via_bus::lookup_receive", + lookup, + interaction_id, + ), + )); + } + } + } + + fn insert_permutation(&mut self, from: Location, to: Location, lookup: Expression) { + let interaction_id = self.next_interaction_id(); + + match self.mode { + LinkerMode::Native => { + self.namespaces + .entry(from.to_string()) + .or_default() + .push(PilStatement::Expression(SourceRef::unknown(), lookup)); + } + LinkerMode::Bus => { + self.namespaces.entry(from.to_string()).or_default().push( + PilStatement::Expression( + SourceRef::unknown(), + call_pil_link_function( + "std::protocols::permmutation_via_bus::permmutation_send", + lookup.clone(), + interaction_id, + ), + ), + ); + self.namespaces + .entry(to.to_string()) + .or_default() + .push(PilStatement::Expression( + SourceRef::unknown(), + call_pil_link_function( + "std::protocols::permmutation_via_bus::permmutation_receive", + lookup, + interaction_id, + ), + )); } } } +} + +fn call_pil_link_function( + function: &str, + constraint: Expression, + interaction_id: u32, +) -> Expression { + Expression::FunctionCall( + SourceRef::unknown(), + FunctionCall { + function: Box::new(Expression::Reference( + SourceRef::unknown(), + SymbolPath::from_str(function).unwrap().into(), + )), + arguments: vec![interaction_id.into(), constraint], + }, + ) +} + +/// Convert a [MachineDegree] into a [NamespaceDegree], setting any unset bounds to the relevant default values +fn to_namespace_degree(d: MachineDegree) -> NamespaceDegree { + NamespaceDegree { + min: d + .min + .unwrap_or_else(|| Expression::from(1 << MIN_DEGREE_LOG)), + max: d + .max + .unwrap_or_else(|| Expression::from(1 << *MAX_DEGREE_LOG)), + } +} - Ok(PILFile(pil)) +pub fn link(graph: MachineInstanceGraph, mode: LinkerMode) -> Result> { + Linker::new(mode).link(graph) } // Extract the utilities and sort them into namespaces where possible. @@ -130,109 +382,11 @@ fn process_definitions( .collect() } -fn process_link(link: Link) -> PilStatement { - let from = link.from; - let to = link.to; - - // the lhs is `instr_flag { operation_id, inputs, outputs }` - let op_id = to.operation.id.iter().cloned().map(|n| n.into()); - - let expr = if link.is_permutation { - // permutation lhs is `flag { operation_id, inputs, outputs }` - let lhs = selected( - combine_flags(from.instr_flag, from.link_flag), - ArrayLiteral { - items: op_id - .chain(from.params.inputs) - .chain(from.params.outputs) - .collect(), - } - .into(), - ); - - // permutation rhs is `(latch * selector[idx]) { operation_id, inputs, outputs }` - let to_namespace = to.machine.location.clone().to_string(); - let op_id = to - .machine - .operation_id - .map(|oid| namespaced_reference(to_namespace.clone(), oid)) - .into_iter(); - - let latch = namespaced_reference(to_namespace.clone(), to.machine.latch.unwrap()); - let rhs_selector = if let Some(call_selectors) = to.machine.call_selectors { - let call_selector_array = namespaced_reference(to_namespace.clone(), call_selectors); - let call_selector = - index_access(call_selector_array, Some(to.selector_idx.unwrap().into())); - latch * call_selector - } else { - latch - }; - - let rhs = selected( - rhs_selector, - ArrayLiteral { - items: op_id - .chain(to.operation.params.inputs_and_outputs().map(|i| { - index_access( - namespaced_reference(to_namespace.clone(), &i.name), - i.index.clone(), - ) - })) - .collect(), - } - .into(), - ); - - permutation(lhs, rhs) - } else { - // plookup lhs is `flag $ [ operation_id, inputs, outputs ]` - let lhs = selected( - combine_flags(from.instr_flag, from.link_flag), - ArrayLiteral { - items: op_id - .chain(from.params.inputs) - .chain(from.params.outputs) - .collect(), - } - .into(), - ); - - let to_namespace = to.machine.location.clone().to_string(); - let op_id = to - .machine - .operation_id - .map(|oid| namespaced_reference(to_namespace.clone(), oid)) - .into_iter(); - - let latch = namespaced_reference(to_namespace.clone(), to.machine.latch.unwrap()); - - // plookup rhs is `latch $ [ operation_id, inputs, outputs ]` - let rhs = selected( - latch, - ArrayLiteral { - items: op_id - .chain(to.operation.params.inputs_and_outputs().map(|i| { - index_access( - namespaced_reference(to_namespace.clone(), &i.name), - i.index.clone(), - ) - })) - .collect(), - } - .into(), - ); - - lookup(lhs, rhs) - }; - - PilStatement::Expression(SourceRef::unknown(), expr) -} - #[cfg(test)] mod test { use std::{fs, path::PathBuf}; - use powdr_ast::object::MachineInstanceGraph; + use powdr_ast::{object::MachineInstanceGraph, parsed::PILFile}; use powdr_number::{FieldElement, GoldilocksField}; use powdr_analysis::convert_asm_to_pil; @@ -240,7 +394,11 @@ mod test { use pretty_assertions::assert_eq; - use crate::{link, MAX_DEGREE_LOG, MIN_DEGREE_LOG}; + use crate::{MAX_DEGREE_LOG, MIN_DEGREE_LOG}; + + fn link(graph: MachineInstanceGraph) -> Result> { + super::link(graph, super::LinkerMode::Native) + } fn parse_analyze_and_compile_file(file: &str) -> MachineInstanceGraph { let contents = fs::read_to_string(file).unwrap(); diff --git a/pipeline/src/pipeline.rs b/pipeline/src/pipeline.rs index b55340a172..c8e7a4c8f1 100644 --- a/pipeline/src/pipeline.rs +++ b/pipeline/src/pipeline.rs @@ -27,6 +27,7 @@ use powdr_executor::{ WitgenCallbackContext, WitnessGenerator, }, }; +pub use powdr_linker::LinkerMode; use powdr_number::{write_polys_csv_file, CsvRenderMode, FieldElement, ReadWrite}; use powdr_schemas::SerializedAnalyzed; @@ -102,6 +103,8 @@ struct Arguments { backend: Option, /// Backend options backend_options: BackendOptions, + /// Backend options + linker_mode: LinkerMode, /// CSV render mode for witness generation. csv_render_mode: CsvRenderMode, /// Whether to export the witness as a CSV file. @@ -335,6 +338,11 @@ impl Pipeline { self.add_query_callback(Arc::new(dict_data_to_query_callback(inputs))) } + pub fn with_linker_mode(mut self, linker_mode: LinkerMode) -> Self { + self.arguments.linker_mode = linker_mode; + self + } + pub fn with_backend(mut self, backend: BackendType, options: Option) -> Self { self.arguments.backend = Some(backend); self.arguments.backend_options = options.unwrap_or_default(); @@ -834,7 +842,7 @@ impl Pipeline { let graph = self.artifact.linked_machine_graph.take().unwrap(); self.log("Run linker"); - let linked = powdr_linker::link(graph)?; + let linked = powdr_linker::link(graph, self.arguments.linker_mode)?; log::trace!("{linked}"); self.maybe_write_pil(&linked, "")?; From ddc50d6fd5372234ad10b44e7acd6b3f0cd3ea68 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 11 Nov 2024 18:59:30 +0100 Subject: [PATCH 2/5] define cli args names, make sure links come last --- linker/src/lib.rs | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/linker/src/lib.rs b/linker/src/lib.rs index 3e3292c722..7121487563 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -39,7 +39,9 @@ lazy_static! { #[derive(Clone, EnumString, EnumVariantNames, Display, Copy, Default)] pub enum LinkerMode { #[default] + #[strum(serialize = "native")] Native, + #[strum(serialize = "bus")] Bus, } @@ -54,7 +56,8 @@ pub enum DegreesMode { struct Linker { mode: LinkerMode, degrees: DegreesMode, - namespaces: BTreeMap>, + /// for each namespace, we store the statements resulting from processing the links separatly, because we need to make sure they do not come first. + namespaces: BTreeMap, Vec)>, next_interaction_id: u32, } @@ -105,7 +108,7 @@ impl Linker { (Some(operation_id), Some(main_operation_id)) => { // call the main operation by initializing `operation_id` to that of the main operation let linker_first_step = "_linker_first_step"; - self.namespaces.get_mut(&location.to_string()).unwrap().extend([ + self.namespaces.get_mut(&location.to_string()).unwrap().0.extend([ parse_pil_statement(&format!( "col fixed {linker_first_step}(i) {{ if i == 0 {{ 1 }} else {{ 0 }} }};" )), @@ -124,7 +127,11 @@ impl Linker { Ok(PILFile( common_definitions .into_iter() - .chain(self.namespaces.into_iter().flat_map(|(_, v)| v.into_iter())) + .chain( + self.namespaces + .into_iter() + .flat_map(|(_, (statements, links))| statements.into_iter().chain(links)), + ) .collect(), )) } @@ -138,7 +145,7 @@ impl Linker { let namespace = location.to_string(); let namespace_degree = to_namespace_degree(degree); - let pil = self.namespaces.entry(namespace).or_default(); + let (pil, _) = self.namespaces.entry(namespace).or_default(); // create a namespace for this object pil.push(PilStatement::Namespace( @@ -260,10 +267,11 @@ impl Linker { self.namespaces .entry(from.to_string()) .or_default() + .1 .push(PilStatement::Expression(SourceRef::unknown(), lookup)); } LinkerMode::Bus => { - self.namespaces.entry(from.to_string()).or_default().push( + self.namespaces.entry(from.to_string()).or_default().1.push( PilStatement::Expression( SourceRef::unknown(), call_pil_link_function( @@ -273,17 +281,16 @@ impl Linker { ), ), ); - self.namespaces - .entry(to.to_string()) - .or_default() - .push(PilStatement::Expression( + self.namespaces.entry(to.to_string()).or_default().1.push( + PilStatement::Expression( SourceRef::unknown(), call_pil_link_function( "std::protocols::lookup_via_bus::lookup_receive", lookup, interaction_id, ), - )); + ), + ); } } } @@ -296,10 +303,11 @@ impl Linker { self.namespaces .entry(from.to_string()) .or_default() + .1 .push(PilStatement::Expression(SourceRef::unknown(), lookup)); } LinkerMode::Bus => { - self.namespaces.entry(from.to_string()).or_default().push( + self.namespaces.entry(from.to_string()).or_default().1.push( PilStatement::Expression( SourceRef::unknown(), call_pil_link_function( @@ -309,17 +317,16 @@ impl Linker { ), ), ); - self.namespaces - .entry(to.to_string()) - .or_default() - .push(PilStatement::Expression( + self.namespaces.entry(to.to_string()).or_default().1.push( + PilStatement::Expression( SourceRef::unknown(), call_pil_link_function( "std::protocols::permmutation_via_bus::permmutation_receive", lookup, interaction_id, ), - )); + ), + ); } } } From 729b6ac11821ae1d2333b89226bad1e6b02318ae Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 12 Nov 2024 16:57:31 +0100 Subject: [PATCH 3/5] use absolute names --- linker/src/lib.rs | 196 +++++++++++++++++++++++++++++++--------------- 1 file changed, 134 insertions(+), 62 deletions(-) diff --git a/linker/src/lib.rs b/linker/src/lib.rs index 7121487563..3cacc9d0ea 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -6,13 +6,14 @@ use powdr_ast::{ asm_analysis::{combine_flags, MachineDegree}, object::{Link, Location, MachineInstanceGraph, Object}, parsed::{ - asm::{AbsoluteSymbolPath, SymbolPath}, + asm::{AbsoluteSymbolPath, Part, SymbolPath}, build::{index_access, lookup, namespaced_reference, permutation, selected}, + visitor::{ExpressionVisitable, VisitOrder}, ArrayLiteral, Expression, FunctionCall, NamespaceDegree, PILFile, PilStatement, }, }; use powdr_parser_util::SourceRef; -use std::{collections::BTreeMap, iter::once, str::FromStr}; +use std::{collections::BTreeMap, iter::once, ops::ControlFlow, str::FromStr}; use strum::{Display, EnumString, EnumVariantNames}; const MAIN_OPERATION_NAME: &str = "main"; @@ -108,7 +109,7 @@ impl Linker { (Some(operation_id), Some(main_operation_id)) => { // call the main operation by initializing `operation_id` to that of the main operation let linker_first_step = "_linker_first_step"; - self.namespaces.get_mut(&location.to_string()).unwrap().0.extend([ + self.namespaces.get_mut(&location.to_string()).unwrap().1.extend([ parse_pil_statement(&format!( "col fixed {linker_first_step}(i) {{ if i == 0 {{ 1 }} else {{ 0 }} }};" )), @@ -216,7 +217,13 @@ impl Linker { .into(), ); - self.insert_permutation(from_location, to_location, permutation(lhs, rhs)); + self.insert_identity( + IdentityType::Permutation, + from_location, + to_location, + lhs, + rhs, + ); } else { // plookup lhs is `flag $ [ operation_id, inputs, outputs ]` let lhs = selected( @@ -255,74 +262,46 @@ impl Linker { .into(), ); - self.insert_lookup(from_location, to_location, lookup(lhs, rhs)); + self.insert_identity(IdentityType::Lookup, from_location, to_location, lhs, rhs); }; } - fn insert_lookup(&mut self, from: Location, to: Location, lookup: Expression) { + fn insert_identity( + &mut self, + identity_type: IdentityType, + from: Location, + to: Location, + lhs: Expression, + rhs: Expression, + ) { let interaction_id = self.next_interaction_id(); match self.mode { LinkerMode::Native => { - self.namespaces - .entry(from.to_string()) - .or_default() - .1 - .push(PilStatement::Expression(SourceRef::unknown(), lookup)); - } - LinkerMode::Bus => { self.namespaces.entry(from.to_string()).or_default().1.push( PilStatement::Expression( SourceRef::unknown(), - call_pil_link_function( - "std::protocols::lookup_via_bus::lookup_send", - lookup.clone(), - interaction_id, - ), - ), - ); - self.namespaces.entry(to.to_string()).or_default().1.push( - PilStatement::Expression( - SourceRef::unknown(), - call_pil_link_function( - "std::protocols::lookup_via_bus::lookup_receive", - lookup, - interaction_id, - ), + match identity_type { + IdentityType::Lookup => lookup(lhs, rhs), + IdentityType::Permutation => permutation(lhs, rhs), + }, ), ); } - } - } - - fn insert_permutation(&mut self, from: Location, to: Location, lookup: Expression) { - let interaction_id = self.next_interaction_id(); - - match self.mode { - LinkerMode::Native => { - self.namespaces - .entry(from.to_string()) - .or_default() - .1 - .push(PilStatement::Expression(SourceRef::unknown(), lookup)); - } LinkerMode::Bus => { self.namespaces.entry(from.to_string()).or_default().1.push( PilStatement::Expression( SourceRef::unknown(), - call_pil_link_function( - "std::protocols::permmutation_via_bus::permmutation_send", - lookup.clone(), - interaction_id, - ), + send(identity_type, lhs.clone(), rhs.clone(), interaction_id), ), ); self.namespaces.entry(to.to_string()).or_default().1.push( PilStatement::Expression( SourceRef::unknown(), - call_pil_link_function( - "std::protocols::permmutation_via_bus::permmutation_receive", - lookup, + receive( + identity_type, + namespaced_expression(from.to_string(), lhs), + rhs, interaction_id, ), ), @@ -332,19 +311,68 @@ impl Linker { } } -fn call_pil_link_function( - function: &str, - constraint: Expression, +#[derive(Clone, Copy)] +enum IdentityType { + Lookup, + Permutation, +} + +fn send( + identity_type: IdentityType, + lhs: Expression, + rhs: Expression, + interaction_id: u32, +) -> Expression { + let (function, identity) = match identity_type { + IdentityType::Lookup => ( + SymbolPath::from_str("std::protocols::lookup_via_bus::lookup_send") + .unwrap() + .into(), + lookup(lhs, rhs), + ), + IdentityType::Permutation => ( + SymbolPath::from_str("std::protocols::permutation_via_bus::permutation_send") + .unwrap() + .into(), + permutation(lhs, rhs), + ), + }; + + Expression::FunctionCall( + SourceRef::unknown(), + FunctionCall { + function: Box::new(Expression::Reference(SourceRef::unknown(), function)), + arguments: vec![interaction_id.into(), identity], + }, + ) +} + +fn receive( + identity_type: IdentityType, + lhs: Expression, + rhs: Expression, interaction_id: u32, ) -> Expression { + let (function, identity) = match identity_type { + IdentityType::Lookup => ( + SymbolPath::from_str("std::protocols::lookup_via_bus::lookup_receive") + .unwrap() + .into(), + lookup(lhs, rhs), + ), + IdentityType::Permutation => ( + SymbolPath::from_str("std::protocols::permutation_via_bus::permutation_receive") + .unwrap() + .into(), + permutation(lhs, rhs), + ), + }; + Expression::FunctionCall( SourceRef::unknown(), FunctionCall { - function: Box::new(Expression::Reference( - SourceRef::unknown(), - SymbolPath::from_str(function).unwrap().into(), - )), - arguments: vec![interaction_id.into(), constraint], + function: Box::new(Expression::Reference(SourceRef::unknown(), function)), + arguments: vec![interaction_id.into(), identity], }, ) } @@ -361,6 +389,21 @@ fn to_namespace_degree(d: MachineDegree) -> NamespaceDegree { } } +fn namespaced_expression(namespace: String, mut expr: Expression) -> Expression { + expr.visit_expressions_mut( + &mut |expr| { + if let Expression::Reference(_, refs) = expr { + refs.path = SymbolPath::from_parts( + once(Part::Named(namespace.clone())).chain(refs.path.clone().into_parts()), + ); + } + ControlFlow::Continue::<(), _>(()) + }, + VisitOrder::Pre, + ); + expr +} + pub fn link(graph: MachineInstanceGraph, mode: LinkerMode) -> Result> { Linker::new(mode).link(graph) } @@ -401,7 +444,7 @@ mod test { use pretty_assertions::assert_eq; - use crate::{MAX_DEGREE_LOG, MIN_DEGREE_LOG}; + use crate::{LinkerMode, MAX_DEGREE_LOG, MIN_DEGREE_LOG}; fn link(graph: MachineInstanceGraph) -> Result> { super::link(graph, super::LinkerMode::Native) @@ -435,7 +478,7 @@ mod test { #[test] fn compile_empty_vm() { - let expectation = r#"namespace main(4 + 4); + let native_expectation = r#"namespace main(4 + 4); let _operation_id; query |__i| std::prover::provide_if_unknown(_operation_id, __i, || 2); pol constant _block_enforcer_last_step = [0]* + [1]; @@ -461,10 +504,39 @@ namespace main__rom(4 + 4); pol constant latch = [1]*; "#; + let bus_expectation = r#"namespace main(4 + 4); + let _operation_id; + query |__i| std::prover::provide_if_unknown(_operation_id, __i, || 2); + pol constant _block_enforcer_last_step = [0]* + [1]; + let _operation_id_no_change = (1 - _block_enforcer_last_step) * (1 - instr_return); + _operation_id_no_change * (_operation_id' - _operation_id) = 0; + pol commit pc; + pol commit instr__jump_to_operation; + pol commit instr__reset; + pol commit instr__loop; + pol commit instr_return; + pol constant first_step = [1] + [0]*; + pol commit pc_update; + pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pc' = (1 - first_step') * pc_update; + std::protocols::lookup_via_bus::lookup_send(0, 1 $ [0, pc, instr__jump_to_operation, instr__reset, instr__loop, instr_return] in main__rom::latch $ [main__rom::operation_id, main__rom::p_line, main__rom::p_instr__jump_to_operation, main__rom::p_instr__reset, main__rom::p_instr__loop, main__rom::p_instr_return]); +namespace main__rom(4 + 4); + pol constant p_line = [0, 1, 2] + [2]*; + pol constant p_instr__jump_to_operation = [0, 1, 0] + [0]*; + pol constant p_instr__loop = [0, 0, 1] + [1]*; + pol constant p_instr__reset = [1, 0, 0] + [0]*; + pol constant p_instr_return = [0]*; + pol constant operation_id = [0]*; + pol constant latch = [1]*; + std::protocols::lookup_via_bus::lookup_receive(0, 1 $ [0, main::pc, main::instr__jump_to_operation, main::instr__reset, main::instr__loop, main::instr_return] in main__rom::latch $ [main__rom::operation_id, main__rom::p_line, main__rom::p_instr__jump_to_operation, main__rom::p_instr__reset, main__rom::p_instr__loop, main__rom::p_instr_return]); +"#; + let file_name = "../test_data/asm/empty_vm.asm"; let graph = parse_analyze_and_compile_file::(file_name); - let pil = link(graph).unwrap(); - assert_eq!(extract_main(&format!("{pil}")), expectation); + let pil = link(graph.clone()).unwrap(); + assert_eq!(extract_main(&format!("{pil}")), native_expectation); + let pil = super::link(graph, LinkerMode::Bus).unwrap(); + assert_eq!(extract_main(&format!("{pil}")), bus_expectation); } #[test] From 0c66ecd3f00a98b0192cc593c15b8db8059d6a55 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 13 Nov 2024 11:17:47 +0100 Subject: [PATCH 4/5] clean --- linker/src/lib.rs | 146 ++++++++++++++++++++++++++-------------------- 1 file changed, 82 insertions(+), 64 deletions(-) diff --git a/linker/src/lib.rs b/linker/src/lib.rs index 3cacc9d0ea..2bf9a36c2c 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -37,7 +37,13 @@ lazy_static! { }; } +/// Link the objects into a single PIL file, using the specified mode. +pub fn link(graph: MachineInstanceGraph, mode: LinkerMode) -> Result> { + Linker::new(mode).link(graph) +} + #[derive(Clone, EnumString, EnumVariantNames, Display, Copy, Default)] +/// Whether to link the machines natively or via a global bus. pub enum LinkerMode { #[default] #[strum(serialize = "native")] @@ -47,7 +53,8 @@ pub enum LinkerMode { } #[derive(Default)] -pub enum DegreesMode { +/// Whether to align the degrees of all machines to the main machine, or to use the degrees of the individual machines. +pub enum DegreeMode { Monolithic(MachineDegree), #[default] Vadcop, @@ -56,7 +63,7 @@ pub enum DegreesMode { #[derive(Default)] struct Linker { mode: LinkerMode, - degrees: DegreesMode, + degrees: DegreeMode, /// for each namespace, we store the statements resulting from processing the links separatly, because we need to make sure they do not come first. namespaces: BTreeMap, Vec)>, next_interaction_id: u32, @@ -76,8 +83,7 @@ impl Linker { id } - /// The optional degree of the namespace is set to that of the object if it's set, to that of the main object otherwise. - pub fn link(mut self, graph: MachineInstanceGraph) -> Result> { + fn link(mut self, graph: MachineInstanceGraph) -> Result> { let main_machine = graph.main; let main_degree = graph .objects @@ -88,13 +94,13 @@ impl Linker { // If the main object has a static degree, we use the monolithic mode with that degree. self.degrees = match main_degree.is_static() { - true => DegreesMode::Monolithic(main_degree), - false => DegreesMode::Vadcop, + true => DegreeMode::Monolithic(main_degree), + false => DegreeMode::Vadcop, }; let common_definitions = process_definitions(graph.statements); - for (location, object) in graph.objects.into_iter() { + for (location, object) in graph.objects { self.process_object(location.clone(), object); if location == Location::main() { @@ -139,33 +145,33 @@ impl Linker { fn process_object(&mut self, location: Location, object: Object) { let degree = match &self.degrees { - DegreesMode::Monolithic(d) => d.clone(), - DegreesMode::Vadcop => object.degree, + DegreeMode::Monolithic(d) => d.clone(), + DegreeMode::Vadcop => object.degree, }; let namespace = location.to_string(); let namespace_degree = to_namespace_degree(degree); - let (pil, _) = self.namespaces.entry(namespace).or_default(); + let (pil, _) = self.namespaces.entry(namespace.clone()).or_default(); // create a namespace for this object pil.push(PilStatement::Namespace( SourceRef::unknown(), - SymbolPath::from_identifier(location.to_string()), + SymbolPath::from_identifier(namespace.clone()), Some(namespace_degree), )); pil.extend(object.pil); for link in object.links { - self.process_link(link, location.clone()); + self.process_link(link, namespace.clone()); } } - fn process_link(&mut self, link: Link, from_location: Location) { + fn process_link(&mut self, link: Link, from_namespace: String) { let from = link.from; let to = link.to; - let to_location = to.machine.location.clone(); + let to_namespace = to.machine.location.clone().to_string(); // the lhs is `instr_flag { operation_id, inputs, outputs }` let op_id = to.operation.id.iter().cloned().map(|n| n.into()); @@ -184,7 +190,6 @@ impl Linker { ); // permutation rhs is `(latch * selector[idx]) { operation_id, inputs, outputs }` - let to_namespace = to.machine.location.clone().to_string(); let op_id = to .machine .operation_id @@ -217,10 +222,10 @@ impl Linker { .into(), ); - self.insert_identity( - IdentityType::Permutation, - from_location, - to_location, + self.insert_interaction( + InteractionType::Permutation, + from_namespace, + to_namespace, lhs, rhs, ); @@ -237,7 +242,6 @@ impl Linker { .into(), ); - let to_namespace = to.machine.location.clone().to_string(); let op_id = to .machine .operation_id @@ -262,75 +266,89 @@ impl Linker { .into(), ); - self.insert_identity(IdentityType::Lookup, from_location, to_location, lhs, rhs); + self.insert_interaction( + InteractionType::Lookup, + from_namespace, + to_namespace, + lhs, + rhs, + ); }; } - fn insert_identity( + fn insert_interaction( &mut self, - identity_type: IdentityType, - from: Location, - to: Location, + interaction_type: InteractionType, + from_namespace: String, + to_namespace: String, lhs: Expression, rhs: Expression, ) { + // get a new unique interaction id let interaction_id = self.next_interaction_id(); match self.mode { LinkerMode::Native => { - self.namespaces.entry(from.to_string()).or_default().1.push( + self.namespaces.entry(from_namespace).or_default().1.push( PilStatement::Expression( SourceRef::unknown(), - match identity_type { - IdentityType::Lookup => lookup(lhs, rhs), - IdentityType::Permutation => permutation(lhs, rhs), + match interaction_type { + InteractionType::Lookup => lookup(lhs, rhs), + InteractionType::Permutation => permutation(lhs, rhs), }, ), ); } LinkerMode::Bus => { - self.namespaces.entry(from.to_string()).or_default().1.push( - PilStatement::Expression( + // send in the origin + self.namespaces + .entry(from_namespace.clone()) + .or_default() + .1 + .push(PilStatement::Expression( SourceRef::unknown(), - send(identity_type, lhs.clone(), rhs.clone(), interaction_id), - ), - ); - self.namespaces.entry(to.to_string()).or_default().1.push( - PilStatement::Expression( + send(interaction_type, lhs.clone(), rhs.clone(), interaction_id), + )); + + // receive in the destination + self.namespaces + .entry(to_namespace) + .or_default() + .1 + .push(PilStatement::Expression( SourceRef::unknown(), receive( - identity_type, - namespaced_expression(from.to_string(), lhs), + interaction_type, + namespaced_expression(from_namespace, lhs), rhs, interaction_id, ), - ), - ); + )); } } } } #[derive(Clone, Copy)] -enum IdentityType { +enum InteractionType { Lookup, Permutation, } fn send( - identity_type: IdentityType, + identity_type: InteractionType, lhs: Expression, rhs: Expression, interaction_id: u32, ) -> Expression { let (function, identity) = match identity_type { - IdentityType::Lookup => ( + InteractionType::Lookup => ( SymbolPath::from_str("std::protocols::lookup_via_bus::lookup_send") .unwrap() .into(), lookup(lhs, rhs), ), - IdentityType::Permutation => ( + InteractionType::Permutation => ( SymbolPath::from_str("std::protocols::permutation_via_bus::permutation_send") .unwrap() .into(), @@ -348,19 +366,19 @@ fn send( } fn receive( - identity_type: IdentityType, + identity_type: InteractionType, lhs: Expression, rhs: Expression, interaction_id: u32, ) -> Expression { let (function, identity) = match identity_type { - IdentityType::Lookup => ( + InteractionType::Lookup => ( SymbolPath::from_str("std::protocols::lookup_via_bus::lookup_receive") .unwrap() .into(), lookup(lhs, rhs), ), - IdentityType::Permutation => ( + InteractionType::Permutation => ( SymbolPath::from_str("std::protocols::permutation_via_bus::permutation_receive") .unwrap() .into(), @@ -404,10 +422,6 @@ fn namespaced_expression(namespace: String, mut expr: Expression) -> Expression expr } -pub fn link(graph: MachineInstanceGraph, mode: LinkerMode) -> Result> { - Linker::new(mode).link(graph) -} - // Extract the utilities and sort them into namespaces where possible. fn process_definitions( mut definitions: BTreeMap>, @@ -444,12 +458,16 @@ mod test { use pretty_assertions::assert_eq; - use crate::{LinkerMode, MAX_DEGREE_LOG, MIN_DEGREE_LOG}; + use crate::{MAX_DEGREE_LOG, MIN_DEGREE_LOG}; - fn link(graph: MachineInstanceGraph) -> Result> { + fn link_native(graph: MachineInstanceGraph) -> Result> { super::link(graph, super::LinkerMode::Native) } + fn link_with_bus(graph: MachineInstanceGraph) -> Result> { + super::link(graph, super::LinkerMode::Bus) + } + fn parse_analyze_and_compile_file(file: &str) -> MachineInstanceGraph { let contents = fs::read_to_string(file).unwrap(); let parsed = parse_asm(Some(file), &contents).unwrap_or_else(|e| { @@ -533,9 +551,9 @@ namespace main__rom(4 + 4); let file_name = "../test_data/asm/empty_vm.asm"; let graph = parse_analyze_and_compile_file::(file_name); - let pil = link(graph.clone()).unwrap(); + let pil = link_native(graph.clone()).unwrap(); assert_eq!(extract_main(&format!("{pil}")), native_expectation); - let pil = super::link(graph, LinkerMode::Bus).unwrap(); + let pil = link_with_bus(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), bus_expectation); } @@ -549,7 +567,7 @@ namespace main__rom(4 + 4); ); let graph = parse_analyze_and_compile::(""); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), expectation); } @@ -557,7 +575,7 @@ namespace main__rom(4 + 4); fn compile_pil_without_machine() { let input = " let even = std::array::new(5, |i| 2 * i);"; let graph = parse_analyze_and_compile::(input); - let pil = link(graph).unwrap().to_string(); + let pil = link_native(graph).unwrap().to_string(); assert_eq!(&pil[0..input.len()], input); } @@ -666,7 +684,7 @@ namespace main_sub__rom(16); "#; let file_name = "../test_data/asm/different_signatures.asm"; let graph = parse_analyze_and_compile_file::(file_name); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), expectation); } @@ -748,7 +766,7 @@ namespace main__rom(16); "#; let file_name = "../test_data/asm/simple_sum.asm"; let graph = parse_analyze_and_compile_file::(file_name); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), expectation); } @@ -809,7 +827,7 @@ namespace main__rom(8); pol constant latch = [1]*; "#; let graph = parse_analyze_and_compile::(source); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), expectation); } @@ -829,7 +847,7 @@ machine NegativeForUnsigned { } "#; let graph = parse_analyze_and_compile::(source); - let _ = link(graph); + let _ = link_native(graph); } #[test] @@ -913,7 +931,7 @@ namespace main_vm(64..128); y = x + 5; "#; let graph = parse_analyze_and_compile::(asm); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&(pil.to_string())), expected); } @@ -1034,7 +1052,7 @@ namespace main_bin(128); "#; let file_name = "../test_data/asm/permutations/vm_to_block.asm"; let graph = parse_analyze_and_compile_file::(file_name); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), expected); } @@ -1191,7 +1209,7 @@ namespace main_submachine(32); "#; let file_name = "../test_data/asm/permutations/link_merging.asm"; let graph = parse_analyze_and_compile_file::(file_name); - let pil = link(graph).unwrap(); + let pil = link_native(graph).unwrap(); assert_eq!(extract_main(&format!("{pil}")), expected); } } From 1960eb588c9f39fd24159621cb0e4927dc962916 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 14 Nov 2024 17:26:06 +0100 Subject: [PATCH 5/5] typo --- linker/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linker/src/lib.rs b/linker/src/lib.rs index 2bf9a36c2c..f644eb0879 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -64,7 +64,7 @@ pub enum DegreeMode { struct Linker { mode: LinkerMode, degrees: DegreeMode, - /// for each namespace, we store the statements resulting from processing the links separatly, because we need to make sure they do not come first. + /// for each namespace, we store the statements resulting from processing the links separately, because we need to make sure they do not come first. namespaces: BTreeMap, Vec)>, next_interaction_id: u32, }