Skip to content

Commit

Permalink
perf: optimize register parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
delehef committed Nov 27, 2023
1 parent 8664ff4 commit db9f961
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 28 deletions.
7 changes: 3 additions & 4 deletions src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,13 @@ impl Value {
}

pub(crate) fn vector_sub_assign(&mut self, other: &Value) {
let msg = format!("{} .- {}", self, other);
match (self, other) {
(Value::BigInt(ref mut i1), Value::BigInt(ref i2)) => *i1 -= i2,
(Value::BigInt(_), Value::Native(_)) => todo!(),
(Value::BigInt(_), Value::ExoNative(_)) => todo!(),
(Value::Native(_), Value::BigInt(_)) => todo!(),
(Value::Native(ref mut f1), Value::Native(ref f2)) => f1.sub_assign(f2),
(Value::Native(_), Value::ExoNative(_)) => todo!("{}", msg),
(Value::Native(_), Value::ExoNative(_)) => todo!(),
(Value::ExoNative(_), Value::BigInt(_)) => todo!(),
(Value::ExoNative(_), Value::Native(_)) => todo!(),
(Value::ExoNative(f1s), Value::ExoNative(f2s)) => f1s
Expand Down Expand Up @@ -265,7 +264,7 @@ impl Value {
pub(crate) fn to_bi_variant(&self) -> Value {
match self {
Value::BigInt(_) => self.clone(),
Value::Native(fr) => Value::BigInt(self.to_bi()),
Value::Native(_) => Value::BigInt(self.to_bi()),
_ => unimplemented!(),
}
}
Expand Down Expand Up @@ -727,7 +726,7 @@ impl<'a> Iterator for ValueBackingIter<'a> {
v.get(self.i as usize).cloned()
}
}
ValueBacking::Expression { spilling, .. } => {
ValueBacking::Expression { .. } => {
if self.i >= self.len {
None
} else {
Expand Down
74 changes: 50 additions & 24 deletions src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use log::*;
use logging_timer::time;
use num_bigint::{BigInt, Sign};
use owo_colors::OwoColorize;
use rayon::prelude::*;
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx")))]
use serde_json::Value;
#[cfg(all(target_arch = "x86_64", target_feature = "avx"))]
Expand Down Expand Up @@ -87,11 +88,19 @@ impl<Data: AsRef<[u8]>> TraceReader<Data> {
}

fn header(&mut self) -> Result<RegisterHeader> {
let handle_length = self.i16()?;
let handle_str = self.string(handle_length as usize)?;
let handle_length = self
.i16()
.with_context(|| anyhow!("parsing a register name length"))?;
let handle_str = self
.string(handle_length as usize)
.with_context(|| anyhow!("parsing a register name"))?;
let mut splitted = handle_str.splitn(2, '.');
let bytes_per_element = self.i8()? as usize;
let length = self.i32()?;
let bytes_per_element =
self.i8()
.with_context(|| anyhow!("parsing BPE for {}", handle_str))? as usize;
let length = self
.i32()
.with_context(|| anyhow!("parsing length of {}", handle_str))?;

Ok(RegisterHeader {
handle: Handle::new(splitted.next().unwrap(), splitted.next().unwrap()),
Expand All @@ -101,37 +110,53 @@ impl<Data: AsRef<[u8]>> TraceReader<Data> {
}

fn map(&mut self) -> Result<Vec<RegisterHeader>> {
let column_count = self.i32()?;
(0..column_count).map(|_| self.header()).collect()
let register_count = self.i32().with_context(|| "parsing register count")?;
(0..register_count).map(|_| self.header()).collect()
}
}

#[time("info", "Parsing binary traces")]
pub fn parse_flat_trace(tracefile: &str, cs: &mut ConstraintSet) -> Result<()> {
let mut trace_reader =
TraceReader::from(unsafe { memmap2::MmapOptions::new().map(&File::open(tracefile)?)? });
for column in trace_reader.map()?.into_iter() {
let column_ref: ColumnRef = column.handle.clone().into();
let mut xs = std::iter::once(Ok(CValue::zero()))
.chain((0..column.length).map(|_| {
trace_reader
.slice(column.bytes_per_element)
.map(|bs| CValue::from(BigInt::from_bytes_be(Sign::Plus, &bs)))
}))
.collect::<Result<Vec<_>>>()?;
for register in trace_reader.map()?.into_iter() {
let column_ref: ColumnRef = register.handle.clone().into();
let register_bytes =
trace_reader.slice(register.length as usize * register.bytes_per_element)?;
let mut xs = (-1..register.length)
.into_par_iter()
.map(|i| {
if i == -1 {
Ok(CValue::zero())
} else {
let i = i as usize;
register_bytes
.get(i * register.bytes_per_element..(i + 1) * register.bytes_per_element)
.map(|bs| CValue::from(BigInt::from_bytes_be(Sign::Plus, &bs)))
.with_context(|| anyhow!("reading {}th element", i))
}
})
.collect::<Result<Vec<_>>>()
.with_context(|| {
anyhow!(
"reading data for {} ({} elts. expected)",
register.handle.pretty(),
register.length
)
})?;

let module_min_len = cs
.columns
.min_len
.get(&column.handle.module)
.get(&register.handle.module)
.cloned()
.unwrap_or(0);

if let Some(Register { magma, .. }) = cs.columns.register(&column_ref) {
debug!("Importing {}", column.handle.pretty());
debug!("Importing {}", register.handle.pretty());
let module_spilling = cs
.spilling_for_column(&column_ref)
.ok_or_else(|| anyhow!("no spilling found for {}", column.handle.pretty()))?;
.ok_or_else(|| anyhow!("no spilling found for {}", register.handle.pretty()))?;
// If the parsed column is not long enought w.r.t. the
// minimal module length, prepend it with as many zeroes as
// required.
Expand All @@ -143,20 +168,21 @@ pub fn parse_flat_trace(tracefile: &str, cs: &mut ConstraintSet) -> Result<()> {
xs.reverse();
}

let module_raw_size = cs.effective_len_or_set(&column.handle.module, xs.len() as isize);
let module_raw_size =
cs.effective_len_or_set(&register.handle.module, xs.len() as isize);
if xs.len() as isize != module_raw_size {
bail!(
"{} has an incorrect length: expected {}, found {}",
column.handle.to_string().blue(),
register.handle.to_string().blue(),
module_raw_size.to_string().red().bold(),
xs.len().to_string().yellow().bold(),
);
}

cs.columns
.set_register_value(&column.handle.into(), xs, module_spilling)?
.set_register_value(&register.handle.into(), xs, module_spilling)?
} else {
debug!("ignoring unknown column {}", column.handle.pretty());
debug!("ignoring unknown column {}", register.handle.pretty());
}
}

Expand Down Expand Up @@ -335,7 +361,7 @@ pub fn fill_traces_from_json(
.ok_or_else(|| anyhow!("no spilling found for {}", handle.pretty()))?;

let mut xs = parse_column(xs, handle.as_handle(), *t)
.with_context(|| anyhow!("while importing {}", handle))?;
.with_context(|| anyhow!("importing {}", handle))?;

// If the parsed column is not long enought w.r.t. the
// minimal module length, prepend it with as many zeroes as
Expand Down Expand Up @@ -368,7 +394,7 @@ pub fn fill_traces_from_json(
.ok_or_else(|| anyhow!("no spilling found for {}", handle.pretty()))?;

let mut xs = parse_column(xs, handle.as_handle(), *magma)
.with_context(|| anyhow!("while importing {}", handle))?;
.with_context(|| anyhow!("importing {}", handle))?;

// If the parsed column is not long enought w.r.t. the
// minimal module length, prepend it with as many zeroes as
Expand Down

0 comments on commit db9f961

Please sign in to comment.