From ed5b132861dd401d469ebc5dca3959639783d223 Mon Sep 17 00:00:00 2001 From: Hallvard Lavik Date: Sun, 19 May 2024 21:14:31 +0200 Subject: [PATCH] Update `MSE` and `BinaryCrossEntropy`. --- examples/{example_train.rs => train.rs} | 6 +-- src/objective.rs | 58 ++++++++++++------------- 2 files changed, 32 insertions(+), 32 deletions(-) rename examples/{example_train.rs => train.rs} (93%) diff --git a/examples/example_train.rs b/examples/train.rs similarity index 93% rename from examples/example_train.rs rename to examples/train.rs index 496f461..38de450 100644 --- a/examples/example_train.rs +++ b/examples/train.rs @@ -32,16 +32,16 @@ fn main() { let nodes = vec![2, 2, 1]; let biases = vec![false, false]; let activations = vec![Activation::Sigmoid, Activation::Sigmoid]; - let lr = 0.9f32; + let lr = 9.0f32; let optimizer = Optimizer::SGD; - let objective = Objective::BinaryCrossEntropy; + let objective = Objective::MSE; let mut net = network::Network::create( nodes, biases, activations, lr, optimizer, objective ); // Train the network - let epoch_loss = net.train(&inputs, &targets, 1000); + let _epoch_loss = net.train(&inputs, &targets, 1000); // Validate the network let val_loss = net.validate(&inputs, &targets); diff --git a/src/objective.rs b/src/objective.rs index b80f6ba..d172dca 100644 --- a/src/objective.rs +++ b/src/objective.rs @@ -19,9 +19,9 @@ use std::fmt::Display; pub enum Objective { AE, MAE, - // MSE, + MSE, RMSE, - // BinaryCrossEntropy, + BinaryCrossEntropy, CategoricalCrossEntropy, } @@ -34,9 +34,9 @@ impl Display for Function { match self.objective { Objective::AE => write!(f, "AE"), Objective::MAE => write!(f, "MAE"), - // Objective::MSE => write!(f, "MSE"), + Objective::MSE => write!(f, "MSE"), Objective::RMSE => write!(f, "RMSE"), - // Objective::BinaryCrossEntropy => write!(f, "BinaryCrossEntropy"), + Objective::BinaryCrossEntropy => write!(f, "BinaryCrossEntropy"), Objective::CategoricalCrossEntropy => write!(f, "CategoricalCrossEntropy"), } } @@ -47,9 +47,9 @@ impl Function { match objective { Objective::AE => Function { objective: Objective::AE }, Objective::MAE => Function { objective: Objective::MAE }, - // Objective::MSE => Function { objective: Objective::MSE }, + Objective::MSE => Function { objective: Objective::MSE }, Objective::RMSE => Function { objective: Objective::RMSE }, - // Objective::BinaryCrossEntropy => Function { objective: Objective::BinaryCrossEntropy }, + Objective::BinaryCrossEntropy => Function { objective: Objective::BinaryCrossEntropy }, Objective::CategoricalCrossEntropy => Function { objective: Objective::CategoricalCrossEntropy }, } } @@ -88,15 +88,15 @@ impl Function { ).collect(); (loss, gradient) }, - // Objective::MSE => { - // let loss: f32 = y.iter().zip(out.iter()) - // .map(|(actual, predicted)| (actual - predicted).powi(2)) - // .sum::() / y.len() as f32; - // let gradient: Vec = y.iter().zip(out.iter()) - // .map(|(actual, predicted)| 2.0 * (actual - predicted) / y.len() as f32) - // .collect(); - // (loss, gradient) - // }, + Objective::MSE => { + let loss: f32 = y.iter().zip(out.iter()) + .map(|(actual, predicted)| (actual - predicted).powi(2)) + .sum::() / y.len() as f32; + let gradient: Vec = y.iter().zip(out.iter()) + .map(|(actual, predicted)| -2.0 * (actual - predicted) / y.len() as f32) + .collect(); + (loss, gradient) + }, Objective::RMSE => { let loss: f32 = y.iter().zip(out.iter()) .map(|(actual, predicted)| (actual - predicted).powi(2)) @@ -112,20 +112,20 @@ impl Function { ).collect(); (loss, gradient) }, - // Objective::BinaryCrossEntropy => { - // let eps: f32 = 1e-7; - // let loss: f32 = -y.iter().zip(out.iter()) - // .map(|(actual, predicted)| { - // let predicted = predicted.clamp(eps, 1.0 - eps); - // actual * predicted.ln() + (1.0 - actual) * (1.0 - predicted).ln() - // }).sum::() / y.len() as f32; - // let gradient: Vec = y.iter().zip(out.iter()) - // .map(|(actual, predicted)| { - // let predicted = predicted.clamp(eps, 1.0 - eps); - // (predicted - actual) / (predicted * (1.0 - predicted)) - // }).collect(); - // (loss, gradient) - // }, + Objective::BinaryCrossEntropy => { + let eps: f32 = 1e-7; + let loss: f32 = -y.iter().zip(out.iter()) + .map(|(actual, predicted)| { + let predicted = predicted.clamp(eps, 1.0 - eps); + actual * predicted.ln() + (1.0 - actual) * (1.0 - predicted).ln() + }).sum::() / y.len() as f32; + let gradient: Vec = y.iter().zip(out.iter()) + .map(|(actual, predicted)| { + let predicted = predicted.clamp(eps, 1.0 - eps); + (predicted - actual) / (predicted * (1.0 - predicted)) + }).collect(); + (loss, gradient) + }, Objective::CategoricalCrossEntropy => { let eps: f32 = 1e-7; let loss: f32 = -y.iter().zip(out.iter())