Skip to content

Commit

Permalink
Merge pull request #69 from Ebanflo42/xla/mnist
Browse files Browse the repository at this point in the history
Basic MNIST example working with XLA :)
  • Loading branch information
BradenEverson authored Apr 4, 2024
2 parents a569ba5 + d77efc0 commit 7ee66bc
Show file tree
Hide file tree
Showing 16 changed files with 1,302 additions and 664 deletions.
95 changes: 48 additions & 47 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ backtrace = "0.3"
smallvec = "1.13"
strum = "0.26"
strum_macros = "0.26"
xla = { git = "https://github.com/Ebanflo42/xla-rs", version = "0.1.6" , branch = "dev" }
xla = { git = "https://github.com/Ebanflo42/xla-rs", version = "0.1.6" , branch = "main" }
thiserror = "1"
half = "2.4.0"
byteorder = "1.5"

[features]
default = ["util"]
Expand Down
31 changes: 25 additions & 6 deletions examples/mnist.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
use unda::{core::{data::{input::Input, matrix::Matrix}, network::Sequential, layer::{layers::{LayerTypes, InputTypes}, methods::{activations::Activations, errors::ErrorTypes}}}, util::{mnist::MnistEntry, categorical::to_categorical}};
use unda::{
core::{
data::{input::Input, matrix::Matrix},
layer::{
layers::{InputTypes, LayerTypes},
methods::{activations::Activations, errors::ErrorTypes},
},
network::Sequential,
},
util::{categorical::to_categorical, mnist::MnistEntry},
};

fn main() {
let mut inputs: Vec<&dyn Input> = vec![];

let mut true_outputs: Vec<Vec<f32>> = vec![];

let inputs_undyn: Vec<Matrix>;
Expand All @@ -13,7 +23,7 @@ fn main() {
println!("Done Generating MNIST");

let outputs: Vec<Vec<f32>> = to_categorical(outputs_uncat);
for i in 0..600{
for i in 0..600 {
inputs.push(&inputs_undyn[i]);
true_outputs.push(outputs[i].clone());
}
Expand All @@ -28,8 +38,17 @@ fn main() {

network.compile();

network.fit(&inputs, &true_outputs, 1, ErrorTypes::CategoricalCrossEntropy);
for i in 0..10{
println!("predicted: {:?} \n\nactual: {:?}\n\n\n", network.predict(inputs[i]), true_outputs[i]);
network.fit(
&inputs,
&true_outputs,
1,
ErrorTypes::CategoricalCrossEntropy,
);
for i in 0..10 {
println!(
"predicted: {:?} \n\nactual: {:?}\n\n\n",
network.predict(inputs[i]),
true_outputs[i]
);
}
}
Loading

0 comments on commit 7ee66bc

Please sign in to comment.