Skip to content

Commit

Permalink
Added Error type and support for MeanAbsolute and MeanSquared(not wor…
Browse files Browse the repository at this point in the history
…king yet I dont think)
  • Loading branch information
BradenEverson committed Feb 1, 2024
1 parent 69235fa commit 10fc2e7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
12 changes: 7 additions & 5 deletions src/core/network.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::layer::layers::{Layer, LayerTypes};
use super::layer::methods::errors::ErrorTypes;
use super::layer::methods::pair::GradientPair;
use super::data::matrix::Matrix;
use super::data::input::Input;
Expand Down Expand Up @@ -203,15 +204,16 @@ impl Network{
///bias updating is different as well
///
///When constructing a neural network, be cautious that your layers behave well with each other
fn back_propegate(&mut self, outputs: &Vec<f32>, target_obj: &Box<dyn Input>) {
let mut parsed = Matrix::from(outputs.to_param_2d());
fn back_propegate(&mut self, outputs: &Vec<f32>, target_obj: &Box<dyn Input>, loss: &ErrorTypes) {
//let mut parsed = Matrix::from(outputs.to_param_2d());

if let None = self.layers[self.layers.len()-1].get_activation() {
panic!("Output layer is not a dense layer");
}

let mut gradients: Box<dyn Input>;
let mut errors: Box<dyn Input> = Box::new((parsed - &Matrix::from(target_obj.to_param_2d())).transpose());
let actual: Box<dyn Input> = Box::new(outputs.clone());
let mut errors: Box<dyn Input> = loss.get_error(&actual, target_obj, 1);//Box::new((parsed - &Matrix::from(target_obj.to_param_2d())).transpose());

for i in (0..self.layers.len() - 1).rev() {
gradients = self.layers[i + 1].update_gradient();
Expand Down Expand Up @@ -249,7 +251,7 @@ impl Network{
///compared to what is actually derived during back propegation
///* `epochs` - How many epochs you want your model training for
///
pub fn fit(&mut self, train_in: &Vec<&dyn Input>, train_out: &Vec<Vec<f32>>, epochs: usize) {
pub fn fit(&mut self, train_in: &Vec<&dyn Input>, train_out: &Vec<Vec<f32>>, epochs: usize, error_fn: ErrorTypes) {
self.loss_train = vec![];

let mut loss: f32;
Expand Down Expand Up @@ -284,7 +286,7 @@ impl Network{
let input: Box<dyn Input> = train_in[input_index].to_box();
let output: Box<dyn Input> = Box::new(train_out[input_index].clone());
let outputs = self.feed_forward(&input);
self.back_propegate(&outputs, &output);
self.back_propegate(&outputs, &output, &error_fn);

for i in 0..outputs.len() {
loss_on_input += (outputs[i] - train_out[input_index].to_param()[i]).powi(2);
Expand Down
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use triton_grow::core::{data::{input::Input, matrix::Matrix, matrix3d::Matrix3D}, network::Network, layer::{layers::LayerTypes, methods::activations::Activations}};
use triton_grow::core::{data::{input::Input, matrix::Matrix, matrix3d::Matrix3D}, network::Network, layer::{layers::LayerTypes, methods::{activations::Activations, errors::ErrorTypes}}};


#[tokio::main]
Expand Down Expand Up @@ -65,7 +65,7 @@ async fn main() {
println!("0 and 0: {:?}", new_net.predict(&vec![0.0,0.0])[0]);


new_net.fit(&inputs, &outputs, 2);
new_net.fit(&inputs, &outputs, 2, ErrorTypes::MeanAbsolute);
println!("1 and 0: {:?}", new_net.predict(&vec![1.0,0.0])[0]);
println!("0 and 1: {:?}", new_net.predict(&vec![0.0,1.0])[0]);
println!("1 and 1: {:?}", new_net.predict(&vec![1.0,1.0])[0]);
Expand Down

0 comments on commit 10fc2e7

Please sign in to comment.