From 886b5631b7c4a8e2aac2dfa903d77e53195a564a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Meyer?= Date: Wed, 10 Jan 2024 19:59:10 +0100 Subject: [PATCH] In Naive Bayes, avoid using `Option::unwrap` and so avoid panicking from NaN values (#274) --- .../hyper_tuning/grid_search.rs | 4 +- src/naive_bayes/mod.rs | 94 +++++++++++++++++-- 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/src/model_selection/hyper_tuning/grid_search.rs b/src/model_selection/hyper_tuning/grid_search.rs index 3c914e48..74242c60 100644 --- a/src/model_selection/hyper_tuning/grid_search.rs +++ b/src/model_selection/hyper_tuning/grid_search.rs @@ -3,9 +3,9 @@ use crate::{ api::{Predictor, SupervisedEstimator}, error::{Failed, FailedError}, - linalg::basic::arrays::{Array2, Array1}, - numbers::realnum::RealNumber, + linalg::basic::arrays::{Array1, Array2}, numbers::basenum::Number, + numbers::realnum::RealNumber, }; use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult}; diff --git a/src/naive_bayes/mod.rs b/src/naive_bayes/mod.rs index e7ab7f6d..11614d14 100644 --- a/src/naive_bayes/mod.rs +++ b/src/naive_bayes/mod.rs @@ -40,7 +40,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1}; use crate::numbers::basenum::Number; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use std::marker::PhantomData; +use std::{cmp::Ordering, marker::PhantomData}; /// Distribution used in the Naive Bayes classifier. pub(crate) trait NBDistribution: Clone { @@ -92,11 +92,10 @@ impl, Y: Array1, D: NBDistribution Result { let y_classes = self.distribution.classes(); - let (rows, _) = x.shape(); - let predictions = (0..rows) - .map(|row_index| { - let row = x.get_row(row_index); - let (prediction, _probability) = y_classes + let predictions = x + .row_iter() + .map(|row| { + y_classes .iter() .enumerate() .map(|(class_index, class)| { @@ -106,11 +105,26 @@ impl, Y: Array1, D: NBDistribution ordering, + None => { + if p1.is_nan() { + Ordering::Less + } else if p2.is_nan() { + Ordering::Greater + } else { + Ordering::Equal + } + } + }) + .map(|(prediction, _probability)| *prediction) + .ok_or_else(|| Failed::predict("Failed to predict, there is no result")) }) - .collect::>(); + .collect::, Failed>>()?; let y_hat = Y::from_vec_slice(&predictions); Ok(y_hat) } @@ -119,3 +133,63 @@ pub mod bernoulli; pub mod categorical; pub mod gaussian; pub mod multinomial; + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::basic::arrays::Array; + use crate::linalg::basic::matrix::DenseMatrix; + use num_traits::float::Float; + + type Model<'d> = BaseNaiveBayes, Vec, TestDistribution<'d>>; + + #[derive(Debug, PartialEq, Clone)] + struct TestDistribution<'d>(&'d Vec); + + impl<'d> NBDistribution for TestDistribution<'d> { + fn prior(&self, _class_index: usize) -> f64 { + 1. + } + + fn log_likelihood<'a>( + &'a self, + class_index: usize, + _j: &'a Box + 'a>, + ) -> f64 { + match self.0.get(class_index) { + &v @ 2 | &v @ 10 | &v @ 20 => v as f64, + _ => f64::nan(), + } + } + + fn classes(&self) -> &Vec { + &self.0 + } + } + + #[test] + fn test_predict() { + let matrix = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]); + + let val = vec![]; + match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) { + Ok(_) => panic!("Should return error in case of empty classes"), + Err(err) => assert_eq!( + err.to_string(), + "Predict failed: Failed to predict, there is no result" + ), + } + + let val = vec![1, 2, 3]; + match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) { + Ok(r) => assert_eq!(r, vec![2, 2, 2]), + Err(_) => panic!("Should success in normal case with NaNs"), + } + + let val = vec![20, 2, 10]; + match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) { + Ok(r) => assert_eq!(r, vec![20, 20, 20]), + Err(_) => panic!("Should success in normal case without NaNs"), + } + } +}