Skip to content

Commit

Permalink
Iris Code Rotation Comparison Improvement (#48)
Browse files Browse the repository at this point in the history
* rotation improvement

* removed duplicated code

* clippy

* cleanup
  • Loading branch information
0xForerunner authored Feb 23, 2024
1 parent db09aac commit c8f273c
Show file tree
Hide file tree
Showing 12 changed files with 80 additions and 216 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ aws-config = "1.1.4"
aws-sdk-sqs = "1.12.0"
axum = "0.7.4"
base64 = "0.21.7"
bitvec = "1"
bytemuck = { version = "1.14.1", features = ["derive"] }
clap = { version = "4.4.18", features = ["derive", "env"] }
config = "0.13.4"
Expand Down Expand Up @@ -64,7 +65,7 @@ name = "e2e"
path = "bin/e2e/e2e.rs"

[[bench]]
name = "example"
name = "rotation_comparison"
harness = false

[profile.release]
Expand Down
12 changes: 0 additions & 12 deletions benches/example.rs

This file was deleted.

26 changes: 26 additions & 0 deletions benches/rotation_comparison.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use criterion::{criterion_group, criterion_main, Criterion};
use mpc::template::Template;

const A: &str = indoc::indoc! { r#"{
"code": "UGQwOHvPdn5w70Tf52jWiso1Upw4taZzWbvoQeYD2JocpibnzBLJJHCBhtFSDfCSxHc7RFXcVCRfNocX4KZdXxx3Vhn+yVHOcskVi6/2/JsaGztJJDCpo+oI+0+DEP+pYPMqOIYA2Ii3VJxnkcQKcdnTKuBOOXV5SXCj8yo/XZhcjvRj+OHeNvgh8UKvFsEVgD0GL7uq+0GAb0s6WQz6wzAmnlkGfV9E7gxT9QcIoezUZm8ZTuHyUcnpDhQ8ptiTbdIWsWL4VIY9PPntGPOL5pMXAQ7Tc7estxTf1za962XDAXwv4l3NkintjfowFJQaLTzDsHzFk4EQwmMCggiQ04nw2YUMc/HtYqbKjopNoMRrvQFrXHW1w2Gipa+RPw5aPiIP7fUF2G9uSemoPvdTHCAZXc887T30fucHUaSv9/2h656KPCA35j1icisHzDn3OEAGVBKr1yhwnzkcM1DNSMcquzDHtSNy7Xn/+6WVVrMNaN9vZg9IMREqn0qqMKrNubfCNHAMOUpIhNIa3Sr75k2xlsb0j+c3cCnkyLD5wHTVBvyzkjCM2siFl4TLZ0EK/DJ+DhNglXSgBfwgzE13F0I1mTyekXov2GYUZUS0ZJrJQyJRWZq+Kx3EfFKFUS9BKu6RFsBovYOpvWUMeWwsC93je6dKQIjRCahKnGSjoLPLNfJBFVdts4ZRoDTL6r8HOTIS+UZM3p75EUeK4LpfZaEC3I7vRLhtiMt64U1VQpWDeJdutbGzmC91p9Xxr3mEyVzjWQbvWEF9/qcYehsxLhlG4Uq2AQGBuP8Ro23wSPbeBhZIu/FaRvYT0HmtaqDgVTZBJkLxR9+qYO1rdY2HFTnQynZyRq15Ri0mIFFrlSv234r0wro49SfF7q/e6oaz3gcvMouqOAABTxLdv8xTSVZxaes5EGrBDr2TY8MJmIu6zYZ729rmr1njAWCZijz2g3lk74X0hyN1+MDgbeGzxXDjcxcAtHswxH58HN9oQ0E92dnoUsQzDX5JWmRcjh5LdsdJmxZYbzuqUDJFselLM06ZJhfHJ6P8oCod2aWVLNX5fkSdT0BfCNw63PlyfifWJbLuL2kbMB5jzvjjJtc428m5JF5+6yMHHc1+7Pic0tTm+D/vtJ4PQrY/YrQDbBIhZCqIB9fjcBqktC75e81JyUsHNM1qRuySG1maGxTgwRfzJgUoKYSaK+g/IKfDPU04AMhHSNi5BXijjicbf4vVLsRVNT5hEFl/XRr0YAxaNdcDPzys3NZ37hc/8RGkPLbNBzsIhd0KKdHbTxGm+at4Sa2a1bpOWx1crPx9QXq7asC0YJ5yE8zWSf/VD0VEJ9LJ7mrgsU7SF2iJRmhPZVkdOv0QkCk9hR3PGLpkCUsUMMpZh+oSHF4n1GGUKLinThXMDsVW+BNZ0ukE9TmBKLXOpPxTYkaPBjeyqJurGXA5mlCDKGGzFHLHh/D/DqR10OK/XhwmXoML6rWFUrai3FPdJBmRTzFga3TrBD0td1x0iBeqhmdz48NoBEjOTdLkB5AezSJ2+o4xQll+9fzdirWV/QbqOidaZMGW3RwlE03FqYNrH+izh46oxstQhmf4GxWLnUvlqF/K9djhB9ziTX5e0RCJ+rctPGD1uhOqcHIWASvQWBKKtxPnyjx0wMIBvQUj43hXpX1L3YQfKe0ZacoHigu9n2DYAfgHU5bT9N/MU2pcdPukVckPd2ExHAFNUi+/VyxKvla6tZP4khc9pa7dEKhNxyRDxGr35tT01SGTziaEyf61HTeeMO5WxEVg8wx0GX6ibXkFIb2UHZYagJs+VX59xzU9DkRvrP7IcjScvJ7M5pef4ZhkwzpyLJ9Q+zi2FyDGqc/n+isaT6owkh+X3kn7GZ6BtDvRlO10MmOq+lKFWt/ZhfADtBmnuJV0UC4YvEjqfxD88Yx7yeAHEwiUTY2yCrJi1hpLnq0bVE8r3RkcXDamk/kNmcvZJ0iNMWszSptlR24UVTcXvEHlKrTsl5m0qGskwFn7xg6CPHRSgfhP8tD0ksM6xzsEQmjXkrt1S3NIWRz26bcBTrwmmrb7rsunEdO+GHGfHWBssD9c4G+OOc8vXeXsHxmMqcj8EwvFjDtdBrR9wvd/irYDEPYZvw==",
"mask":"/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////w=="
}"# };

const B: &str = indoc::indoc! { r#"{
"code": "NBmIZfruaQ6PwrJ34jZriTKuhEwPqn7TlbY33BrnQxSS6cLrGy8ZwI6lfekKyTgV627R70xHb9HDPi4brHiBhGGouyWdGUjulBVEJ/nEOjWMY0GyV4c+sZVIONy5TtT66xJgfOG+8SVWkycZge+L26nc22vpcETqc5fyArVXFCaD2dS+d4X8jb5sMdzDNYLILeCCCQUy8Wz9Kr0wef1vcjnINru5JZ/QmBaWlggUspWxpxDNf692RiGoj/ue7Py45jqdr+/SgOcbl9zDvnVk7rd3B8K4BgHMq+YFuiBREVFvzcuI1cl5oLJOzAJhpi+sfYe1tOeYWEYQJBMvGEIOmz8AKWoLKESVbh/9lIYSW9JE+bGJKEsPPZxCpo9vfFCYwROL6Rxw4wgxeNQ/VUX8HYqbSEUDpDIvZVOOAEOGxwa6v/P91ZOTEUwZ6V/Vz+46Z+oAUyccEcikKvreMBqZJw6x3PY3pESIz6JYoclDkjUOXTBjReCEv6smAv+YfRICRQavmVsT1j2L0GCmAH4bNxCATyAPq2YqdtH5VwrKZPczMb5ExAryRzQyiOWOVTBkPjUElhxeprilmyuV4AkHW+xPgDnGitgpQ6hk/s2eL8/SwBMtRUygc/tz6bdszyt5WhEBy5qfsQ71/A6GRkndCIjXj//W5lhcdiqobrZ70WFfQy/NRgL3SFLMEYZQrTfonp+mRDI0DwQeFKojEjjpx7T9DdSBMBLhRs/PwOCVLcQZbeduRYxPrz4GOqWAcw7p6v56i6+YJOLm0mxcjVnhgIlTM1LJkp+UF2Ke6LmKFeFeuSjpA/GH2TiDlLKl04++V4aspmmRYZU5BS02FFh4V7P8u4m4SPMJXJkUV9bXMP67aqMAVkFzQBVg7wqzPw8eSvpRw6NticywIYf8n+RPjrHpPH3EZJriFRHQW8qqMQCfE3nVNSOLsTWkc+QFvaANPZJDUA5RvPQTFyqd2Jl8sSw0rnlKex1Rn5tVzoMof/aq1ARscRP49XEjaNDH9axziDyoWTcdmLDt4QcWURMTu5Hu6S4Fd/SW1qDQGt8jWlUABqmPNwDOmX3mgiErrE+RvAhwe1k4wQc9uB22O/MiBTSd/gwcqElvvxUSHIlgNyjsi2YHFucLpiNAPk9zaQsA1RP8gtLsKk+OKiIEJX+nNQHFYbPpcpSuY/XA2QFmNvzN0G9y25FVMir4+bBhvD9UKjHoxtfEGWW5Tq4UrQTaR4PzY2Vns3d0vz1KmLooIVpUNfQ0uVh2rODbhCNWkt/8skRSjAf9e+zNXWj5oMNaKNKVyWiS5VAUpITivy3yCpdwjFNWAREsufl1IOMo861bl1gJrKS6RxPPOXZJ36wHxSVRrsrQcoz0oqx2S6eE6sIssOUxaR31SaAXxZ0C/rRSDcUW/psYgUXeR6AxzyxkbXxxIdPnubRG/v1im4chRVxvy3r6YVsAZYR4uIJE2KPtCQYbJojD3/F11YnDXDA15L8hCxcMoCpjz+IzVfdBBM+jJK0KAqIWItWSOHPDbANzUuY4jxKVTbfY7XHE8OXE+fcCKk6j5eoTTiOKWneDn0Ym2c8sgIxwzUOPmCy0GhexARPqYrKPCRNgvOJ6gCUTDK7KOYzn8UmzFQ9/hOWc1rfFx7HxcK/8kSzzD8EroLccyAz6ZT7fa3YrmosEiCjTNc9DKF4j5dSgYBCFDXPpHAxPTKk8pZSNKfuLWtXQV/OjRky9XEl7g85ejwZKm/w7zpE7gtFUBYH2rAu7cCyai5OGbWUB8Jw0sOeOHRsuFxdyQknYazomA8ZQD+aQglwfMfHG2YkiNSM2eHR+4jBU4TUoTXrqgEQBMuVfHROdSLZdUPK+4kd5VM5cfDJsmKYBXSIetLwl/OEQ9faB5RjoDmcRcP123Yv3hHt/VozmTI/4BiUypkGigGfwMTNQaO6HcIu/mfaBCtoUgy9xmD4YfDIKpQjsI66MrC4SGCBjp9s6Eo+N9R6kS976dzPg/6EJPfPQdUVnAfscyonUNLgjLvcGehZ4S9v6SABx9AzVXGPPJ19qv5Y6nSxJ8fgo38ULCIVP8cMbYZ5SNfHlVREqccQjopkcM+8YIEYBp05c99cii0h6bpEcJtwja4U6gEDgaw==",
"mask":"/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////w=="
}"# };

pub fn rotation_comparison(c: &mut Criterion) {
let a: Template = serde_json::from_str(A).unwrap();
let b: Template = serde_json::from_str(B).unwrap();
c.bench_function("rotation comparison", |bencher| {
bencher.iter(|| {
// Benchmark logic here
a.distance(&b);
})
});
}

criterion_group!(benches, rotation_comparison);
criterion_main!(benches);
5 changes: 1 addition & 4 deletions src/arch/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ pub fn denominators<'a>(
) -> impl Iterator<Item = [u16; 31]> + 'a {
const BATCH: usize = 10_000;

// Prepare 31 rotations of query in advance
let rotations: Box<[_]> = ROTATIONS.map(|r| query.rotated(r)).collect();

// Iterate over a batch of database entries
db.chunks(BATCH).flat_map(move |chunk| {
// Parallel computation over batch
Expand All @@ -52,7 +49,7 @@ pub fn denominators<'a>(
.map(|(entry)| {
let mut result = [0_u16; 31];
// Compute dot product for each rotation
for (d, rotation) in result.iter_mut().zip(rotations.iter()) {
for (d, rotation) in result.iter_mut().zip(query.rotations()) {
*d = rotation.dot(entry);
}
result
Expand Down
4 changes: 2 additions & 2 deletions src/arch/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ pub fn denominators<'a>(
) -> impl Iterator<Item = [u16; 31]> + 'a {
db.iter().map(|entry| {
let mut result = [0_u16; 31];
for (d, r) in result.iter_mut().zip(ROTATIONS) {
*d = query.rotated(r).dot(entry);
for (d, r) in result.iter_mut().zip(query.rotations()) {
*d = r.dot(entry);
}
result
})
Expand Down
63 changes: 38 additions & 25 deletions src/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,57 @@ use std::ops::Index;

use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bitvec::prelude::*;
use bytemuck::{cast_slice_mut, Pod, Zeroable};
use rand::distributions::{Distribution, Standard};
use rand::Rng;
use serde::de::Error as _;
use serde::{Deserialize, Serialize};

use crate::{iris, slice_utils};
use crate::distance::ROTATION_DISTANCE;

pub const COLS: usize = 200;
pub const STEP_MULTI: usize = 4;
pub const ROWS: usize = 4 * 16;
pub const BITS: usize = ROWS * COLS;
const LIMBS: usize = BITS / 64;
const BYTES_PER_COL: usize = COLS * STEP_MULTI;

#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Bits(pub [u64; LIMBS]);

impl Bits {
pub fn rotated(&self, amount: i32) -> Self {
// Convert to big-endian bytes
let mut unpacked = iris::unpack_iris_code64(&self.0);

// Rotate byte chunks
for chunk in unpacked.array_chunks_mut::<BYTES_PER_COL>() {
let rot = amount * STEP_MULTI as i32;
slice_utils::rotate_slice(chunk, rot as i64);
}
/// Returns an unordered iterator over the 31 possible rotations.
/// Rotations are done consecutively because the underlying `rotate_left\right`
/// methods are less efficient for larger rotations.
pub fn rotations(&self) -> impl Iterator<Item = Self> + '_ {
let mut left = *self;
let iter_left = (0..ROTATION_DISTANCE).map(move |_| {
left.rotate_left(1);
left
});
let mut right = *self;
let iter_right = (0..ROTATION_DISTANCE).map(move |_| {
right.rotate_left(1);
right
});
std::iter::once(*self).chain(iter_left).chain(iter_right)
}

let limbs = iris::pack_iris_code(&unpacked);
Bits::try_from(limbs).expect("Invalid bits size")
pub fn rotate_right(&mut self, by: usize) {
BitSlice::<_, Lsb0>::from_slice_mut(&mut self.0)
.chunks_exact_mut(COLS)
.for_each(|chunk| chunk.rotate_right(by));
}

/// For some insane reason, chunks_exact_mut benchmarks faster than manually indexing
/// for rotate_right but not for rotate_left. Compilers are weird.
pub fn rotate_left(&mut self, by: usize) {
let bit_slice = BitSlice::<_, Lsb0>::from_slice_mut(&mut self.0);
for row in 0..ROWS {
let row_slice = &mut bit_slice[row * COLS..(row + 1) * COLS];
row_slice.rotate_left(by);
}
}

pub fn count_ones(&self) -> u16 {
Expand Down Expand Up @@ -258,12 +277,10 @@ mod tests {
use rand::{thread_rng, Rng};

use super::*;
use crate::distance::ROTATIONS;

#[test]
fn limbs_exact() {
assert_eq!(LIMBS * 64, BITS);
assert_eq!(BYTES_PER_COL, COLS * STEP_MULTI);
}

#[test]
Expand All @@ -285,16 +302,12 @@ mod tests {
#[test]
fn test_rotated_inverse() {
let mut rng = thread_rng();
for _ in 0..100 {
let bits: Bits = rng.gen();
for amount in ROTATIONS {
assert_eq!(
bits.rotated(amount).rotated(-amount),
bits,
"Rotation failed for {amount}"
)
}
}
let bits: Bits = rng.gen();
let mut other = bits;
other.rotate_left(1);
other.rotate_right(1);

assert_eq!(bits, other)
}

#[test]
Expand Down
8 changes: 3 additions & 5 deletions src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub const ROWS: usize = 4 * 16;
pub const BITS: usize = ROWS * COLS;

pub const ROTATIONS: RangeInclusive<i32> = -15..=15;
pub const ROTATION_DISTANCE: usize = 15;

/// Generate a [`EncodedBits`] such that values are $\{-1,0,1\}$, representing
/// unset, masked and set.
Expand Down Expand Up @@ -99,11 +100,8 @@ pub struct MasksEngine {

impl MasksEngine {
pub fn new(query: &Bits) -> Self {
let rotations = ROTATIONS
.map(|r| query.rotated(r))
.collect::<Box<[Bits]>>()
.try_into()
.unwrap();
let rotations =
query.rotations().collect::<Box<[_]>>().try_into().unwrap();
Self { rotations }
}

Expand Down
88 changes: 0 additions & 88 deletions src/iris.rs

This file was deleted.

2 changes: 0 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ pub mod db;
pub mod distance;
pub mod encoded_bits;
pub mod health_check;
pub mod iris;
pub mod participant;
pub mod slice_utils;
pub mod template;
pub mod utils;
63 changes: 0 additions & 63 deletions src/slice_utils.rs

This file was deleted.

Loading

0 comments on commit c8f273c

Please sign in to comment.