Skip to content

Commit

Permalink
fix: columns import/export for BLS
Browse files Browse the repository at this point in the history
  • Loading branch information
delehef committed Nov 9, 2023
1 parent a55d240 commit 8bd7937
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 125 deletions.
40 changes: 19 additions & 21 deletions src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,23 +220,11 @@ impl Value {
}
}

pub(crate) fn to_repr(&self) -> impl Iterator<Item = u64> {
let us = match &self {
Value::Native(f) => f.0 .0.to_vec(),
Value::ExoNative(fs) => fs.iter().flat_map(|f| f.0 .0.iter()).cloned().collect(),
Value::BigInt(_) => todo!(),
};
us.into_iter()
}

pub(crate) fn to_bytes(&self) -> Vec<u8> {
match &self {
Value::Native(f) => f.0 .0.iter().flat_map(|u| u.to_be_bytes()).collect(),
Value::ExoNative(fs) => fs
.iter()
.flat_map(|f| f.0 .0.iter().flat_map(|u| u.to_be_bytes()))
.collect(),
Value::BigInt(bi) => bi.to_bytes_be().1,
Value::Native(f) => f.into_bigint().to_bytes_be(),
Value::ExoNative(_) => todo!(),
}
}

Expand Down Expand Up @@ -649,16 +637,16 @@ impl ValueBacking {
}
}
.cloned(),
ValueBacking::Expression { e, spilling } => e.eval(
i + spilling,
ValueBacking::Expression { e, .. } => e.eval(
i,
|handle, j, _| {
cs.get(handle, j, false)
.or_else(|| cs.column(handle).unwrap().padding_value.as_ref().cloned())
},
&mut None,
&EvalSettings { wrap: false },
),
ValueBacking::Function { f, spilling } => f(i + spilling, cs),
ValueBacking::Function { f, .. } => f(i, cs),
}
}

Expand Down Expand Up @@ -689,10 +677,11 @@ impl ValueBacking {
}
}

pub fn iter<'a>(&'a self, columns: &'a ColumnSet) -> ValueBackingIter<'a> {
pub fn iter<'a>(&'a self, columns: &'a ColumnSet, len: isize) -> ValueBackingIter<'a> {
ValueBackingIter {
value: self,
i: 0,
len,
columns,
}
}
Expand Down Expand Up @@ -724,6 +713,7 @@ impl ValueBacking {
pub struct ValueBackingIter<'a> {
value: &'a ValueBacking,
columns: &'a ColumnSet,
len: isize,
i: isize,
}

Expand All @@ -741,9 +731,17 @@ impl<'a> Iterator for ValueBackingIter<'a> {
v.get(self.i as usize).cloned()
}
}
ValueBacking::Expression { .. } => {
self.i += 1;
self.value.get_raw(self.i - 1, false, self.columns)
ValueBacking::Expression { spilling, .. } => {
if self.i >= self.len {
None
} else {
self.i += 1;
Some(
self.value
.get(self.i - 1, false, self.columns)
.unwrap_or_default(),
)
}
}
ValueBacking::Function { f, .. } => {
self.i += 1;
Expand Down
3 changes: 2 additions & 1 deletion src/compiler/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,7 @@ impl ConstraintSet {
let empty_backing: ValueBacking = ValueBacking::default();
while let Some((r, column)) = current_col.next() {
let handle = &column.handle;
let module_size = self.effective_len_for(&handle.module).unwrap();
trace!("Writing {}", handle);
let backing = self.columns.backing(&r).unwrap_or_else(|| &empty_backing);
let padding: Value = if let Some(v) = column.padding_value.as_ref() {
Expand Down Expand Up @@ -904,7 +905,7 @@ impl ConstraintSet {
out.write_all(format!("\"{}\":{{\n", handle).as_bytes())?;
out.write_all("\"values\":[".as_bytes())?;

let mut value = backing.iter(&self.columns).peekable();
let mut value = backing.iter(&self.columns, module_size).peekable();
while let Some(x) = value.next() {
out.write_all(
cache
Expand Down
16 changes: 14 additions & 2 deletions src/compiler/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -995,8 +995,20 @@ impl Display for Node {

match self.e() {
Expression::Const(x) => write!(f, "{}", x),
Expression::Column { handle, .. } | Expression::ExoColumn { handle, .. } => {
write!(f, "{}", handle.to_string_short())
Expression::Column { handle, shift, .. }
| Expression::ExoColumn { handle, shift, .. } => {
write!(
f,
"{}{}",
handle.to_string_short(),
if *shift > 0 {
format!("₊{}", crate::pretty::subscript(&shift.to_string()))
} else if *shift < 0 {
crate::pretty::subscript(&shift.to_string())
} else {
Default::default()
}
)
}
Expression::ArrayColumn { handle, domain, .. } => {
write!(f, "{}{}", handle.to_string_short(), domain)
Expand Down
35 changes: 19 additions & 16 deletions src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,6 @@ pub fn fill_traces(
if path.len() >= 2 {
let module = path[path.len() - 2].to_string();
let handle: ColumnRef = Handle::new(&module, &path[path.len() - 1]).into();
// The first column sets the size of its module
let module_raw_size = cs.effective_len_or_set(&module, xs.len() as isize);

// The min length can be set if the module contains range
// proofs, that require a minimal length of a certain power of 2
Expand All @@ -197,16 +195,6 @@ pub fn fill_traces(
let module_spilling = module_spilling
.ok_or_else(|| anyhow!("no spilling found for {}", handle.pretty()))?;

if xs.len() as isize != module_raw_size {
bail!(
"{} has an incorrect length: expected {} (from {}), found {}",
handle.to_string().blue(),
module_raw_size.to_string().red().bold(),
initiator.as_ref().unwrap(),
xs.len().to_string().yellow().bold(),
);
}

let mut xs = parse_column(xs, handle.as_handle(), *t)
.with_context(|| anyhow!("while importing {}", handle))?;

Expand All @@ -222,11 +210,9 @@ pub fn fill_traces(
});
xs.reverse();
}
cs.columns.set_column_value(&handle, xs, module_spilling)?
} else if let Some(Register { magma, .. }) = cs.columns.register(&handle) {
let module_spilling = module_spilling
.ok_or_else(|| anyhow!("no spilling found for {}", handle.pretty()))?;

// The first column sets the size of its module
let module_raw_size = cs.effective_len_or_set(&module, xs.len() as isize);
if xs.len() as isize != module_raw_size {
bail!(
"{} has an incorrect length: expected {} (from {}), found {}",
Expand All @@ -237,6 +223,11 @@ pub fn fill_traces(
);
}

cs.columns.set_column_value(&handle, xs, module_spilling)?
} else if let Some(Register { magma, .. }) = cs.columns.register(&handle) {
let module_spilling = module_spilling
.ok_or_else(|| anyhow!("no spilling found for {}", handle.pretty()))?;

let mut xs = parse_column(xs, handle.as_handle(), *magma)
.with_context(|| anyhow!("while importing {}", handle))?;

Expand All @@ -250,6 +241,18 @@ pub fn fill_traces(
xs.resize(module_min_len, CValue::zero()); // TODO: register padding values
xs.reverse();
}

let module_raw_size = cs.effective_len_or_set(&module, xs.len() as isize);
if xs.len() as isize != module_raw_size {
bail!(
"{} has an incorrect length: expected {} (from {}), found {}",
handle.to_string().blue(),
module_raw_size.to_string().red().bold(),
initiator.as_ref().unwrap(),
xs.len().to_string().yellow().bold(),
);
}

cs.columns
.set_register_value(&handle, xs, module_spilling)?
} else {
Expand Down
106 changes: 22 additions & 84 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,13 @@ fn cstr_to_string<'a>(s: *const c_char) -> &'a str {
name.to_str().unwrap()
}

const EMPTY_MARKER: [u64; 4] = [0, u64::MAX, 0, u64::MAX];
const EMPTY_MARKER: [u8; 32] = [
2, 4, 8, 16, 32, 64, 128, 255, 255, 128, 64, 32, 16, 8, 4, 2, 2, 4, 8, 16, 32, 64, 128, 255,
255, 128, 64, 32, 16, 8, 4, 2,
];
struct ComputedColumn {
padding_value: [u64; 4],
values: Vec<[u64; 4]>,
padding_value: [u8; 32],
values: Vec<[u8; 32]>,
}
impl ComputedColumn {
fn empty() -> Self {
Expand All @@ -118,7 +121,7 @@ pub struct Trace {
ids: Vec<String>,
}
impl Trace {
fn from_constraints(c: &Corset, convert_to_be: bool) -> Self {
fn from_constraints(c: &Corset) -> Self {
let mut r = Trace {
..Default::default()
};
Expand All @@ -132,6 +135,11 @@ impl Trace {

let column = c.columns.column(cref).unwrap();
let handle = &column.handle;
let module_size = c
.effective_len_for(&handle.module)
// If the module is empty, use its spilling
.or_else(|| c.spilling_of(&handle.module))
.unwrap();
trace!("Writing {}", handle);
let backing = c.columns.backing(cref).unwrap_or(&empty_backing);
let padding: Value = if let Some(v) = column.padding_value.as_ref() {
Expand Down Expand Up @@ -162,23 +170,10 @@ impl Trace {
(
ComputedColumn {
values: backing
.iter(&c.columns)
.map(|x| {
let mut v = x.to_repr().collect::<Vec<_>>().try_into().unwrap();
if convert_to_be {
reverse_fr(&mut v);
}
v
})
.iter(&c.columns, module_size)
.map(|x| x.to_bytes().try_into().unwrap())
.collect(),
padding_value: {
let mut padding =
padding.to_repr().collect::<Vec<_>>().try_into().unwrap();
if convert_to_be {
reverse_fr(&mut padding);
}
padding
},
padding_value: { padding.to_bytes().try_into().unwrap() },
},
c.handle(cref).to_string(),
)
Expand All @@ -204,56 +199,6 @@ impl Trace {
}
}

fn reverse_fr(v: &mut [u64; 4]) {
#[cfg(target_arch = "aarch64")]
reverse_fr_aarch64(v);
#[cfg(target_arch = "x86_64")]
reverse_fr_x86_64(v);
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
reverse_fr_fallback(v);
}

fn reverse_fr_fallback(v: &mut [u64; 4]) {
for vi in v.iter_mut() {
*vi = vi.swap_bytes();
}
v.swap(0, 3);
v.swap(1, 2);
}

#[cfg(target_arch = "aarch64")]
fn reverse_fr_aarch64(v: &mut [u64; 4]) {
for vi in v.iter_mut() {
*vi = vi.swap_bytes();
}
v.swap(0, 3);
v.swap(1, 2);
}

#[cfg(target_arch = "x86_64")]
fn reverse_fr_x86_64(v: &mut [u64; 4]) {
if is_x86_feature_detected!("avx2") {
unsafe {
use std::arch::x86_64::*;
let inverter = _mm256_set_epi64x(
0x0001020304050607,
0x08090a0b0c0d0e0f,
0x0001020304050607,
0x08090a0b0c0d0e0f,
);
let value = _mm256_loadu_si256(v.as_ptr() as *const __m256i);
let x = _mm256_shuffle_epi8(value, inverter);
*v = std::mem::transmute(_mm256_permute2f128_si256(x, x, 0x01));
}
} else {
for vi in v.iter_mut() {
*vi = vi.swap_bytes();
}
v.swap(0, 3);
v.swap(1, 2);
}
}

fn make_corset(mut constraints: ConstraintSet) -> Result<Corset> {
transformer::expand_to(
&mut constraints,
Expand Down Expand Up @@ -284,23 +229,21 @@ fn _corset_from_str(zkevmstr: &str) -> Result<Corset> {
fn _compute_trace_from_file(
constraints: &mut Corset,
tracefile: &str,
convert_to_be: bool,
fail_on_missing: bool,
) -> Result<Trace> {
compute::compute_trace(tracefile, constraints, fail_on_missing)
.with_context(|| format!("while computing from file `{}`", tracefile))?;
Ok(Trace::from_constraints(constraints, convert_to_be))
Ok(Trace::from_constraints(constraints))
}

fn _compute_trace_from_str(
constraints: &mut Corset,
tracestr: &str,
convert_to_be: bool,
fail_on_missing: bool,
) -> Result<Trace> {
compute::compute_trace_str(tracestr.as_bytes(), constraints, fail_on_missing)
.with_context(|| format!("while computing from string `{}`", tracestr))?;
Ok(Trace::from_constraints(constraints, convert_to_be))
Ok(Trace::from_constraints(constraints))
}

#[no_mangle]
Expand Down Expand Up @@ -415,16 +358,14 @@ pub extern "C" fn trace_compute_from_file(
corset: *mut Corset,
tracefile: *const c_char,
threads: c_uint,
convert_to_be: bool,
fail_on_missing: bool,
) -> *mut Trace {
match init_rayon(threads) {
Result::Ok(tp) => {
let tracefile = cstr_to_string(tracefile);
let constraints = Corset::mut_from_ptr(corset);
let r = tp.install(|| {
_compute_trace_from_file(constraints, tracefile, convert_to_be, fail_on_missing)
});
let r =
tp.install(|| _compute_trace_from_file(constraints, tracefile, fail_on_missing));
match r {
Err(e) => {
eprintln!("{:?}", e);
Expand All @@ -446,7 +387,6 @@ pub extern "C" fn trace_compute_from_string(
corset: *mut Corset,
tracestr: *const c_char,
threads: c_uint,
convert_to_be: bool,
fail_on_missing: bool,
) -> *mut Trace {
match init_rayon(threads) {
Expand All @@ -458,9 +398,7 @@ pub extern "C" fn trace_compute_from_string(
}

let constraints = Corset::mut_from_ptr(corset);
let r = tp.install(|| {
_compute_trace_from_str(constraints, tracestr, convert_to_be, fail_on_missing)
});
let r = tp.install(|| _compute_trace_from_str(constraints, tracestr, fail_on_missing));
match r {
Err(e) => {
eprintln!("{:?}", e);
Expand Down Expand Up @@ -510,8 +448,8 @@ pub extern "C" fn trace_column_names(trace: *const Trace) -> *const *mut c_char

#[repr(C)]
pub struct ColumnData {
padding_value: [u64; 4],
values: *const [u64; 4],
padding_value: [u8; 32],
values: *const [u8; 32],
values_len: u64,
}
impl Default for ColumnData {
Expand Down
Loading

0 comments on commit 8bd7937

Please sign in to comment.