diff --git a/crates/gmw/src/parse/lut_circuit.rs b/crates/gmw/src/parse/lut_circuit.rs index b1c2dd2..dd814aa 100644 --- a/crates/gmw/src/parse/lut_circuit.rs +++ b/crates/gmw/src/parse/lut_circuit.rs @@ -96,7 +96,7 @@ pub struct WireMask { } #[derive(Debug, Clone, Eq, PartialEq)] pub struct WireOutput { - pub unexpanded: BitVec, + pub unexpanded: BitVec, } #[derive(Debug, PartialEq)] @@ -356,33 +356,41 @@ fn wire_output<'w, 'i>( wire_mask: &'w BitSlice, ) -> impl Fn(&'i str) -> IResult<&'i str, WireOutput, LutParseError<&'i str>> + 'w { move |i: &str| { + // drop the 0x prefix and parse the hexadecimal number into a BigUint let (i, out_mask) = map_res(preceded(tag("0x"), hex_digit1), |hex_val| { BigUint::from_str_radix(hex_val, 16) })(i)?; - let le_bytes = out_mask.to_bytes_le(); - let mut out_bit_mask = BitVec::from_slice(&le_bytes); - out_bit_mask.truncate(2_usize.pow(wire_mask.count_ones() as u32)); - let trailing_zeros = out_bit_mask.trailing_zeros(); - out_bit_mask.shift_right(trailing_zeros); - // if le_bytes.len() == 1 { - // let shift = trailing_zeros.saturating_sub(4); - // out_bit_mask.shift_right(shift) - // } else { - // out_bit_mask.shift_right(trailing_zeros); - // } - // out_bit_mask.reverse(); - // let mut out_bit_mask = match out_mask.to_bytes_be()[..] { - // [one_byte] => { - // let upper = 2_usize.pow(wire_mask.count_ones() as u32); - // let lower = upper.saturating_sub(8); - // let mut bv = BitVec::repeat(false, upper); - // bv[lower..upper].store(one_byte); - // bv - // } - // ref bytes => BitVec::from_slice(bytes), - // }; - // TODO check that this is actually correct and then make it less confusing - // out_bit_mask.reverse(); + // the big integer is split into its bytes in big endian order + // this means, the most significant bytes come first + let be_bytes = out_mask.to_bytes_be(); + // these big endian bytes are put into a bit vector with most significant **bit** first + // ordering (Msb0) + let mut out_bit_mask = BitVec::::from_slice(&be_bytes); + // because the relevant bits of the parse number can be smaller than a single byte (for an + // LUT with two active output wires, output bit mask should have 4 bits), or the bit we need + // can be larger than initial hex number as a bitvec, we calculate the required shift to + // the right or left + let used_bits = 2_usize.pow(wire_mask.count_ones() as u32); + let shift = used_bits as i32 - out_bit_mask.len() as i32; + if shift < 0 { + // out_bit_mask.len() > used_bits + // this happens if we have two active output wires (wire_mask.count_ones() == 2) but + // the minimum size of the out_bit_mask is 8, as we always get at least one byte + // from out_mask.to_bytes_be() + // therefore, we shift the out_bit_mask by the absolute shift amount to the left + // (thus removing superflous 0 bits) and truncate the bitvec to the correct size + out_bit_mask.shift_left(shift.abs() as usize); + out_bit_mask.truncate(used_bits); + } else if shift > 0 { + // out_bit_mask.len() < used_bits + // this happens if we require a large output_bit_mask, e.g. for a LUT with 8 outputs + // which would require a size 256 output, but the initial hex number was something + // small like 0x8, which only results in a single byte. Thus, we first resize the + // bitmask to the needed length and then shift_right the original out_bit_mask so the + // used bits are at the end of the bitvec + out_bit_mask.resize(used_bits, false); + out_bit_mask.shift_right(shift as usize); + } Ok(( i, WireOutput { @@ -516,8 +524,10 @@ impl LutParseError { } } -fn expand(unexpanded: &BitSlice, wire_mask: &BitSlice) -> BitVec { - let mut expanded = unexpanded.to_bitvec(); +fn expand(unexpanded: &BitSlice, wire_mask: &BitSlice) -> BitVec { + let used_bits = 2_usize.pow(wire_mask.len() as u32); + let mut expanded = BitVec::with_capacity(used_bits); + expanded.extend_from_bitslice(unexpanded); let mut new_expanded = BitVec::repeat(false, 2_usize.pow(wire_mask.len() as u32)); for (idx, bit) in wire_mask.iter().rev().enumerate() { if *bit { @@ -594,175 +604,182 @@ fn get_prefix<'a, 'b>(with_prefix: &'a str, sub_str: &'b str) -> &'a str { &with_prefix[..with_prefix.len() - sub_str.len()] } -// #[cfg(test)] -// mod tests { -// use super::*; -// use crate::private_test_utils::init_tracing; -// use bitvec::order::Msb0; -// use bitvec::{bits, bitvec}; -// -// fn inp(num: u32) -> Input { -// Input(num.to_string()) -// } -// -// fn out(num: u32) -> Output { -// Output(num.to_string()) -// } -// -// fn in_cache(wires: &[u32]) -> HashSet { -// wires.iter().copied().map(inp).collect() -// } -// -// fn out_cache(wires: &[u32]) -> HashSet { -// wires.iter().copied().map(out).collect() -// } -// -// #[test] -// fn test_expand() { -// let val = bitvec![u8, Msb0; 0,1,1,0,1,0,0,1,0,0,0,0,0,0,0,0]; -// let wire_mask = bitvec![u8, Msb0; 1, 1, 0, 1]; -// let expected = bits![u8, Msb0; 0,1,0,1,1,0,1,0,1,0,1,0,0,1,0,1]; -// let expanded = expand(&val, &wire_mask[0..4]); -// assert_eq!(expected, &expanded); -// -// let val = bitvec![u8, Msb0; 1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0]; -// let wire_mask = bitvec![u8, Msb0; 0, 1, 0, 0]; -// let expected = bits![u8, Msb0; 1; 16]; -// let expanded = expand(&val, &wire_mask[0..4]); -// assert_eq!(expected, &expanded); -// -// let val = bitvec![u8, Msb0; 0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0]; -// let wire_mask = bitvec![u8, Msb0; 0, 1, 1, 0]; -// let expected = bits![u8, Msb0; 0,0,0,0,0,0,1,1,0,0,0,0,0,0,1,1]; -// let expanded = expand(&val, &wire_mask[0..4]); -// assert_eq!(expected, &expanded); -// -// let val = bitvec![u8, Msb0; 0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0]; -// let wire_mask = bitvec![u8, Msb0; 1, 1, 1, 1]; -// let expected = bits![u8, Msb0; 0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0]; -// let expanded = expand(&val, &wire_mask[0..4]); -// assert_eq!(expected, &expanded); -// -// let val = bitvec![u8, Msb0; 1,0,1,0,1,0,1,0,0,0,1,1,1,0,0,1]; -// let wire_mask = bitvec![u8, Msb0; 1, 1, 1, 1]; -// let expected = bits![u8, Msb0; 1,0,1,0,1,0,1,0,0,0,1,1,1,0,0,1]; -// let expanded = expand(&val, &wire_mask[0..4]); -// assert_eq!(expected, &expanded); -// } -// -// #[test] -// fn parse_wire() { -// let in_cache = in_cache(&[0][..]); -// let out_cache = out_cache(&[42][..]); -// let wire = wire(&in_cache, &out_cache); -// assert_eq!(Ok(("", Wire::Input(inp(0)))), wire("0")); -// assert_eq!(Ok(("", Wire::Output(out(42)))), wire("42")); -// assert_eq!(Ok(("", Wire::Internal(Internal(42)))), wire("n42")); -// } -// -// #[test] -// fn parse_wire_output() { -// assert_eq!( -// Ok(( -// "", -// WireOutput { -// unexpanded: bitvec![u8, Msb0; 1,0,] -// } -// )), -// wire_output(bits![u8, Msb0; 1])("0x2") -// ); -// } -// -// #[test] -// fn parse_wire_mask() { -// let mask = bitvec![u8, Msb0; 1,0,1,1]; -// assert_eq!(Ok(("", WireMask { mask })), wire_mask(4)("3 1011")); -// } -// -// #[test] -// fn parse_masked_lut() { -// let in_cache = in_cache(&[42][..]); -// let out_cache = out_cache(&[5][..]); -// let exp_masked_lut = MaskedLut { -// wire_mask: WireMask { -// mask: bitvec![u8, Msb0; 1,0,0,0], -// }, -// output: WireOutput { -// unexpanded: bitvec![u8, Msb0; 1,1], -// }, -// out_wire: Wire::Output(out(5)), -// }; -// let parsed = masked_lut(&in_cache, &out_cache, 4)("1 1000 0x3 5").unwrap(); -// -// assert_eq!(("", exp_masked_lut), parsed); -// assert_eq!(bits![u8, Msb0; 1; 16], &parsed.1.expanded()) -// } -// -// #[test] -// fn parse_lut() { -// let exp_masked_lut = MaskedLut { -// wire_mask: WireMask { -// mask: bitvec![u8, Msb0; 1,0,0,0], -// }, -// output: WireOutput { -// unexpanded: bitvec![u8, Msb0; 1,1,], -// }, -// out_wire: Wire::Output(out(5)), -// }; -// -// let expected_lut = Lut { -// input_wires: SmallVec::from([ -// Wire::Input(inp(0)), -// Wire::Input(inp(1)), -// Wire::Input(inp(2)), -// Wire::Internal(Internal(6)), -// ]), -// masked_luts: SmallVec::from_elem(exp_masked_lut, 2), -// }; -// -// let in_cache = in_cache(&[0, 1, 2][..]); -// let out_cache = out_cache(&[5][..]); -// -// assert_eq!( -// Ok(("", expected_lut)), -// lut(&in_cache, &out_cache)("LUT 4 2 0 1 2 n6 1 1000 0x3 5 1 1000 0x3 5") -// ); -// } -// -// #[test] -// fn parse_sample_circuit() { -// let circ = Circuit::load(Path::new("test_resources/Sample LUT file.lut")).unwrap(); -// dbg!(circ); -// } -// -// #[test] -// fn gate_execute_lut() { -// let wire = |num| Wire::Internal(Internal(num)); -// let gate = Gate::Lut(Lut { -// input_wires: SmallVec::from_vec(vec![wire(1), wire(2), wire(3), wire(4)]), -// masked_luts: SmallVec::from(vec![MaskedLut { -// wire_mask: WireMask { -// mask: bitvec![u8, Msb0; 1,1,0,1], -// }, -// output: WireOutput { -// unexpanded: bitvec![u8, Msb0; 0,1,1,0,1,1,0,1], -// }, -// out_wire: wire(5), -// }]), -// }); -// let mut wire_vals = [(wire(1), false), (wire(2), false), (wire(4), false)] -// .into_iter() -// .collect(); -// gate.execute(&mut wire_vals); -// assert_eq!(false, wire_vals[&wire(4)]); -// } -// -// #[test] -// fn circuit_execute() { -// let _g = init_tracing(); -// let circ = Circuit::load(Path::new("test_resources/lut_circuits/lfa32_4.lut")).unwrap(); -// let inp = bitvec![u8, Msb0; 0;64]; -// let out = circ.execute(&inp); -// assert_eq!(bits![u8, Msb0; 0;33], out); -// } -// } +#[cfg(test)] +mod tests { + use super::*; + use crate::private_test_utils::init_tracing; + use bitvec::order::Msb0; + use bitvec::{bits, bitvec}; + + fn inp(num: u32) -> Input { + Input(num.to_string()) + } + + fn out(num: u32) -> Output { + Output(num.to_string()) + } + + fn in_cache(wires: &[u32]) -> HashSet { + wires.iter().copied().map(inp).collect() + } + + fn out_cache(wires: &[u32]) -> HashSet { + wires.iter().copied().map(out).collect() + } + + #[test] + fn test_expand() { + let val = bitvec![u8, Msb0; 0,1,1,0,1,0,0,1,0,0,0,0,0,0,0,0]; + let wire_mask = bitvec![u8, Msb0; 1, 1, 0, 1]; + let expected = bits![u8, Msb0; 0,1,0,1,1,0,1,0,1,0,1,0,0,1,0,1]; + let expanded = expand(&val, &wire_mask[0..4]); + assert_eq!(expected, &expanded); + + let val = bitvec![u8, Msb0; 1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0]; + let wire_mask = bitvec![u8, Msb0; 0, 1, 0, 0]; + let expected = bits![u8, Msb0; 1; 16]; + let expanded = expand(&val, &wire_mask[0..4]); + assert_eq!(expected, &expanded); + + let val = bitvec![u8, Msb0; 0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0]; + let wire_mask = bitvec![u8, Msb0; 0, 1, 1, 0]; + let expected = bits![u8, Msb0; 0,0,0,0,0,0,1,1,0,0,0,0,0,0,1,1]; + let expanded = expand(&val, &wire_mask[0..4]); + assert_eq!(expected, &expanded); + + let val = bitvec![u8, Msb0; 0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0]; + let wire_mask = bitvec![u8, Msb0; 1, 1, 1, 1]; + let expected = bits![u8, Msb0; 0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0]; + let expanded = expand(&val, &wire_mask[0..4]); + assert_eq!(expected, &expanded); + + let val = bitvec![u8, Msb0; 1,0,1,0,1,0,1,0,0,0,1,1,1,0,0,1]; + let wire_mask = bitvec![u8, Msb0; 1, 1, 1, 1]; + let expected = bits![u8, Msb0; 1,0,1,0,1,0,1,0,0,0,1,1,1,0,0,1]; + let expanded = expand(&val, &wire_mask[0..4]); + assert_eq!(expected, &expanded); + } + + #[test] + fn parse_wire() { + let in_cache = in_cache(&[0][..]); + let out_cache = out_cache(&[42][..]); + let wire = wire(&in_cache, &out_cache); + assert_eq!(Ok(("", Wire::Input(inp(0)))), wire("0")); + assert_eq!(Ok(("", Wire::Output(out(42)))), wire("42")); + assert_eq!( + Ok(("", Wire::Internal(Internal("n42".into())))), + wire("n42") + ); + } + + #[test] + fn parse_wire_output() { + assert_eq!( + Ok(( + "", + WireOutput { + unexpanded: bitvec![u8, Msb0; 1,0,] + } + )), + wire_output(bits![u8, Msb0; 1])("0x2") + ); + } + + #[test] + fn parse_wire_mask() { + let mask = bitvec![u8, Msb0; 1,0,1,1]; + assert_eq!( + Ok(("", WireMask { mask })), + wire_mask(&[false; 4])("3 1011") + ); + } + + #[test] + fn parse_masked_lut() { + let in_cache = in_cache(&[42][..]); + let out_cache = out_cache(&[5][..]); + let exp_masked_lut = MaskedLut { + wire_mask: WireMask { + mask: bitvec![u8, Msb0; 1,0,0,0], + }, + output: WireOutput { + unexpanded: bitvec![u8, Msb0; 1,1], + }, + out_wire: Wire::Output(out(5)), + }; + let parsed = masked_lut(&in_cache, &out_cache, &[false; 4])("1 1000 0x3 5").unwrap(); + + assert_eq!(("", exp_masked_lut), parsed); + assert_eq!(bits![u8, Msb0; 1; 16], &parsed.1.expanded()) + } + + #[test] + fn parse_lut() { + let exp_masked_lut = MaskedLut { + wire_mask: WireMask { + mask: bitvec![u8, Msb0; 1,0,0,0], + }, + output: WireOutput { + unexpanded: bitvec![u8, Msb0; 1,1,], + }, + out_wire: Wire::Output(out(5)), + }; + + let expected_lut = Lut { + input_wires: SmallVec::from([ + Wire::Input(inp(0)), + Wire::Input(inp(1)), + Wire::Input(inp(2)), + Wire::Internal(Internal("n6".into())), + ]), + masked_luts: SmallVec::from_elem(exp_masked_lut, 2), + }; + + let in_cache = in_cache(&[0, 1, 2][..]); + let out_cache = out_cache(&[5][..]); + + assert_eq!( + Ok(("", expected_lut)), + lut(&in_cache, &out_cache)("LUT 4 2 0 1 2 n6 1 1000 0x3 5 1 1000 0x3 5") + ); + } + + #[test] + fn parse_sample_circuit() { + let circ = + Circuit::load(Path::new("test_resources/lut_circuits/Sample LUT file.lut")).unwrap(); + dbg!(circ); + } + + #[test] + fn gate_execute_lut() { + let wire = |num: i32| Wire::Internal(Internal(num.to_string())); + let gate = Gate::Lut(Lut { + input_wires: SmallVec::from_vec(vec![wire(1), wire(2), wire(3), wire(4)]), + masked_luts: SmallVec::from(vec![MaskedLut { + wire_mask: WireMask { + mask: bitvec![u8, Msb0; 1,1,0,1], + }, + output: WireOutput { + unexpanded: bitvec![u8, Msb0; 0,1,1,0,1,1,0,1], + }, + out_wire: wire(5), + }]), + }); + let mut wire_vals = [(wire(1), false), (wire(2), false), (wire(4), false)] + .into_iter() + .collect(); + gate.execute(&mut wire_vals); + assert_eq!(false, wire_vals[&wire(4)]); + } + + // #[test] + // fn circuit_execute() { + // let _g = init_tracing(); + // let circ = Circuit::load(Path::new("test_resources/lut_circuits/lfa32.lut")).unwrap(); + // let inp = bitvec![u8, Lsb0; 0;64]; + // let out = circ.execute(&inp); + // assert_eq!(bits![u8, Lsb0; 0;33], out); + // } +} diff --git a/crates/gmw/src/protocols/aby2_lut.rs b/crates/gmw/src/protocols/aby2_lut.rs index 7e8ee57..526a93d 100644 --- a/crates/gmw/src/protocols/aby2_lut.rs +++ b/crates/gmw/src/protocols/aby2_lut.rs @@ -14,7 +14,7 @@ use crate::share_wrapper::ShareWrapper; use crate::{bristol, BooleanGate, Circuit, CircuitBuilder, GateId, SubCircuitGate}; use ahash::AHashMap; use async_trait::async_trait; -use bitvec::order::Lsb0; +use bitvec::order::{Lsb0, Msb0}; use bitvec::view::BitView; use bitvec::{bitvec, slice, vec}; use itertools::Itertools; @@ -70,7 +70,7 @@ pub enum Msg { pub enum LutGate { Base(BaseGate), Lut { - output_mask: vec::BitVec, + output_mask: vec::BitVec, inputs: u8, }, Xor, @@ -371,8 +371,7 @@ impl LutGate { fn and() -> Self { LutGate::Lut { - // TODO is the order of bits correct? - output_mask: bitvec![u8, Lsb0; 0,0,0,1], + output_mask: bitvec![u8, Msb0; 0,0,0,1], inputs: 2, } } @@ -947,7 +946,7 @@ where fn expand( input_size: u8, - lut_output: &slice::BitSlice, + lut_output: &slice::BitSlice, input: &slice::BitSlice, ) -> Vec> { let lut_set_bits = lut_output.count_ones(); @@ -1197,7 +1196,7 @@ mod tests { #[test] fn expand_output() { - let expanded = expand(3, bits![u8, Lsb0; 0,1,1,0,1,1,0,1], bits![u8, Lsb0; 0;8]); + let expanded = expand(3, bits![u8, Msb0; 0,1,1,0,1,1,0,1], bits![u8, Lsb0; 0;8]); let expected = vec![ bitvec![u8, Msb0; 1,1,0,0,0], bitvec![u8, Msb0; 1,0,1,1,0],