Skip to content

Commit

Permalink
Use tinyvec instead of a fixed-size array in RoundId
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Nov 13, 2024
1 parent 891bd04 commit 600cf85
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 72 deletions.
17 changes: 17 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions manul/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ rand_core = { version = "0.6.4", default-features = false }
tracing = { version = "0.1", default-features = false }
displaydoc = { version = "0.2", default-features = false }
derive-where = "1"
tinyvec = { version = "1", default-features = false, features = ["alloc", "serde"] }

rand = { version = "0.8", default-features = false, optional = true }
serde-persistent-deserializer = { version = "0.3", optional = true }
Expand Down
57 changes: 16 additions & 41 deletions manul/src/protocol/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use alloc::{
collections::{BTreeMap, BTreeSet},
format,
string::String,
vec,
vec::Vec,
};
use core::{
Expand All @@ -12,6 +13,7 @@ use core::{

use rand_core::CryptoRngCore;
use serde::{Deserialize, Serialize};
use tinyvec::{tiny_vec, TinyVec};

use super::{
errors::{FinalizeError, LocalError, MessageValidationError, ProtocolValidationError, ReceiveError},
Expand All @@ -29,24 +31,18 @@ pub enum FinalizeOutcome<Id: PartyId, P: Protocol> {
Result(P::Result),
}

// Maximum depth of group nesting in RoundIds.
// We need this to be limited to allow the nesting to be performed in `const` context
// (since we cannot use heap there).
const ROUND_ID_DEPTH: usize = 8;

/// A round identifier.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct RoundId {
depth: u8,
round_nums: [u8; ROUND_ID_DEPTH],
round_nums: TinyVec<[u8; 4]>,
is_echo: bool,
}

impl Display for RoundId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "Round ")?;
for i in (0..self.depth as usize).rev() {
write!(f, "{}", self.round_nums.get(i).expect("Depth within range"))?;
for (i, round_num) in self.round_nums.iter().enumerate().rev() {
write!(f, "{}", round_num)?;
if i != 0 {
write!(f, "-")?;
}
Expand All @@ -60,36 +56,18 @@ impl Display for RoundId {

impl RoundId {
/// Creates a new round identifier.
pub const fn new(round_num: u8) -> Self {
let mut round_nums = [0u8; ROUND_ID_DEPTH];
#[allow(clippy::indexing_slicing)]
{
round_nums[0] = round_num;
}
pub fn new(round_num: u8) -> Self {
Self {
depth: 1,
round_nums,
round_nums: tiny_vec!(round_num, 0, 0, 0),
is_echo: false,
}
}

/// Prefixes this round ID (possibly already nested) with a group number.
///
/// **Warning:** the maximum nesting depth is 8. Panics if this nesting overflows it.
pub(crate) const fn group_under(&self, round_num: u8) -> Self {
if self.depth as usize == ROUND_ID_DEPTH {
panic!("Maximum depth reached");
}
let mut round_nums = self.round_nums;

// Would use `expect("Depth within range")` here, but `expect()` in const fns is unstable.
#[allow(clippy::indexing_slicing)]
{
round_nums[self.depth as usize] = round_num;
}

pub(crate) fn group_under(&self, round_num: u8) -> Self {
let mut round_nums = self.round_nums.clone();
round_nums.push(round_num);
Self {
depth: self.depth + 1,
round_nums,
is_echo: self.is_echo,
}
Expand All @@ -99,13 +77,12 @@ impl RoundId {
///
/// Returns the `Err` variant if the round ID is not nested.
pub(crate) fn ungroup(&self) -> Result<Self, LocalError> {
if self.depth == 1 {
if self.round_nums.len() == 1 {
Err(LocalError::new("This round ID is not in a group"))
} else {
let mut round_nums = self.round_nums;
*round_nums.get_mut(self.depth as usize - 1).expect("Depth within range") = 0;
let mut round_nums = self.round_nums.clone();
round_nums.pop().expect("vector size greater than 1");
Ok(Self {
depth: self.depth - 1,
round_nums,
is_echo: self.is_echo,
})
Expand All @@ -127,8 +104,7 @@ impl RoundId {
panic!("This is already an echo round ID");
}
Self {
depth: self.depth,
round_nums: self.round_nums,
round_nums: self.round_nums.clone(),
is_echo: true,
}
}
Expand All @@ -143,8 +119,7 @@ impl RoundId {
panic!("This is already an non-echo round ID");
}
Self {
depth: self.depth,
round_nums: self.round_nums,
round_nums: self.round_nums.clone(),
is_echo: false,
}
}
Expand Down
32 changes: 16 additions & 16 deletions manul/src/session/evidence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ where
.iter()
.map(|round_id| {
transcript
.get_echo_broadcast(*round_id, verifier)
.map(|echo| (*round_id, echo))
.get_echo_broadcast(round_id.clone(), verifier)
.map(|echo| (round_id.clone(), echo))
})
.collect::<Result<BTreeMap<_, _>, _>>()?;

Expand All @@ -110,8 +110,8 @@ where
.iter()
.map(|round_id| {
transcript
.get_normal_broadcast(*round_id, verifier)
.map(|bc| (*round_id, bc))
.get_normal_broadcast(round_id.clone(), verifier)
.map(|bc| (round_id.clone(), bc))
})
.collect::<Result<BTreeMap<_, _>, _>>()?;

Expand All @@ -120,8 +120,8 @@ where
.iter()
.map(|round_id| {
transcript
.get_direct_message(*round_id, verifier)
.map(|dm| (*round_id, dm))
.get_direct_message(round_id.clone(), verifier)
.map(|dm| (round_id.clone(), dm))
})
.collect::<Result<BTreeMap<_, _>, _>>()?;

Expand All @@ -131,7 +131,7 @@ where
.map(|round_id| {
transcript
.get_normal_broadcast(round_id.echo(), verifier)
.map(|dm| (*round_id, dm))
.map(|dm| (round_id.clone(), dm))
})
.collect::<Result<BTreeMap<_, _>, _>>()?;

Expand Down Expand Up @@ -470,12 +470,12 @@ where
for (round_id, direct_message) in self.direct_messages.iter() {
let verified_direct_message = direct_message.clone().verify::<SP>(verifier)?;
let metadata = verified_direct_message.metadata();
if metadata.session_id() != session_id || metadata.round_id() != *round_id {
if metadata.session_id() != session_id || &metadata.round_id() != round_id {
return Err(EvidenceError::InvalidEvidence(
"Invalid attached message metadata".into(),
));
}
verified_direct_messages.insert(*round_id, verified_direct_message.payload().clone());
verified_direct_messages.insert(round_id.clone(), verified_direct_message.payload().clone());
}

let verified_echo_broadcast = self.echo_broadcast.clone().verify::<SP>(verifier)?.payload().clone();
Expand All @@ -500,31 +500,31 @@ where
for (round_id, echo_broadcast) in self.echo_broadcasts.iter() {
let verified_echo_broadcast = echo_broadcast.clone().verify::<SP>(verifier)?;
let metadata = verified_echo_broadcast.metadata();
if metadata.session_id() != session_id || metadata.round_id() != *round_id {
if metadata.session_id() != session_id || &metadata.round_id() != round_id {
return Err(EvidenceError::InvalidEvidence(
"Invalid attached message metadata".into(),
));
}
verified_echo_broadcasts.insert(*round_id, verified_echo_broadcast.payload().clone());
verified_echo_broadcasts.insert(round_id.clone(), verified_echo_broadcast.payload().clone());
}

let mut verified_normal_broadcasts = BTreeMap::new();
for (round_id, normal_broadcast) in self.normal_broadcasts.iter() {
let verified_normal_broadcast = normal_broadcast.clone().verify::<SP>(verifier)?;
let metadata = verified_normal_broadcast.metadata();
if metadata.session_id() != session_id || metadata.round_id() != *round_id {
if metadata.session_id() != session_id || &metadata.round_id() != round_id {
return Err(EvidenceError::InvalidEvidence(
"Invalid attached message metadata".into(),
));
}
verified_normal_broadcasts.insert(*round_id, verified_normal_broadcast.payload().clone());
verified_normal_broadcasts.insert(round_id.clone(), verified_normal_broadcast.payload().clone());
}

let mut combined_echos = BTreeMap::new();
for (round_id, combined_echo) in self.combined_echos.iter() {
let verified_combined_echo = combined_echo.clone().verify::<SP>(verifier)?;
let metadata = verified_combined_echo.metadata();
if metadata.session_id() != session_id || metadata.round_id().non_echo() != *round_id {
if metadata.session_id() != session_id || &metadata.round_id().non_echo() != round_id {
return Err(EvidenceError::InvalidEvidence(
"Invalid attached message metadata".into(),
));
Expand All @@ -537,14 +537,14 @@ where
for (other_verifier, echo_broadcast) in echo_set.echo_broadcasts.iter() {
let verified_echo_broadcast = echo_broadcast.clone().verify::<SP>(other_verifier)?;
let metadata = verified_echo_broadcast.metadata();
if metadata.session_id() != session_id || metadata.round_id() != *round_id {
if metadata.session_id() != session_id || &metadata.round_id() != round_id {
return Err(EvidenceError::InvalidEvidence(
"Invalid attached message metadata".into(),
));
}
verified_echo_set.push(verified_echo_broadcast.payload().clone());
}
combined_echos.insert(*round_id, verified_echo_set);
combined_echos.insert(round_id.clone(), verified_echo_set);
}

Ok(self.error.verify_messages_constitute_error(
Expand Down
2 changes: 1 addition & 1 deletion manul/src/session/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl MessageMetadata {
}

pub fn round_id(&self) -> RoundId {
self.round_id
self.round_id.clone()
}
}

Expand Down
18 changes: 9 additions & 9 deletions manul/src/session/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ where
}
MessageFor::ThisRound
} else if self.possible_next_rounds.contains(&message_round_id) {
if accum.message_is_cached(from, message_round_id) {
if accum.message_is_cached(from, &message_round_id) {
let err = format!("Message for {:?} is already cached", message_round_id);
accum.register_unprovable_error(from, RemoteError::new(&err))?;
trace!("{key:?} {err}");
Expand Down Expand Up @@ -354,7 +354,7 @@ where
match message_for {
MessageFor::ThisRound => {
accum.mark_processing(&verified_message)?;
Ok(PreprocessOutcome::ToProcess(verified_message))
Ok(PreprocessOutcome::ToProcess(Box::new(verified_message)))
}
MessageFor::NextRound => {
debug!("{key:?}: Caching message from {from:?} for {message_round_id}");
Expand Down Expand Up @@ -406,7 +406,7 @@ where
) -> Result<SessionReport<P, SP>, LocalError> {
let round_id = self.round_id();
let transcript = self.transcript.update(
round_id,
&round_id,
accum.echo_broadcasts,
accum.normal_broadcasts,
accum.direct_messages,
Expand Down Expand Up @@ -446,7 +446,7 @@ where
let round_id = self.round_id();

let transcript = self.transcript.update(
round_id,
&round_id,
accum.echo_broadcasts,
accum.normal_broadcasts,
accum.direct_messages,
Expand Down Expand Up @@ -604,9 +604,9 @@ where
self.processing.contains(from)
}

fn message_is_cached(&self, from: &SP::Verifier, round_id: RoundId) -> bool {
fn message_is_cached(&self, from: &SP::Verifier, round_id: &RoundId) -> bool {
if let Some(entry) = self.cached.get(from) {
entry.contains_key(&round_id)
entry.contains_key(round_id)
} else {
false
}
Expand Down Expand Up @@ -745,7 +745,7 @@ where
let from = message.from().clone();
let round_id = message.metadata().round_id();
let cached = self.cached.entry(from.clone()).or_default();
if cached.insert(round_id, message).is_some() {
if cached.insert(round_id.clone(), message).is_some() {
return Err(LocalError::new(format!(
"A message from for {:?} has already been cached",
round_id
Expand All @@ -771,7 +771,7 @@ pub struct ProcessedMessage<P: Protocol, SP: SessionParameters> {
#[derive(Debug, Clone)]
pub enum PreprocessOutcome<Verifier> {
/// The message was successfully verified, pass it on to [`Session::process_message`].
ToProcess(VerifiedMessage<Verifier>),
ToProcess(Box<VerifiedMessage<Verifier>>),
/// The message was intended for the next round and was cached.
///
/// No action required now, cached messages will be returned on successful [`Session::finalize_round`].
Expand All @@ -795,7 +795,7 @@ impl<Verifier> PreprocessOutcome<Verifier> {
/// so the user may choose to ignore them if no logging is desired.
pub fn ok(self) -> Option<VerifiedMessage<Verifier>> {
match self {
Self::ToProcess(message) => Some(message),
Self::ToProcess(message) => Some(*message),
_ => None,
}
}
Expand Down
Loading

0 comments on commit 600cf85

Please sign in to comment.