Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip dispatcher for main VM #2504

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions asm-to-pil/src/common.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
use powdr_ast::parsed::asm::Instruction;

/// Values which are common to many steps from asm to PIL
use crate::utils::parse_instruction;

/// The name for the `return` keyword in the PIL constraints
pub const RETURN_NAME: &str = "return";
/// The name for the `reset` instruction in the PIL constraints
Expand All @@ -13,7 +8,7 @@ pub fn instruction_flag(name: &str) -> String {
}

/// The names of the output assignment registers for `count` outputs. All `return` statements assign to these.
fn output_registers(count: usize) -> Vec<String> {
pub fn output_registers(count: usize) -> Vec<String> {
(0..count).map(output_at).collect()
}

Expand All @@ -26,11 +21,3 @@ pub fn input_at(i: usize) -> String {
pub fn output_at(i: usize) -> String {
format!("_output_{i}")
}

/// The return instruction for `output_count` outputs and `pc_name` the name of the pc
pub fn return_instruction(output_count: usize, pc_name: &str) -> Instruction {
parse_instruction(&format!(
"{} {{ {pc_name}' = 0 }}",
output_registers(output_count).join(", ")
))
}
15 changes: 13 additions & 2 deletions asm-to-pil/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::collections::BTreeMap;

use powdr_ast::asm_analysis::{AnalysisASMFile, Module, StatementReference, SubmachineDeclaration};
use powdr_ast::{
asm_analysis::{AnalysisASMFile, Module, StatementReference, SubmachineDeclaration},
parsed::asm::{parse_absolute_path, SymbolPath},
};
use powdr_number::FieldElement;
use romgen::generate_machine_rom;
use vm_to_constrained::ROM_SUBMACHINE_NAME;
Expand All @@ -9,9 +12,12 @@ mod romgen;
mod vm_to_constrained;

pub const ROM_SUFFIX: &str = "ROM";
const MAIN_MACHINE: &str = "::Main";

/// Remove all ASM from the machine tree, leaving only constrained machines
pub fn compile<T: FieldElement>(mut file: AnalysisASMFile) -> AnalysisASMFile {
let main_machine_path = parse_absolute_path(MAIN_MACHINE);

for (path, module) in &mut file.modules {
let mut new_machines = BTreeMap::default();
let (mut machines, statements, ordering) = std::mem::take(module).into_inner();
Expand All @@ -21,7 +27,12 @@ pub fn compile<T: FieldElement>(mut file: AnalysisASMFile) -> AnalysisASMFile {
match r {
StatementReference::MachineDeclaration(name) => {
let m = machines.remove(&name).unwrap();
let (m, rom) = generate_machine_rom::<T>(m);
let machine_path =
path.clone().join(SymbolPath::from_identifier(name.clone()));
// A machine is callable if it is not the main machine
// This is used to avoid generating a dispatcher for the main machine, since it's only called once, from the outside, and doesn't return
let is_callable = machine_path != main_machine_path;
let (m, rom) = generate_machine_rom::<T>(m, is_callable);
let (mut m, rom_machine) = vm_to_constrained::convert_machine::<T>(m, rom);

match rom_machine {
Expand Down
96 changes: 75 additions & 21 deletions asm-to-pil/src/romgen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use powdr_ast::parsed::{
use powdr_number::{BigUint, FieldElement};
use powdr_parser_util::SourceRef;

use crate::common::{instruction_flag, RETURN_NAME};
use crate::common::{instruction_flag, output_registers, RETURN_NAME};
use crate::{
common::{input_at, output_at, RESET_NAME},
utils::{
Expand Down Expand Up @@ -55,10 +55,55 @@ fn pad_return_arguments(s: &mut FunctionStatement, output_count: usize) {
};
}

pub fn generate_machine_rom<T: FieldElement>(mut machine: Machine) -> (Machine, Option<Rom>) {
/// Generate the ROM for a machine
/// Arguments:
/// - `machine`: the machine to generate the ROM for
/// - `is_callable`: whether the machine is callable.
/// - If it is, a dispatcher is generated and this machine can be used as a submachine
/// - If it is not, the machine must have a single function which never returns, so that the entire trace can be filled with a single block
pub fn generate_machine_rom<T: FieldElement>(
mut machine: Machine,
is_callable: bool,
) -> (Machine, Option<Rom>) {
if !machine.has_pc() {
// do nothing, there is no rom to be generated
(machine, None)
} else if !is_callable {
let pc = machine.pc().unwrap();
// if the machine is not callable, it must have a single function
assert_eq!(machine.callable.0.len(), 1);
let callable = machine.callable.iter_mut().next().unwrap();
let function = match callable.symbol {
CallableSymbol::Function(ref mut f) => f,
CallableSymbol::Operation(_) => unreachable!(),
};
// the function must have no inputs
assert!(function.params.inputs.is_empty());
// the function must have no outputs
assert!(function.params.outputs.is_empty());
// we implement `return` as an infinite loop
machine.instructions.push({
// `return` is a protected keyword, so we use a dummy name and replace it afterwards
let mut d = parse_instruction_definition(&format!("instr dummy {{ {pc}' = {pc} }}",));
d.name = RETURN_NAME.into();
d
});

let rom = std::mem::take(&mut function.body.statements)
.into_iter_batches()
.collect();

*callable.symbol = OperationSymbol {
source: SourceRef::unknown(),
id: OperationId { id: None },
params: Params {
inputs: vec![],
outputs: vec![],
},
}
.into();

(machine, Some(Rom { statements: rom }))
} else {
// all callables in the machine must be functions
assert!(machine.callable.is_only_functions());
Expand All @@ -67,24 +112,6 @@ pub fn generate_machine_rom<T: FieldElement>(mut machine: Machine) -> (Machine,

let pc = machine.pc().unwrap();

// add the necessary embedded instructions
let embedded_instructions = [
parse_instruction_definition(&format!(
"instr _jump_to_operation {{ {pc}' = {operation_id} }}",
)),
parse_instruction_definition(&format!(
"instr {RESET_NAME} {{ {} }}",
machine
.write_register_names()
.map(|w| format!("{w}' = 0"))
.collect::<Vec<_>>()
.join(", ")
)),
parse_instruction_definition(&format!("instr _loop {{ {pc}' = {pc} }}")),
];

machine.instructions.extend(embedded_instructions);

// generate the rom
// the functions are already batched, we just batch the dispatcher manually here

Expand Down Expand Up @@ -124,6 +151,33 @@ pub fn generate_machine_rom<T: FieldElement>(mut machine: Machine) -> (Machine,
input_assignment_registers_declarations.chain(output_assignment_registers_declarations),
);

// add the necessary embedded instructions
let embedded_instructions = [
parse_instruction_definition(&format!(
"instr _jump_to_operation {{ {pc}' = {operation_id} }}",
)),
parse_instruction_definition(&format!(
"instr {RESET_NAME} {{ {} }}",
machine
.write_register_names()
.map(|w| format!("{w}' = 0"))
.collect::<Vec<_>>()
.join(", ")
)),
parse_instruction_definition(&format!("instr _loop {{ {pc}' = {pc} }}")),
{
// `return` is a protected keyword, so we use a dummy name and replace it afterwards
let mut d = parse_instruction_definition(&format!(
"instr dummy {} {{ {pc}' = 0 }}",
output_registers(output_count).join(", ")
));
d.name = RETURN_NAME.into();
d
},
];

machine.instructions.extend(embedded_instructions);

// turn each function into an operation, setting the operation_id to the current position in the ROM
for callable in machine.callable.iter_mut() {
let operation_id = BigUint::from(rom.len() as u64);
Expand Down Expand Up @@ -264,7 +318,7 @@ mod tests {
let checked = powdr_analysis::machine_check::check(parsed).unwrap();
checked
.into_machines()
.map(|(name, m)| (name, generate_machine_rom::<T>(m)))
.map(|(name, m)| (name, generate_machine_rom::<T>(m, true)))
.collect()
}

Expand Down
50 changes: 7 additions & 43 deletions asm-to-pil/src/vm_to_constrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use powdr_ast::{
RegisterDeclarationStatement, RegisterTy, Rom,
},
parsed::{
self,
asm::{
CallableParams, CallableRef, InstructionBody, InstructionParams, LinkDeclaration,
OperationId, Param, Params,
Expand All @@ -29,20 +28,15 @@ use powdr_number::{BigUint, FieldElement, LargeInt};
use powdr_parser_util::SourceRef;

use crate::{
common::{instruction_flag, return_instruction, RETURN_NAME},
common::{instruction_flag, RETURN_NAME},
utils::parse_pil_statement,
};

pub fn convert_machine<T: FieldElement>(
machine: Machine,
rom: Option<Rom>,
) -> (Machine, Option<Machine>) {
let output_count = machine
.operations()
.map(|f| f.params.outputs.len())
.max()
.unwrap_or_default();
VMConverter::<T>::with_output_count(output_count).convert_machine(machine, rom)
VMConverter::<T>::default().convert_machine(machine, rom)
}

pub enum Input {
Expand Down Expand Up @@ -128,19 +122,10 @@ struct VMConverter<T> {
line_lookup: Vec<(String, String)>,
/// Names of fixed columns that contain the rom.
rom_constant_names: Vec<String>,
/// the maximum number of inputs in all functions
output_count: usize,
_phantom: std::marker::PhantomData<T>,
}

impl<T: FieldElement> VMConverter<T> {
fn with_output_count(output_count: usize) -> Self {
Self {
output_count,
..Default::default()
}
}

fn convert_machine(
mut self,
mut input: Machine,
Expand Down Expand Up @@ -168,16 +153,6 @@ impl<T: FieldElement> VMConverter<T> {
self.handle_instruction_def(&mut input, instr);
}

// introduce `return` instruction
self.handle_instruction_def(
&mut input,
InstructionDefinitionStatement {
source: SourceRef::unknown(),
name: RETURN_NAME.into(),
instruction: self.return_instruction(),
},
);

let assignment_registers = self
.assignment_register_names()
.cloned()
Expand All @@ -202,29 +177,22 @@ impl<T: FieldElement> VMConverter<T> {
use RegisterTy::*;
match reg.ty {
// Force pc to zero on first row.
Pc => {
Pc | Write => {
// introduce an intermediate witness polynomial to keep the degree of polynomial identities at 2
// this may not be optimal for backends which support higher degree constraints
let pc_update_name = format!("{name}_update");
let update_name = format!("{name}_update");
vec![
witness_column(
SourceRef::unknown(),
pc_update_name.clone(),
None,
),
witness_column(SourceRef::unknown(), update_name.clone(), None),
PilStatement::Expression(
SourceRef::unknown(),
build::identity(
direct_reference(pc_update_name.clone()),
rhs,
),
build::identity(direct_reference(update_name.clone()), rhs),
),
PilStatement::Expression(
SourceRef::unknown(),
build::identity(
lhs,
(Expression::from(1) - next_reference("first_step"))
* direct_reference(pc_update_name),
* direct_reference(update_name),
),
),
]
Expand Down Expand Up @@ -1138,10 +1106,6 @@ impl<T: FieldElement> VMConverter<T> {
.filter_map(|(n, r)| r.ty.is_read_only().then_some(n))
}

fn return_instruction(&self) -> parsed::asm::Instruction {
return_instruction(self.output_count, self.pc_name.as_ref().unwrap())
}

/// Return an expression of degree at most 1 whose value matches that of `expr`.
/// Intermediate witness columns can be introduced, with names starting with `prefix` optionally followed by a suffix
/// Suffixes are defined as follows: "", "_1", "_2", "_3" etc.
Expand Down
Loading