Skip to content

Commit

Permalink
adding unsafe ba version
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed May 24, 2024
1 parent fe84881 commit a2b8fbe
Show file tree
Hide file tree
Showing 4 changed files with 445 additions and 231 deletions.
130 changes: 100 additions & 30 deletions enzyme/benchmarks/ReverseMode/adbench/ba.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,16 @@ extern "C" {
double* reproj_err,
double* w_err
);

void rust2_ba_objective(
int n,
int m,
int p,
double const* cams,
double const* X,
double const* w,
int const* obs,
double const* feats,
double* reproj_err,
double* w_err
);

void rust2_unsafe_ba_objective(int n, int m, int p, double const *cams,
double const *X, double const *w,
int const *obs, double const *feats,
double *reproj_err, double *w_err);

void rust2_ba_objective(int n, int m, int p, double const *cams,
double const *X, double const *w, int const *obs,
double const *feats, double *reproj_err,
double *w_err);

void dcompute_reproj_error(
double const* cam,
Expand Down Expand Up @@ -183,17 +180,17 @@ extern "C" {

void adept_compute_zach_weight_error(double const* w, double* dw, double* err, double* derr);

void rust_dcompute_reproj_error(
double const* cam,
double * dcam,
double const* X,
double * dX,
double const* w,
double * wb,
double const* feat,
double *err,
double *derr
);
void rust_unsafe_dcompute_reproj_error(double const *cam, double *dcam,
double const *X, double *dX,
double const *w, double *wb,
double const *feat, double *err,
double *derr);

void rust_dcompute_reproj_error(double const *cam, double *dcam,
double const *X, double *dX,
double const *w, double *wb,
double const *feat, double *err,
double *derr);

void rust_dcompute_zach_weight_error(double const* w, double* dw, double* err, double* derr);
}
Expand Down Expand Up @@ -362,10 +359,22 @@ int main(const int argc, const char* argv[]) {
std::string path = "/mnt/Data/git/Enzyme/apps/ADBench/data/ba/ba1_n49_m7776_p31843.txt";

std::vector<std::string> paths = {
"ba10_n1197_m126327_p563734.txt", "ba14_n356_m226730_p1255268.txt", "ba18_n1936_m649673_p5213733.txt", "ba2_n21_m11315_p36455.txt", "ba6_n539_m65220_p277273.txt", "test.txt",
"ba11_n1723_m156502_p678718.txt", "ba15_n1102_m780462_p4052340.txt", "ba19_n4585_m1324582_p9125125.txt", "ba3_n161_m48126_p182072.txt", "ba7_n93_m61203_p287451.txt",
"ba12_n253_m163691_p899155.txt", "ba16_n1544_m942409_p4750193.txt", "ba1_n49_m7776_p31843.txt", "ba4_n372_m47423_p204472.txt", "ba8_n88_m64298_p383937.txt",
"ba13_n245_m198739_p1091386.txt", "ba17_n1778_m993923_p5001946.txt", "ba20_n13682_m4456117_p2987644.txt", "ba5_n257_m65132_p225911.txt", "ba9_n810_m88814_p393775.txt",
"ba10_n1197_m126327_p563734.txt",
"ba14_n356_m226730_p1255268.txt", // "ba18_n1936_m649673_p5213733.txt",
// "ba2_n21_m11315_p36455.txt",
// "ba6_n539_m65220_p277273.txt",
// "test.txt",
// "ba11_n1723_m156502_p678718.txt",
// "ba15_n1102_m780462_p4052340.txt",
// "ba19_n4585_m1324582_p9125125.txt",
// "ba3_n161_m48126_p182072.txt", "ba7_n93_m61203_p287451.txt",
// "ba12_n253_m163691_p899155.txt",
// "ba16_n1544_m942409_p4750193.txt", "ba1_n49_m7776_p31843.txt",
// "ba4_n372_m47423_p204472.txt", "ba8_n88_m64298_p383937.txt",
// "ba13_n245_m198739_p1091386.txt",
// "ba17_n1778_m993923_p5001946.txt",
// "ba20_n13682_m4456117_p2987644.txt",
// "ba5_n257_m65132_p225911.txt", "ba9_n810_m88814_p393775.txt",
};

std::ofstream jsonfile("results.json", std::ofstream::trunc);
Expand Down Expand Up @@ -571,7 +580,40 @@ int main(const int argc, const char* argv[]) {
}
}


{
struct BAInput input;
read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
input.X, input.w, input.obs, input.feats);

struct BAOutput result = {std::vector<double>(2 * input.p),
std::vector<double>(input.p),
BASparseMat(input.n, input.m, input.p)};
{

struct timeval start, end;
gettimeofday(&start, NULL);
rust2_unsafe_ba_objective(input.n, input.m, input.p, input.cams.data(),
input.X.data(), input.w.data(),
input.obs.data(), input.feats.data(),
result.reproj_err.data(), result.w_err.data());
gettimeofday(&end, NULL);
printf("primal unsafe rust t=%0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "primal unsafe rust";
enzyme["runtime"] = tdiff(&start, &end);
for (unsigned i = 0; i < 5; i++) {
printf("%f ", result.reproj_err[i]);
enzyme["result"].push_back(result.reproj_err[i]);
}
for (unsigned i = 0; i < 5; i++) {
printf("%f ", result.w_err[i]);
enzyme["result"].push_back(result.w_err[i]);
}
printf("\n");
test_suite["tools"].push_back(enzyme);
}
}

{
struct BAInput input;
read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, input.X, input.w, input.obs, input.feats);
Expand Down Expand Up @@ -626,6 +668,35 @@ int main(const int argc, const char* argv[]) {
BASparseMat(input.n, input.m, input.p)
};

{
struct timeval start, end;
gettimeofday(&start, NULL);
calculate_jacobian<rust_unsafe_dcompute_reproj_error,
rust_dcompute_zach_weight_error>(input, result);
gettimeofday(&end, NULL);
printf("Enzyme unsafe rust combined %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "Enzyme unsafe rust combined";
enzyme["runtime"] = tdiff(&start, &end);
for (unsigned i = 0; i < 5; i++) {
printf("%f ", result.J.vals[i]);
enzyme["result"].push_back(result.J.vals[i]);
}
printf("\n");
test_suite["tools"].push_back(enzyme);
}
}

{

struct BAInput input;
read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
input.X, input.w, input.obs, input.feats);

struct BAOutput result = {std::vector<double>(2 * input.p),
std::vector<double>(input.p),
BASparseMat(input.n, input.m, input.p)};

{
struct timeval start, end;
gettimeofday(&start, NULL);
Expand All @@ -642,7 +713,6 @@ int main(const int argc, const char* argv[]) {
printf("\n");
test_suite["tools"].push_back(enzyme);
}

}

test_suite["llvm-version"] = __clang_version__;
Expand Down
204 changes: 3 additions & 201 deletions enzyme/benchmarks/ReverseMode/ba/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,87 +2,10 @@
#![feature(slice_first_last_chunk)]
#![allow(non_snake_case)]

//#define BA_NCAMPARAMS 11
static BA_NCAMPARAMS: usize = 11;

fn sqsum(x: &[f64]) -> f64 {
x.iter().map(|&v| v * v).sum()
}

#[inline]
fn cross(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] {
[
a[1] * b[2] - a[2] * b[1],
a[2] * b[0] - a[0] * b[2],
a[0] * b[1] - a[1] * b[0],
]
}

fn radial_distort(rad_params: &[f64], proj: &mut [f64]) {
let rsq = sqsum(proj);
let l = 1. + rad_params[0] * rsq + rad_params[1] * rsq * rsq;
proj[0] = proj[0] * l;
proj[1] = proj[1] * l;
}

fn rodrigues_rotate_point(rot: &[f64; 3], pt: &[f64; 3], rotated_pt: &mut [f64; 3]) {
let sqtheta = sqsum(rot);
if sqtheta != 0. {
let theta = sqtheta.sqrt();
let costheta = theta.cos();
let sintheta = theta.sin();
let theta_inverse = 1. / theta;
let mut w = [0.; 3];
for i in 0..3 {
w[i] = rot[i] * theta_inverse;
}
let w_cross_pt = cross(&w, &pt);
let tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) * (1. - costheta);
for i in 0..3 {
rotated_pt[i] = pt[i] * costheta + w_cross_pt[i] * sintheta + w[i] * tmp;
}
} else {
let rot_cross_pt = cross(&rot, &pt);
for i in 0..3 {
rotated_pt[i] = pt[i] + rot_cross_pt[i];
}
}
}

fn project(cam: &[f64; 11], X: &[f64; 3], proj: &mut [f64; 2]) {
let C = &cam[3..6];
let mut Xo = [0.; 3];
let mut Xcam = [0.; 3];

Xo[0] = X[0] - C[0];
Xo[1] = X[1] - C[1];
Xo[2] = X[2] - C[2];

rodrigues_rotate_point(cam.first_chunk::<3>().unwrap(), &Xo, &mut Xcam);

proj[0] = Xcam[0] / Xcam[2];
proj[1] = Xcam[1] / Xcam[2];

radial_distort(&cam[9..], proj);

proj[0] = proj[0] * cam[6] + cam[7];
proj[1] = proj[1] * cam[6] + cam[8];
}
pub mod safe;
pub mod r#unsafe;

#[no_mangle]
pub extern "C" fn rust_dcompute_reproj_error(
cam: *const [f64; 11],
dcam: *mut [f64; 11],
x: *const [f64; 3],
dx: *mut [f64; 3],
w: *const [f64; 1],
wb: *mut [f64; 1],
feat: *const [f64; 2],
err: *mut [f64; 2],
derr: *mut [f64; 2],
) {
dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr);
}
static BA_NCAMPARAMS: usize = 11;

#[no_mangle]
pub extern "C" fn rust_dcompute_zach_weight_error(
Expand All @@ -94,130 +17,9 @@ pub extern "C" fn rust_dcompute_zach_weight_error(
dcompute_zach_weight_error(w, dw, err, derr);
}

#[autodiff(
dcompute_reproj_error,
Reverse,
Duplicated,
Duplicated,
Duplicated,
Const,
Duplicated
)]
pub fn compute_reproj_error(
cam: *const [f64; 11],
x: *const [f64; 3],
w: *const [f64; 1],
feat: *const [f64; 2],
err: *mut [f64; 2],
) {
let cam = unsafe { &*cam };
let w = unsafe { *(*w).get_unchecked(0) };
let x = unsafe { &*x };
let feat = unsafe { &*feat };
let mut err = unsafe { &mut *err };
let mut proj = [0.; 2];
project(cam, x, &mut proj);
err[0] = w * (proj[0] - feat[0]);
err[1] = w * (proj[1] - feat[1]);
}

#[autodiff(dcompute_zach_weight_error, Reverse, Duplicated, Duplicated)]
pub fn compute_zach_weight_error(w: *const f64, err: *mut f64) {
let w = unsafe { *w };
unsafe { *err = 1. - w * w; }
}

// n number of cameras
// m number of points
// p number of observations
// cams: 11*n cameras in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2]
// r1, r2, r3 are angle - axis rotation parameters(Rodrigues)
// [C1 C2 C3]' is the camera center
// f is the focal length in pixels
// [u0 v0]' is the principal point
// k1, k2 are radial distortion parameters
// X: 3*m points
// obs: 2*p observations (pairs cameraIdx, pointIdx)
// feats: 2*p features (x,y coordinates corresponding to observations)
// reproj_err: 2*p errors of observations
// w_err: p weight "error" terms
fn rust_ba_objective(
n: usize,
m: usize,
p: usize,
cams: &[f64],
x: &[f64],
w: &[f64],
obs: &[i32],
feats: &[f64],
reproj_err: &mut [f64],
w_err: &mut [f64],
) {
assert_eq!(cams.len(), n * 11);
assert_eq!(x.len(), m * 3);
assert_eq!(w.len(), p);
assert_eq!(obs.len(), p * 2);
assert_eq!(feats.len(), p * 2);
assert_eq!(reproj_err.len(), p * 2);
assert_eq!(w_err.len(), p);

for i in 0..p {
let cam_idx = obs[i * 2 + 0] as usize;
let pt_idx = obs[i * 2 + 1] as usize;
let start = cam_idx * BA_NCAMPARAMS;
let cam: &[f64; 11] = unsafe {
cams[start..]
.get_unchecked(..11)
.try_into()
.unwrap_unchecked()
};
let x: &[f64; 3] = unsafe {
x[pt_idx * 3..]
.get_unchecked(..3)
.try_into()
.unwrap_unchecked()
};
let w: &[f64; 1] = unsafe { w[i..].get_unchecked(..1).try_into().unwrap_unchecked() };
let feat: &[f64; 2] = unsafe {
feats[i * 2..]
.get_unchecked(..2)
.try_into()
.unwrap_unchecked()
};
let reproj_err: &mut [f64; 2] = unsafe {
reproj_err[i * 2..]
.get_unchecked_mut(..2)
.try_into()
.unwrap_unchecked()
};
compute_reproj_error(cam, x, w, feat, reproj_err);
}

for i in 0..p {
let w_err: &mut f64 = unsafe { w_err.get_unchecked_mut(i) };
compute_zach_weight_error(w[i..].as_ptr(), w_err as *mut f64);
}
}

#[no_mangle]
extern "C" fn rust2_ba_objective(
n: usize,
m: usize,
p: usize,
cams: *const f64,
x: *const f64,
w: *const f64,
obs: *const i32,
feats: *const f64,
reproj_err: *mut f64,
w_err: *mut f64,
) {
let cams = unsafe { std::slice::from_raw_parts(cams, n * 11) };
let x = unsafe { std::slice::from_raw_parts(x, m * 3) };
let w = unsafe { std::slice::from_raw_parts(w, p) };
let obs = unsafe { std::slice::from_raw_parts(obs, p * 2) };
let feats = unsafe { std::slice::from_raw_parts(feats, p * 2) };
let reproj_err = unsafe { std::slice::from_raw_parts_mut(reproj_err, p * 2) };
let w_err = unsafe { std::slice::from_raw_parts_mut(w_err, p) };
rust_ba_objective(n, m, p, cams, x, w, obs, feats, reproj_err, w_err);
}
Loading

0 comments on commit a2b8fbe

Please sign in to comment.