Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

utils/mmr: safety & performance rework #274

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 75 additions & 64 deletions crates/util/mmr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,82 +9,89 @@
use error::MerkleError;
use hasher::{Hash, MerkleHasher};

fn zero() -> Hash {
[0; 32]
}

fn is_zero(h: Hash) -> bool {
h.iter().all(|b| *b == 0)
}

/// Compact representation of the MMR that should be borsh serializable easily.
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, Arbitrary)]
pub struct CompactMmr {
Comment on lines 12 to 14
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be really cool is to write manual borsh serialization so that we don't have to store the length of the peaks list and can infer it from the element count (since it's redundant, it's just popcnt).

entries: u64,
cap_log2: u8,
roots: Vec<Hash>,
element_count: u64,
peaks: Vec<Hash>,
}

const ZERO: [u8; 32] = [0; 32];

#[derive(Clone)]
pub struct MerkleMr<H: MerkleHasher + Clone> {
// number of elements inserted into mmr
pub num: u64,
/// number of elements inserted into mmr
pub element_count: u64,
// Buffer of all possible peaks in mmr. only some of them will be valid at a time
pub peaks: Box<[Hash]>,
// phantom data for hasher
pub hasher: PhantomData<H>,
}

impl<H: MerkleHasher + Clone> MerkleMr<H> {
pub fn new(cap_log2: usize) -> Self {
pub fn new(peak_count: usize) -> Self {
Self {
num: 0,
peaks: vec![[0; 32]; cap_log2].into_boxed_slice(),
element_count: 0,
peaks: vec![ZERO; peak_count].into_boxed_slice(),
hasher: PhantomData,
}
}

pub fn from_compact(compact: &CompactMmr) -> Self {
// FIXME this is somewhat inefficient, we could consume the vec and just
// slice out its elements, but this is fine for now
let mut roots = vec![zero(); compact.cap_log2 as usize];
let mut at = 0;
for i in 0..compact.cap_log2 {
if compact.entries >> i & 1 != 0 {
roots[i as usize] = compact.roots[at as usize];
at += 1;
}
/// returns the minimum peaks needed to store an mmr of `element_count` elements
#[inline]
fn min_peaks(element_count: u64) -> Option<usize> {
match element_count {
0 => None,

Check warning on line 44 in crates/util/mmr/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/util/mmr/src/lib.rs#L44

Added line #L44 was not covered by tests
c => Some(c.ilog2() as usize + 1),
}
}

/// restores from a compacted mmr
pub fn from_compact(compact: CompactMmr) -> Self {
let required_peaks = Self::min_peaks(compact.element_count);
let peaks = match required_peaks {
None => vec![ZERO],

Check warning on line 53 in crates/util/mmr/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/util/mmr/src/lib.rs#L53

Added line #L53 was not covered by tests
Some(required) => {
let mut peaks = compact.peaks;
if peaks.len() < required {
let num_to_add = required - peaks.len();
peaks.reserve_exact(num_to_add);
// note we add in a loop so we don't need to make 2 allocs
// (the ZEROs will be stack allocated... i think)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// (the ZEROs will be stack allocated... i think)
// (the ZEROs might be be stack allocated...)

Unpersonal.

Also I don't understand the note.
On one sense, yes ZERO is not only const but also Sized, hence it can be trivially stack-allocated by the compiler.
However, the peaks is a Vec and MAYBE the compiler can totally infer the total size and any possible interactions at compile time, but there might be no such guarantees at run time.
Hence, I would not be surprised if peaks is heap-allocated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK Vecs' sizes cannot be inferred at build time, they're purely allocated at runtime.

What I mean is that instead of a vec![ZERO; num_to_add] then extending peaks (which correct, is heap allocated, all Vecs are) which would incur another allocation - once when initializing the vector then a second allocation to peaks when extending from the slice - we can avoid the initialization entirely by using a loop:

(0..num_to_add).for_each(|_| peaks.push(ZERO))

Where each iteration of the loop puts a ZERO on the stack, then memcpy's it to the end of peaks. This avoids the second allocation required if we were to initialize another vector then extend from it.

(0..num_to_add).for_each(|_| peaks.push(ZERO))

Check warning on line 61 in crates/util/mmr/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/util/mmr/src/lib.rs#L57-L61

Added lines #L57 - L61 were not covered by tests
}
peaks
}
};

Self {
num: compact.entries,
peaks: roots.into(),
element_count: compact.element_count,
peaks: peaks.into(),
hasher: PhantomData,
}
}

/// exports a "compact" version of the mmr that can be easily serialized
pub fn to_compact(&self) -> CompactMmr {
CompactMmr {
entries: self.num,
cap_log2: self.peaks.len() as u8,
roots: self
.peaks
.iter()
.filter(|h| !is_zero(**h))
.copied()
.collect(),
element_count: self.element_count,
peaks: match Self::min_peaks(self.element_count) {
Some(required) => self.peaks.iter().take(required).copied().collect(),
None => vec![ZERO],

Check warning on line 80 in crates/util/mmr/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/util/mmr/src/lib.rs#L80

Added line #L80 was not covered by tests
},
}
}

/// Adds a leaf's hash to the MMR
pub fn add_leaf(&mut self, hash_arr: Hash) {
if self.num == 0 {
if self.element_count == 0 {
self.peaks[0] = hash_arr;
self.num += 1;
self.element_count += 1;
return;
}

// the number of elements in MMR is also the mask of peaks
let peak_mask = self.num;
let peak_mask = self.element_count;

let mut current_node = hash_arr;
// we iterate through the height
Expand All @@ -101,26 +108,26 @@
}

self.peaks[current_height] = current_node;
self.num += 1;
self.element_count += 1;
}

pub fn get_single_root(&self) -> Result<Hash, MerkleError> {
if self.num == 0 {
if self.element_count == 0 {
return Err(MerkleError::NoElements);
}
if !self.num.is_power_of_two() && self.num != 1 {
if !self.element_count.is_power_of_two() && self.element_count != 1 {
return Err(MerkleError::NotPowerOfTwo);
}

Ok(self.peaks[(self.num.ilog2()) as usize])
Ok(self.peaks[(self.element_count.ilog2()) as usize])
}

pub fn add_leaf_updating_proof(
&mut self,
next: Hash,
proof: &MerkleProof<H>,
) -> MerkleProof<H> {
if self.num == 0 {
if self.element_count == 0 {
self.add_leaf(next);
return MerkleProof {
cohashes: vec![],
Expand All @@ -130,8 +137,8 @@
}
let mut updated_proof = proof.clone();

let new_leaf_index = self.num;
let peak_mask = self.num;
let new_leaf_index = self.element_count;
let peak_mask = self.element_count;
let mut current_node = next;
let mut current_height = 0;
while (peak_mask >> current_height) & 1 == 1 {
Expand All @@ -152,7 +159,7 @@
}

self.peaks[current_height] = current_node;
self.num += 1;
self.element_count += 1;

updated_proof
}
Expand Down Expand Up @@ -184,7 +191,7 @@
next: Hash,
proof_list: &mut [MerkleProof<H>],
) -> MerkleProof<H> {
if self.num == 0 {
if self.element_count == 0 {
self.add_leaf(next);
return MerkleProof {
cohashes: vec![],
Expand All @@ -194,12 +201,12 @@
}
let mut new_proof = MerkleProof {
cohashes: vec![],
index: self.num,
index: self.element_count,
_pd: PhantomData,
};

let new_leaf_index = self.num;
let peak_mask = self.num;
let new_leaf_index = self.element_count;
let peak_mask = self.element_count;
let mut current_node = next;
let mut current_height = 0;
while (peak_mask >> current_height) & 1 == 1 {
Expand Down Expand Up @@ -230,7 +237,7 @@
}

self.peaks[current_height] = current_node;
self.num += 1;
self.element_count += 1;

new_proof
}
Expand Down Expand Up @@ -267,7 +274,7 @@
proof_list: &[MerkleProof<H>],
index: u64,
) -> Result<Option<MerkleProof<H>>, MerkleError> {
if index > self.num {
if index > self.element_count {
return Err(MerkleError::IndexOutOfBounds);
}

Expand Down Expand Up @@ -315,7 +322,9 @@
use super::{hasher::Hash, MerkleMr, MerkleProof};
use crate::error::MerkleError;

fn generate_for_n_integers(n: usize) -> (MerkleMr<Sha256>, Vec<MerkleProof<Sha256>>) {
fn generate_for_n_integers(
n: usize,
) -> (MerkleMr<Sha256>, Vec<MerkleProof<Sha256>>, Vec<[u8; 32]>) {
let mut mmr: MerkleMr<Sha256> = MerkleMr::new(14);

let mut proof = Vec::new();
Expand All @@ -325,7 +334,7 @@
let new_proof = mmr.add_leaf_updating_proof_list(list_of_hashes[i], &mut proof);
proof.push(new_proof);
});
(mmr, proof)
(mmr, proof, list_of_hashes)
}

fn generate_hashes_for_n_integers(n: usize) -> Vec<Hash> {
Expand All @@ -335,7 +344,7 @@
}

fn mmr_proof_for_specific_nodes(n: usize, specific_nodes: Vec<u64>) {
let (mmr, proof_list) = generate_for_n_integers(n);
let (mmr, proof_list, _) = generate_for_n_integers(n);
let proof: Vec<MerkleProof<Sha256>> = specific_nodes
.iter()
.map(|i| {
Expand Down Expand Up @@ -368,7 +377,7 @@

#[test]
fn check_single_element() {
let (mmr, proof_list) = generate_for_n_integers(1);
let (mmr, proof_list, _) = generate_for_n_integers(1);

let proof = mmr
.gen_proof(&proof_list, 0)
Expand Down Expand Up @@ -438,7 +447,7 @@

#[test]
fn check_invalid_proof() {
let (mmr, _) = generate_for_n_integers(5);
let (mmr, ..) = generate_for_n_integers(5);
let invalid_proof = MerkleProof::<Sha256> {
cohashes: vec![],
index: 6,
Expand Down Expand Up @@ -485,18 +494,20 @@

#[test]
fn check_compact_and_non_compact() {
let (mmr, _) = generate_for_n_integers(5);
let (mmr, proofs, hashes) = generate_for_n_integers(5);

let compact_mmr = mmr.to_compact();
let deserialized_mmr = MerkleMr::<Sha256>::from_compact(&compact_mmr);
let deserialized_mmr = MerkleMr::<Sha256>::from_compact(compact_mmr);

assert_eq!(mmr.num, deserialized_mmr.num);
assert_eq!(mmr.peaks, deserialized_mmr.peaks);
assert_eq!(mmr.element_count, deserialized_mmr.element_count);
for (i, proof) in proofs.into_iter().enumerate() {
assert!(deserialized_mmr.verify(&proof, &hashes[i]))
}
}

#[test]
fn arbitrary_index_proof() {
let (mut mmr, _) = generate_for_n_integers(20);
let (mut mmr, ..) = generate_for_n_integers(20);
// update proof for 21st element
let mut proof: MerkleProof<Sha256> = MerkleProof {
cohashes: vec![],
Expand All @@ -518,7 +529,7 @@

#[test]
fn update_proof_list_from_arbitrary_index() {
let (mut mmr, _) = generate_for_n_integers(20);
let (mut mmr, ..) = generate_for_n_integers(20);
// update proof for 21st element
let mut proof_list = Vec::new();

Expand Down
Loading