Skip to content

Commit

Permalink
Update MSE and BinaryCrossEntropy.
Browse files Browse the repository at this point in the history
  • Loading branch information
hallvardnmbu committed May 19, 2024
1 parent 1354d14 commit ed5b132
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 32 deletions.
6 changes: 3 additions & 3 deletions examples/example_train.rs → examples/train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ fn main() {
let nodes = vec![2, 2, 1];
let biases = vec![false, false];
let activations = vec![Activation::Sigmoid, Activation::Sigmoid];
let lr = 0.9f32;
let lr = 9.0f32;
let optimizer = Optimizer::SGD;
let objective = Objective::BinaryCrossEntropy;
let objective = Objective::MSE;

let mut net = network::Network::create(
nodes, biases, activations, lr, optimizer, objective
);

// Train the network
let epoch_loss = net.train(&inputs, &targets, 1000);
let _epoch_loss = net.train(&inputs, &targets, 1000);

// Validate the network
let val_loss = net.validate(&inputs, &targets);
Expand Down
58 changes: 29 additions & 29 deletions src/objective.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use std::fmt::Display;
pub enum Objective {
AE,
MAE,
// MSE,
MSE,
RMSE,
// BinaryCrossEntropy,
BinaryCrossEntropy,
CategoricalCrossEntropy,
}

Expand All @@ -34,9 +34,9 @@ impl Display for Function {
match self.objective {
Objective::AE => write!(f, "AE"),
Objective::MAE => write!(f, "MAE"),
// Objective::MSE => write!(f, "MSE"),
Objective::MSE => write!(f, "MSE"),
Objective::RMSE => write!(f, "RMSE"),
// Objective::BinaryCrossEntropy => write!(f, "BinaryCrossEntropy"),
Objective::BinaryCrossEntropy => write!(f, "BinaryCrossEntropy"),
Objective::CategoricalCrossEntropy => write!(f, "CategoricalCrossEntropy"),
}
}
Expand All @@ -47,9 +47,9 @@ impl Function {
match objective {
Objective::AE => Function { objective: Objective::AE },
Objective::MAE => Function { objective: Objective::MAE },
// Objective::MSE => Function { objective: Objective::MSE },
Objective::MSE => Function { objective: Objective::MSE },
Objective::RMSE => Function { objective: Objective::RMSE },
// Objective::BinaryCrossEntropy => Function { objective: Objective::BinaryCrossEntropy },
Objective::BinaryCrossEntropy => Function { objective: Objective::BinaryCrossEntropy },
Objective::CategoricalCrossEntropy => Function { objective: Objective::CategoricalCrossEntropy },
}
}
Expand Down Expand Up @@ -88,15 +88,15 @@ impl Function {
).collect();
(loss, gradient)
},
// Objective::MSE => {
// let loss: f32 = y.iter().zip(out.iter())
// .map(|(actual, predicted)| (actual - predicted).powi(2))
// .sum::<f32>() / y.len() as f32;
// let gradient: Vec<f32> = y.iter().zip(out.iter())
// .map(|(actual, predicted)| 2.0 * (actual - predicted) / y.len() as f32)
// .collect();
// (loss, gradient)
// },
Objective::MSE => {
let loss: f32 = y.iter().zip(out.iter())
.map(|(actual, predicted)| (actual - predicted).powi(2))
.sum::<f32>() / y.len() as f32;
let gradient: Vec<f32> = y.iter().zip(out.iter())
.map(|(actual, predicted)| -2.0 * (actual - predicted) / y.len() as f32)
.collect();
(loss, gradient)
},
Objective::RMSE => {
let loss: f32 = y.iter().zip(out.iter())
.map(|(actual, predicted)| (actual - predicted).powi(2))
Expand All @@ -112,20 +112,20 @@ impl Function {
).collect();
(loss, gradient)
},
// Objective::BinaryCrossEntropy => {
// let eps: f32 = 1e-7;
// let loss: f32 = -y.iter().zip(out.iter())
// .map(|(actual, predicted)| {
// let predicted = predicted.clamp(eps, 1.0 - eps);
// actual * predicted.ln() + (1.0 - actual) * (1.0 - predicted).ln()
// }).sum::<f32>() / y.len() as f32;
// let gradient: Vec<f32> = y.iter().zip(out.iter())
// .map(|(actual, predicted)| {
// let predicted = predicted.clamp(eps, 1.0 - eps);
// (predicted - actual) / (predicted * (1.0 - predicted))
// }).collect();
// (loss, gradient)
// },
Objective::BinaryCrossEntropy => {
let eps: f32 = 1e-7;
let loss: f32 = -y.iter().zip(out.iter())
.map(|(actual, predicted)| {
let predicted = predicted.clamp(eps, 1.0 - eps);
actual * predicted.ln() + (1.0 - actual) * (1.0 - predicted).ln()
}).sum::<f32>() / y.len() as f32;
let gradient: Vec<f32> = y.iter().zip(out.iter())
.map(|(actual, predicted)| {
let predicted = predicted.clamp(eps, 1.0 - eps);
(predicted - actual) / (predicted * (1.0 - predicted))
}).collect();
(loss, gradient)
},
Objective::CategoricalCrossEntropy => {
let eps: f32 = 1e-7;
let loss: f32 = -y.iter().zip(out.iter())
Expand Down

0 comments on commit ed5b132

Please sign in to comment.