From 1a062f395ddc5bec6acef1770a4f616a357b4b0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Have=CC=81?= Date: Sun, 15 Oct 2023 23:07:37 +0200 Subject: [PATCH] Add LinearRegression demo --- src/rust/Cargo.lock | 463 +++++++++++++++++++++++++++++- src/rust/Cargo.toml | 12 +- src/rust/src/lib.rs | 8 + src/rust/src/linear_regression.rs | 162 +++++++++++ 4 files changed, 642 insertions(+), 3 deletions(-) create mode 100644 src/rust/src/linear_regression.rs diff --git a/src/rust/Cargo.lock b/src/rust/Cargo.lock index 18bb817..1187358 100644 --- a/src/rust/Cargo.lock +++ b/src/rust/Cargo.lock @@ -2,6 +2,48 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aho-corasick" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +dependencies = [ + "memchr", +] + +[[package]] +name = "anyhow" +version = "1.0.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bytemuck" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + [[package]] name = "extendr-api" version = "0.4.0" @@ -32,9 +74,121 @@ checksum = "09bf0849f0d48209be8163378248137fed5ccb5f464d171cf93a19f31a9e6c67" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] +[[package]] +name = "futures" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" + +[[package]] +name = "futures-executor" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" + +[[package]] +name = "futures-macro" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.32", +] + +[[package]] +name = "futures-sink" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" + +[[package]] +name = "futures-task" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" + +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + +[[package]] +name = "futures-util" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "lazy_static" version = "1.4.0" @@ -47,11 +201,111 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd728a97b9b0975f546bc865a7413e0ce6f98a8f6cea52e77dc5ee0bcea00adf" +[[package]] +name = "libc" +version = "0.2.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" + [[package]] name = "libkrigingtemplate" version = "0.1.0" dependencies = [ + "anyhow", "extendr-api", + "nalgebra", + "rand", + "rand_distr", + "rstest", +] + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "matrixmultiply" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" + +[[package]] +name = "nalgebra" +version = "0.32.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "307ed9b18cc2423f29e83f84fd23a8e73628727990181f18641a8b5dc2ab1caa" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91761aed67d03ad966ef783ae962ef9bbaca728d2dd7ceb7939ec110fffad998" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", + "libm", ] [[package]] @@ -60,6 +314,24 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro2" version = "1.0.66" @@ -78,6 +350,162 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "regex" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aaac441002f822bc9705a681810a4dd2963094b9ca0ddc41cb963a4c189189ea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5011c7e263a695dc8ca064cddb722af1be54e517a280b12a5356f98366899e5d" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" + +[[package]] +name = "relative-path" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c707298afce11da2efef2f600116fa93ffa7a032b5d7b628aa17711ec81383ca" + +[[package]] +name = "rstest" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97eeab2f3c0a199bc4be135c36c924b6590b88c377d416494288c14f2db30199" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d428f8247852f894ee1be110b375111b586d4fa431f6c46e64ba5a0dcccbe605" +dependencies = [ + "cfg-if", + "glob", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn 2.0.32", + "unicode-ident", +] + +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + +[[package]] +name = "safe_arch" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f398075ce1e6a179b46f51bd88d0598b92b00d3551f1a2d4ac49e771b56ac354" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "semver" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" + +[[package]] +name = "simba" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + [[package]] name = "syn" version = "1.0.109" @@ -89,8 +517,41 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "unicode-ident" version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wide" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebecebefc38ff1860b4bc47550bbfa63af5746061cf0d29fcd7fa63171602598" +dependencies = [ + "bytemuck", + "safe_arch", +] diff --git a/src/rust/Cargo.toml b/src/rust/Cargo.toml index a15f592..ad25637 100644 --- a/src/rust/Cargo.toml +++ b/src/rust/Cargo.toml @@ -1,10 +1,18 @@ [package] name = 'libkrigingtemplate' version = '0.1.0' -edition = '2018' +edition = '2021' [lib] -crate-type = [ 'staticlib' ] +crate-type = ['staticlib'] [dependencies] +anyhow = "1.0" extendr-api = '*' +#nalgebra = "0.32" +nalgebra = "0.32" +rand = "0.8" +rand_distr = "0.4" + +[dev-dependencies] +rstest = "0.18" diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 789e6e0..ae4e335 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -1,3 +1,5 @@ +mod linear_regression; + use extendr_api::prelude::*; /// Return string `"Hello world!"` to R. @@ -14,3 +16,9 @@ extendr_module! { mod libkrigingtemplate; fn hello_world; } + +#[cfg(test)] +mod tests { + #[test] + fn rs_compile() {} +} diff --git a/src/rust/src/linear_regression.rs b/src/rust/src/linear_regression.rs new file mode 100644 index 0000000..f7976a3 --- /dev/null +++ b/src/rust/src/linear_regression.rs @@ -0,0 +1,162 @@ +use anyhow::{anyhow, Result}; + +type ColVec = nalgebra::DVector; + +type Mat = nalgebra::DMatrix; + +pub fn solve1(x: Mat, y: &ColVec) -> Result { + // println!("x : {:?}", x.shape()); + let qr = x.qr(); + // println!("q: {:?}", qr.q().shape()); + + let mut y = Mat::transpose(&qr.q()) * y; + // println!("tr(q)*y: {}", y); + + let r = qr.unpack_r(); + // println!("r: {:?}", r.shape()); + + let solved = r.solve_upper_triangular_mut(&mut y); + if solved { + Ok(y) + } else { + Err(anyhow!("Cannot solve")) + } +} + +pub fn solve(x: Mat, y: &ColVec) -> Result { + // println!("x : {:?}", x.shape()); + let qr = x.qr(); + + // println!("q: {:?}", qr.q().shape()); + + let mut y = y.clone(); + qr.q_tr_mul(&mut y); + // println!("tr(q)*y: {}", y); + + let r = qr.unpack_r(); + // println!("r: {:?}", r.shape()); + + y.resize_vertically_mut(r.shape().0, f64::NAN); + + let solved = r.solve_upper_triangular_mut(&mut y); + if solved { + Ok(y) + } else { + Err(anyhow!("Cannot solve")) + } +} + +struct LinearRegression { + coef: ColVec, + sig2: f64, + stderrest: ColVec, +} + +trait FitAndPredict { + fn fit(v: &ColVec, x: Mat) -> Result + where + Self: Sized; + + fn predict(&self, x: &Mat) -> (ColVec, ColVec); +} + +impl LinearRegression { + pub fn coef(&self) -> &ColVec { + &self.coef + } + + pub fn sig2(&self) -> f64 { + self.sig2 + } + + pub fn stderrest(&self) -> &ColVec { + &self.stderrest + } +} + +impl FitAndPredict for LinearRegression { + fn fit(y: &ColVec, x: Mat) -> Result { + let (n, k) = x.shape(); + + let coef = solve(x.clone(), y)?; + // println!("coef = {coef}"); + let resid = y - &x * &coef; + let sig2_mat = (resid.transpose() * resid) / (n as f64 - k as f64); + assert_eq!(sig2_mat.shape(), (1, 1)); // should be scalar + let sig2 = *sig2_mat.as_scalar(); + let stderrest = Mat::map_diagonal( + &Mat::try_inverse(x.transpose() * &x) + .ok_or_else(|| anyhow!("Cannot inverse x^t * x"))?, + f64::sqrt, + ); + + Ok(LinearRegression { + coef, + sig2, + stderrest, + }) + } + + fn predict(&self, x: &Mat) -> (ColVec, ColVec) { + // should test that X.n_cols == fit.X.n_cols + let y = x * &self.coef; + let stderr_v = Mat::map_diagonal( + &(x * Mat::from_diagonal(&self.stderrest) * Mat::transpose(x)), + f64::sqrt, + ); + (y, stderr_v) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::thread_rng; + use rand_distr::Distribution; + use rstest::rstest; + + #[rstest] + fn test(#[values(40, 100, 1000)] n: usize, #[values(3, 6)] m: usize) { + let mut rng = thread_rng(); + + let normal = rand_distr::Normal::new(0., 1.).unwrap(); + + let sol = ColVec::from_fn(m, |_, _| normal.sample(&mut rng)); + + let mut x = Mat::from_fn(n, m, |_, _| normal.sample(&mut rng)); + x.fill_column(0, 1.); + + // WHEN value is perfectly computed + { + let y = &x * / + + // println!("sol = {sol}"); + // println!("x = {x}"); + // println!("y = {y}"); + + let rl = LinearRegression::fit(&y, x.clone()).unwrap(); + let (y_pred, _) = rl.predict(&x); + let eps = 1e-5; + let norm_inf = (y - y_pred).abs().max(); + assert!(norm_inf < 10. * eps); + } + + // WHEN value is computed with noise + { + let e = 1e-8; + let noise = rand_distr::Normal::new(1.0, e).unwrap(); + + let y = (&x * &sol).map(|v| v * noise.sample(&mut rng)); + + // println!("sol = {sol}"); + // println!("x = {x}"); + // println!("y = {y}"); + + let rl = LinearRegression::fit(&y, x.clone()).unwrap(); + let (y_pred, _) = rl.predict(&x); + let eps = 1e-5; + let norm_inf = (y - y_pred).abs().max(); + assert!(norm_inf < 10. * eps + 10. * e); + } + } +}