From ea9f4f6ca22a7dc7719d0e96bf051e35e0dd7aa1 Mon Sep 17 00:00:00 2001 From: sragss Date: Thu, 9 May 2024 15:27:14 -0700 Subject: [PATCH 1/3] almost working with const fn; BinaryField::mul has no constant analog --- jolt-core/src/field/ark.rs | 4 ++ jolt-core/src/field/binius.rs | 104 +++++++++++++++++++++++++++++++++- jolt-core/src/field/mod.rs | 1 + jolt-core/src/lib.rs | 6 ++ 4 files changed, 113 insertions(+), 2 deletions(-) diff --git a/jolt-core/src/field/ark.rs b/jolt-core/src/field/ark.rs index e0c6ed94e..03bb2d828 100644 --- a/jolt-core/src/field/ark.rs +++ b/jolt-core/src/field/ark.rs @@ -37,4 +37,8 @@ impl JoltField for ark_bn254::Fr { assert_eq!(bytes.len(), Self::NUM_BYTES); ark_bn254::Fr::from_le_bytes_mod_order(bytes) } + + fn from_count_index(index: u64) -> Self { + ::from_u64(index).unwrap() + } } diff --git a/jolt-core/src/field/binius.rs b/jolt-core/src/field/binius.rs index f76f9fac9..354bd4cb2 100644 --- a/jolt-core/src/field/binius.rs +++ b/jolt-core/src/field/binius.rs @@ -1,5 +1,5 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use binius_field::{BinaryField128b, BinaryField128bPolyval}; +use binius_field::{BinaryField, BinaryField128b, BinaryField128bPolyval, TowerField}; use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; use super::JoltField; @@ -22,8 +22,37 @@ impl BiniusSpecific for BinaryField128bPolyval {} /// Trait for BiniusField functionality specific to each impl. pub trait BiniusSpecific: binius_field::TowerField + BiniusConstructable + bytemuck::Pod {} -pub trait BiniusConstructable { +const LOG_TABLE_SIZE: usize = 16; +const TABLE_SIZE: usize = 1 << LOG_TABLE_SIZE; + +pub trait BiniusConstructable: BinaryField { fn new(n: u64) -> Self; + + + /// Binius counts are constructed from multiplicities of a Binary Field multiplicative generator. + /// Precomputing all required counts [0, 2^32] is prohibitively expensive and using iterative multiplication + /// or square and multiply is still excessively costly. + /// Utilizes a two-table lookup method to handle counts up to `2^32` efficiently: + /// - Decompose count `x` as `x = a * 2^16 + b` where `a` and `b` are within the range `[0, 2^16]`. + /// - Two Precomptued Lookup Tables: + /// - One table stores `2^16` powers of `g^{2^16}`. + /// - Another table stores `2^16` powers of `g`. + /// - Computes any count up to `2^32` using: + /// ``` + /// g^x = (g^{2^16})^a * g^b + /// ``` + /// This is achieved with two lookups (one from each table) and a single multiplication. + const PRECOMPUTED_GENERATOR_MULTIPLES_LOW: [Self; TABLE_SIZE] = compute_powers::(Self::MULTIPLICATIVE_GENERATOR); + const PRECOMPUTED_GENERATOR_MULTIPLES_HIGH: [Self; TABLE_SIZE] = compute_powers_starting_from::(Self::MULTIPLICATIVE_GENERATOR, 16); + + fn from_count_index(n: u64) -> Self { + const MAX_COUNTS: usize = 1 << (2 * LOG_TABLE_SIZE); + assert!(n <= MAX_COUNTS as u64); + let high_index = (n >> LOG_TABLE_SIZE) as usize; + let low_index = (n & ((1 << LOG_TABLE_SIZE) - 1)) as usize; + // TODO(sragss): mul_0_optimized + Self::PRECOMPUTED_GENERATOR_MULTIPLES_HIGH[high_index] * Self::PRECOMPUTED_GENERATOR_MULTIPLES_LOW[low_index] + } } #[derive(Default, Debug, Copy, Clone, Eq, PartialEq)] @@ -67,6 +96,10 @@ impl JoltField for BiniusField { let field_element = bytemuck::try_from_bytes::(bytes).unwrap(); Self(field_element.to_owned()) } + + fn from_count_index(index: u64) -> Self { + Self(::from_count_index(index)) + } } impl Neg for BiniusField { @@ -183,3 +216,70 @@ impl ark_serialize::Valid for BiniusField { todo!() } } + +/// Computes `N` sequential powers of base^{[0, ... N]} +const fn compute_powers(base: F) -> [F; N] { + let mut powers = [F::ZERO; N]; + powers[0] = base; + let mut i = 1; + while i < N { + powers[i] = powers[i - 1] * base; + i += 1; + } + powers +} + +/// Computes `N` sequential powers of base^{2^starting_power}: {base^{2^starting_power}}^{[1, ... N]} +const fn compute_powers_starting_from(base: F, starting_power: usize) -> [F; N] { + let mut starting = base; + let mut count = 0; + // // Repeated squaring + while count < starting_power { + starting = starting.mul(starting); + count += 1; + } + compute_powers::(starting) +} + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_compute_powers() { + let base = BinaryField128b::MULTIPLICATIVE_GENERATOR; + let powers = compute_powers::(base); + let expected_powers = [ + base, + base * base, + base * base * base, + base * base * base * base, + base * base * base * base * base + ]; + assert_eq!(powers, expected_powers); + } + + #[test] + fn test_compute_powers_starting_from() { + let base = BinaryField128b::MULTIPLICATIVE_GENERATOR; + let powers = compute_powers_starting_from::(base, 2); + let expected_starting_base = base * base * base * base; + let expected_powers = [ + expected_starting_base, + expected_starting_base * expected_starting_base, + expected_starting_base * expected_starting_base * expected_starting_base, + expected_starting_base * expected_starting_base * expected_starting_base * expected_starting_base, + expected_starting_base * expected_starting_base * expected_starting_base * expected_starting_base * expected_starting_base, + ]; + assert_eq!(powers, expected_powers); + } + + #[test] + fn test_from_count_index() { + let actual = BiniusField::::from_count_index(0); + let expected = BinaryField128b::MULTIPLICATIVE_GENERATOR; + assert_eq!(actual.0, expected); + } +} + diff --git a/jolt-core/src/field/mod.rs b/jolt-core/src/field/mod.rs index 904d35ce6..d04459478 100644 --- a/jolt-core/src/field/mod.rs +++ b/jolt-core/src/field/mod.rs @@ -37,6 +37,7 @@ pub trait JoltField: fn from_u64(n: u64) -> Option; fn square(&self) -> Self; fn from_bytes(bytes: &[u8]) -> Self; + fn from_count_index(index: u64) -> Self; } pub mod ark; diff --git a/jolt-core/src/lib.rs b/jolt-core/src/lib.rs index ad2b670b5..214d0ec57 100644 --- a/jolt-core/src/lib.rs +++ b/jolt-core/src/lib.rs @@ -8,6 +8,12 @@ #![feature(iter_next_chunk)] #![allow(long_running_const_eval)] +// Note: Used exclusively by const fn BiniusConstructable::compute_powers. +// Can be removed with a manual const fn for BinaryField multiplication. +#![feature(const_trait_impl)] +#![feature(effects)] +#![feature(const_refs_to_cell)] + pub mod benches; pub mod field; pub mod host; From d2671eee5c35ad8afb3ba6d73b998cf80069f73f Mon Sep 17 00:00:00 2001 From: sragss Date: Thu, 9 May 2024 16:07:41 -0700 Subject: [PATCH 2/3] switch to OnceCell for statics --- jolt-core/Cargo.toml | 1 + jolt-core/src/field/binius.rs | 126 ++++++++++++++++++++++++++++------ jolt-core/src/lib.rs | 3 +- 3 files changed, 107 insertions(+), 23 deletions(-) diff --git a/jolt-core/Cargo.toml b/jolt-core/Cargo.toml index a376fb78d..1e90679bf 100644 --- a/jolt-core/Cargo.toml +++ b/jolt-core/Cargo.toml @@ -69,6 +69,7 @@ common = { path = "../common" } tracer = { path = "../tracer" } bincode = "1.3.3" bytemuck = "1.15.0" +once_cell = "1.19.0" [build-dependencies] common = { path = "../common" } diff --git a/jolt-core/src/field/binius.rs b/jolt-core/src/field/binius.rs index 354bd4cb2..4bb958e8a 100644 --- a/jolt-core/src/field/binius.rs +++ b/jolt-core/src/field/binius.rs @@ -1,19 +1,55 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use binius_field::{BinaryField, BinaryField128b, BinaryField128bPolyval, TowerField}; +use binius_field::{BinaryField, BinaryField128b, BinaryField128bPolyval}; use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; use super::JoltField; +use once_cell::sync::Lazy; + +static PRECOMPUTED_LOW_128B: Lazy<[BinaryField128b; TABLE_SIZE]> = Lazy::new(|| { + compute_powers::(BinaryField128b::MULTIPLICATIVE_GENERATOR) +}); +static PRECOMPUTED_HIGH_128B: Lazy<[BinaryField128b; TABLE_SIZE]> = Lazy::new(|| { + compute_powers_starting_from::( + BinaryField128b::MULTIPLICATIVE_GENERATOR, + 16, + ) +}); impl BiniusConstructable for BinaryField128b { fn new(n: u64) -> Self { Self::new(n as u128) } + + fn precomputed_generator_multiples( + ) -> (&'static [Self; TABLE_SIZE], &'static [Self; TABLE_SIZE]) { + (&PRECOMPUTED_LOW_128B, &PRECOMPUTED_HIGH_128B) + } } +static PRECOMPUTED_LOW_128B_POLYVAL: Lazy<[BinaryField128bPolyval; TABLE_SIZE]> = Lazy::new(|| { + compute_powers::( + BinaryField128bPolyval::MULTIPLICATIVE_GENERATOR, + ) +}); +static PRECOMPUTED_HIGH_128B_POLYVAL: Lazy<[BinaryField128bPolyval; TABLE_SIZE]> = + Lazy::new(|| { + compute_powers_starting_from::( + BinaryField128bPolyval::MULTIPLICATIVE_GENERATOR, + 16, + ) + }); impl BiniusConstructable for BinaryField128bPolyval { fn new(n: u64) -> Self { Self::new(n as u128) } + + fn precomputed_generator_multiples( + ) -> (&'static [Self; TABLE_SIZE], &'static [Self; TABLE_SIZE]) { + ( + &PRECOMPUTED_LOW_128B_POLYVAL, + &PRECOMPUTED_HIGH_128B_POLYVAL, + ) + } } impl BiniusSpecific for BinaryField128b {} @@ -28,8 +64,7 @@ const TABLE_SIZE: usize = 1 << LOG_TABLE_SIZE; pub trait BiniusConstructable: BinaryField { fn new(n: u64) -> Self; - - /// Binius counts are constructed from multiplicities of a Binary Field multiplicative generator. + /// Binius counts are constructed from multiplicities of a Binary Field multiplicative generator. /// Precomputing all required counts [0, 2^32] is prohibitively expensive and using iterative multiplication /// or square and multiply is still excessively costly. /// Utilizes a two-table lookup method to handle counts up to `2^32` efficiently: @@ -37,21 +72,29 @@ pub trait BiniusConstructable: BinaryField { /// - Two Precomptued Lookup Tables: /// - One table stores `2^16` powers of `g^{2^16}`. /// - Another table stores `2^16` powers of `g`. - /// - Computes any count up to `2^32` using: - /// ``` - /// g^x = (g^{2^16})^a * g^b - /// ``` - /// This is achieved with two lookups (one from each table) and a single multiplication. - const PRECOMPUTED_GENERATOR_MULTIPLES_LOW: [Self; TABLE_SIZE] = compute_powers::(Self::MULTIPLICATIVE_GENERATOR); - const PRECOMPUTED_GENERATOR_MULTIPLES_HIGH: [Self; TABLE_SIZE] = compute_powers_starting_from::(Self::MULTIPLICATIVE_GENERATOR, 16); + /// - Computes any count up to `2^32` using: `g^x = (g^{2^16})^a * g^b` + /// This is achieved with two lookups (one from each table) and a single multiplication. + /// + /// `precomputed_generator_multiples() -> (&low_table, &high_table)` + fn precomputed_generator_multiples( + ) -> (&'static [Self; TABLE_SIZE], &'static [Self; TABLE_SIZE]); fn from_count_index(n: u64) -> Self { const MAX_COUNTS: usize = 1 << (2 * LOG_TABLE_SIZE); - assert!(n <= MAX_COUNTS as u64); + let n = (n) as usize; + assert!(n <= MAX_COUNTS); + assert!(n != 0); + let high_index = (n >> LOG_TABLE_SIZE) as usize; let low_index = (n & ((1 << LOG_TABLE_SIZE) - 1)) as usize; + let (precomputed_low, precomputed_high) = Self::precomputed_generator_multiples(); // TODO(sragss): mul_0_optimized - Self::PRECOMPUTED_GENERATOR_MULTIPLES_HIGH[high_index] * Self::PRECOMPUTED_GENERATOR_MULTIPLES_LOW[low_index] + if n < TABLE_SIZE { + precomputed_low[n] + } else { + println!("low[{low_index}] * high[{high_index}]"); + precomputed_low[low_index] * precomputed_high[high_index] + } } } @@ -220,8 +263,9 @@ impl ark_serialize::Valid for BiniusField { /// Computes `N` sequential powers of base^{[0, ... N]} const fn compute_powers(base: F) -> [F; N] { let mut powers = [F::ZERO; N]; - powers[0] = base; - let mut i = 1; + powers[0] = F::ONE; + powers[1] = base; + let mut i = 2; while i < N { powers[i] = powers[i - 1] * base; i += 1; @@ -230,7 +274,10 @@ const fn compute_powers(base: F) - } /// Computes `N` sequential powers of base^{2^starting_power}: {base^{2^starting_power}}^{[1, ... N]} -const fn compute_powers_starting_from(base: F, starting_power: usize) -> [F; N] { +const fn compute_powers_starting_from( + base: F, + starting_power: usize, +) -> [F; N] { let mut starting = base; let mut count = 0; // // Repeated squaring @@ -241,21 +288,21 @@ const fn compute_powers_starting_from(starting) } - #[cfg(test)] mod tests { use super::*; + use binius_field::Field; #[test] fn test_compute_powers() { let base = BinaryField128b::MULTIPLICATIVE_GENERATOR; let powers = compute_powers::(base); let expected_powers = [ + BinaryField128b::ONE, base, base * base, base * base * base, base * base * base * base, - base * base * base * base * base ]; assert_eq!(powers, expected_powers); } @@ -266,20 +313,57 @@ mod tests { let powers = compute_powers_starting_from::(base, 2); let expected_starting_base = base * base * base * base; let expected_powers = [ + BinaryField128b::ONE, expected_starting_base, expected_starting_base * expected_starting_base, expected_starting_base * expected_starting_base * expected_starting_base, - expected_starting_base * expected_starting_base * expected_starting_base * expected_starting_base, - expected_starting_base * expected_starting_base * expected_starting_base * expected_starting_base * expected_starting_base, + expected_starting_base + * expected_starting_base + * expected_starting_base + * expected_starting_base, ]; assert_eq!(powers, expected_powers); } #[test] fn test_from_count_index() { - let actual = BiniusField::::from_count_index(0); + let actual = BiniusField::::from_count_index(1); let expected = BinaryField128b::MULTIPLICATIVE_GENERATOR; assert_eq!(actual.0, expected); + + let actual = BiniusField::::from_count_index(2); + let expected = + BinaryField128b::MULTIPLICATIVE_GENERATOR * BinaryField128b::MULTIPLICATIVE_GENERATOR; + assert_eq!(actual.0, expected); + + let actual = BiniusField::::from_count_index(3); + let expected = BinaryField128b::MULTIPLICATIVE_GENERATOR + * BinaryField128b::MULTIPLICATIVE_GENERATOR + * BinaryField128b::MULTIPLICATIVE_GENERATOR; + assert_eq!(actual.0, expected); + + let actual = BiniusField::::from_count_index(1 << 17); + let mut expected = BinaryField128b::MULTIPLICATIVE_GENERATOR; + for _ in 0..17 { + expected = expected.square(); + } + assert_eq!(actual.0, expected); + + let actual = BiniusField::::from_count_index((1 << 17) + 1); + let mut expected = BinaryField128b::MULTIPLICATIVE_GENERATOR; + for _ in 0..17 { + expected = expected.square(); + } + expected *= BinaryField128b::MULTIPLICATIVE_GENERATOR; + assert_eq!(actual.0, expected); + + let actual = BiniusField::::from_count_index((1 << 18) + 2); + let mut expected = BinaryField128b::MULTIPLICATIVE_GENERATOR; + for _ in 0..18 { + expected = expected.square(); + } + expected *= BinaryField128b::MULTIPLICATIVE_GENERATOR; + expected *= BinaryField128b::MULTIPLICATIVE_GENERATOR; + assert_eq!(actual.0, expected); } } - diff --git a/jolt-core/src/lib.rs b/jolt-core/src/lib.rs index 214d0ec57..faa22f075 100644 --- a/jolt-core/src/lib.rs +++ b/jolt-core/src/lib.rs @@ -7,8 +7,7 @@ #![feature(generic_const_exprs)] #![feature(iter_next_chunk)] #![allow(long_running_const_eval)] - -// Note: Used exclusively by const fn BiniusConstructable::compute_powers. +// Note: Used exclusively by const fn BiniusConstructable::compute_powers. // Can be removed with a manual const fn for BinaryField multiplication. #![feature(const_trait_impl)] #![feature(effects)] From 1241739007964a4e2c2ba71fc07e76082a0ca925 Mon Sep 17 00:00:00 2001 From: sragss Date: Thu, 9 May 2024 16:20:40 -0700 Subject: [PATCH 3/3] refactor lazy computations to seperate file --- jolt-core/src/field/binius.rs | 125 ++++++----------------------- jolt-core/src/field/mod.rs | 1 + jolt-core/src/field/precomputed.rs | 97 ++++++++++++++++++++++ 3 files changed, 123 insertions(+), 100 deletions(-) create mode 100644 jolt-core/src/field/precomputed.rs diff --git a/jolt-core/src/field/binius.rs b/jolt-core/src/field/binius.rs index 4bb958e8a..9df93d60d 100644 --- a/jolt-core/src/field/binius.rs +++ b/jolt-core/src/field/binius.rs @@ -2,49 +2,37 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use binius_field::{BinaryField, BinaryField128b, BinaryField128bPolyval}; use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; -use super::JoltField; - -use once_cell::sync::Lazy; - -static PRECOMPUTED_LOW_128B: Lazy<[BinaryField128b; TABLE_SIZE]> = Lazy::new(|| { - compute_powers::(BinaryField128b::MULTIPLICATIVE_GENERATOR) -}); -static PRECOMPUTED_HIGH_128B: Lazy<[BinaryField128b; TABLE_SIZE]> = Lazy::new(|| { - compute_powers_starting_from::( - BinaryField128b::MULTIPLICATIVE_GENERATOR, - 16, - ) -}); +use super::precomputed; +use super::{ + precomputed::{ + PRECOMPUTED_HIGH_128B, PRECOMPUTED_HIGH_128B_POLYVAL, PRECOMPUTED_LOW_128B, + PRECOMPUTED_LOW_128B_POLYVAL, + }, + JoltField, +}; + impl BiniusConstructable for BinaryField128b { fn new(n: u64) -> Self { Self::new(n as u128) } - fn precomputed_generator_multiples( - ) -> (&'static [Self; TABLE_SIZE], &'static [Self; TABLE_SIZE]) { + fn precomputed_generator_multiples() -> ( + &'static [Self; precomputed::TABLE_SIZE], + &'static [Self; precomputed::TABLE_SIZE], + ) { (&PRECOMPUTED_LOW_128B, &PRECOMPUTED_HIGH_128B) } } -static PRECOMPUTED_LOW_128B_POLYVAL: Lazy<[BinaryField128bPolyval; TABLE_SIZE]> = Lazy::new(|| { - compute_powers::( - BinaryField128bPolyval::MULTIPLICATIVE_GENERATOR, - ) -}); -static PRECOMPUTED_HIGH_128B_POLYVAL: Lazy<[BinaryField128bPolyval; TABLE_SIZE]> = - Lazy::new(|| { - compute_powers_starting_from::( - BinaryField128bPolyval::MULTIPLICATIVE_GENERATOR, - 16, - ) - }); impl BiniusConstructable for BinaryField128bPolyval { fn new(n: u64) -> Self { Self::new(n as u128) } - fn precomputed_generator_multiples( - ) -> (&'static [Self; TABLE_SIZE], &'static [Self; TABLE_SIZE]) { + fn precomputed_generator_multiples() -> ( + &'static [Self; precomputed::TABLE_SIZE], + &'static [Self; precomputed::TABLE_SIZE], + ) { ( &PRECOMPUTED_LOW_128B_POLYVAL, &PRECOMPUTED_HIGH_128B_POLYVAL, @@ -58,9 +46,6 @@ impl BiniusSpecific for BinaryField128bPolyval {} /// Trait for BiniusField functionality specific to each impl. pub trait BiniusSpecific: binius_field::TowerField + BiniusConstructable + bytemuck::Pod {} -const LOG_TABLE_SIZE: usize = 16; -const TABLE_SIZE: usize = 1 << LOG_TABLE_SIZE; - pub trait BiniusConstructable: BinaryField { fn new(n: u64) -> Self; @@ -76,23 +61,23 @@ pub trait BiniusConstructable: BinaryField { /// This is achieved with two lookups (one from each table) and a single multiplication. /// /// `precomputed_generator_multiples() -> (&low_table, &high_table)` - fn precomputed_generator_multiples( - ) -> (&'static [Self; TABLE_SIZE], &'static [Self; TABLE_SIZE]); + fn precomputed_generator_multiples() -> ( + &'static [Self; precomputed::TABLE_SIZE], + &'static [Self; precomputed::TABLE_SIZE], + ); fn from_count_index(n: u64) -> Self { - const MAX_COUNTS: usize = 1 << (2 * LOG_TABLE_SIZE); + const MAX_COUNTS: usize = 1 << (2 * precomputed::LOG_TABLE_SIZE); let n = (n) as usize; assert!(n <= MAX_COUNTS); assert!(n != 0); - let high_index = (n >> LOG_TABLE_SIZE) as usize; - let low_index = (n & ((1 << LOG_TABLE_SIZE) - 1)) as usize; + let high_index = (n >> precomputed::LOG_TABLE_SIZE) as usize; + let low_index = (n & ((1 << precomputed::LOG_TABLE_SIZE) - 1)) as usize; let (precomputed_low, precomputed_high) = Self::precomputed_generator_multiples(); - // TODO(sragss): mul_0_optimized - if n < TABLE_SIZE { + if n < precomputed::TABLE_SIZE { precomputed_low[n] } else { - println!("low[{low_index}] * high[{high_index}]"); precomputed_low[low_index] * precomputed_high[high_index] } } @@ -260,71 +245,11 @@ impl ark_serialize::Valid for BiniusField { } } -/// Computes `N` sequential powers of base^{[0, ... N]} -const fn compute_powers(base: F) -> [F; N] { - let mut powers = [F::ZERO; N]; - powers[0] = F::ONE; - powers[1] = base; - let mut i = 2; - while i < N { - powers[i] = powers[i - 1] * base; - i += 1; - } - powers -} - -/// Computes `N` sequential powers of base^{2^starting_power}: {base^{2^starting_power}}^{[1, ... N]} -const fn compute_powers_starting_from( - base: F, - starting_power: usize, -) -> [F; N] { - let mut starting = base; - let mut count = 0; - // // Repeated squaring - while count < starting_power { - starting = starting.mul(starting); - count += 1; - } - compute_powers::(starting) -} - #[cfg(test)] mod tests { use super::*; use binius_field::Field; - #[test] - fn test_compute_powers() { - let base = BinaryField128b::MULTIPLICATIVE_GENERATOR; - let powers = compute_powers::(base); - let expected_powers = [ - BinaryField128b::ONE, - base, - base * base, - base * base * base, - base * base * base * base, - ]; - assert_eq!(powers, expected_powers); - } - - #[test] - fn test_compute_powers_starting_from() { - let base = BinaryField128b::MULTIPLICATIVE_GENERATOR; - let powers = compute_powers_starting_from::(base, 2); - let expected_starting_base = base * base * base * base; - let expected_powers = [ - BinaryField128b::ONE, - expected_starting_base, - expected_starting_base * expected_starting_base, - expected_starting_base * expected_starting_base * expected_starting_base, - expected_starting_base - * expected_starting_base - * expected_starting_base - * expected_starting_base, - ]; - assert_eq!(powers, expected_powers); - } - #[test] fn test_from_count_index() { let actual = BiniusField::::from_count_index(1); diff --git a/jolt-core/src/field/mod.rs b/jolt-core/src/field/mod.rs index d04459478..edae6c5d5 100644 --- a/jolt-core/src/field/mod.rs +++ b/jolt-core/src/field/mod.rs @@ -42,3 +42,4 @@ pub trait JoltField: pub mod ark; pub mod binius; +mod precomputed; diff --git a/jolt-core/src/field/precomputed.rs b/jolt-core/src/field/precomputed.rs new file mode 100644 index 000000000..9bbeac1f9 --- /dev/null +++ b/jolt-core/src/field/precomputed.rs @@ -0,0 +1,97 @@ +use binius_field::{BinaryField, BinaryField128b, BinaryField128bPolyval}; +use once_cell::sync::Lazy; + +pub const LOG_TABLE_SIZE: usize = 16; +pub const TABLE_SIZE: usize = 1 << LOG_TABLE_SIZE; + +// BinaryField128b +pub static PRECOMPUTED_LOW_128B: Lazy<[BinaryField128b; TABLE_SIZE]> = Lazy::new(|| { + compute_powers::(BinaryField128b::MULTIPLICATIVE_GENERATOR) +}); +pub static PRECOMPUTED_HIGH_128B: Lazy<[BinaryField128b; TABLE_SIZE]> = Lazy::new(|| { + compute_powers_starting_from::( + BinaryField128b::MULTIPLICATIVE_GENERATOR, + 16, + ) +}); + +// BinaryField128bPolyval +pub static PRECOMPUTED_LOW_128B_POLYVAL: Lazy<[BinaryField128bPolyval; TABLE_SIZE]> = + Lazy::new(|| { + compute_powers::( + BinaryField128bPolyval::MULTIPLICATIVE_GENERATOR, + ) + }); +pub static PRECOMPUTED_HIGH_128B_POLYVAL: Lazy<[BinaryField128bPolyval; TABLE_SIZE]> = + Lazy::new(|| { + compute_powers_starting_from::( + BinaryField128bPolyval::MULTIPLICATIVE_GENERATOR, + 16, + ) + }); + +/// Computes `N` sequential powers of base^{[0, ... N]} +const fn compute_powers(base: F) -> [F; N] { + let mut powers = [F::ZERO; N]; + powers[0] = F::ONE; + powers[1] = base; + let mut i = 2; + while i < N { + powers[i] = powers[i - 1] * base; + i += 1; + } + powers +} + +/// Computes `N` sequential powers of base^{2^starting_power}: {base^{2^starting_power}}^{[1, ... N]} +const fn compute_powers_starting_from( + base: F, + starting_power: usize, +) -> [F; N] { + let mut starting = base; + let mut count = 0; + // // Repeated squaring + while count < starting_power { + starting = starting.mul(starting); + count += 1; + } + compute_powers::(starting) +} + +#[cfg(test)] +mod tests { + use super::*; + use binius_field::{BinaryField, Field}; + + #[test] + fn test_compute_powers() { + let base = BinaryField128b::MULTIPLICATIVE_GENERATOR; + let powers = compute_powers::(base); + let expected_powers = [ + BinaryField128b::ONE, + base, + base * base, + base * base * base, + base * base * base * base, + ]; + assert_eq!(powers, expected_powers); + } + + #[test] + fn test_compute_powers_starting_from() { + let base = BinaryField128b::MULTIPLICATIVE_GENERATOR; + let powers = compute_powers_starting_from::(base, 2); + let expected_starting_base = base * base * base * base; + let expected_powers = [ + BinaryField128b::ONE, + expected_starting_base, + expected_starting_base * expected_starting_base, + expected_starting_base * expected_starting_base * expected_starting_base, + expected_starting_base + * expected_starting_base + * expected_starting_base + * expected_starting_base, + ]; + assert_eq!(powers, expected_powers); + } +}