-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add deconvolution layer #22. Extend examples.
- Loading branch information
1 parent
46c6a8b
commit ebe42c9
Showing
21 changed files
with
1,749 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
# Define the deconvolution layer | ||
class DeconvolutionLayer(nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding): | ||
super(DeconvolutionLayer, self).__init__() | ||
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding) | ||
|
||
def forward(self, x): | ||
return self.deconv(x) | ||
|
||
# Test the backward pass | ||
def test_backward_pass(): | ||
# Define the input parameters | ||
in_channels = 1 | ||
out_channels = 1 | ||
kernel_size = (3, 3) | ||
stride = (2, 2) | ||
padding = (1, 1) | ||
|
||
# Create the deconvolution layer | ||
layer = DeconvolutionLayer(in_channels, out_channels, kernel_size, stride, padding) | ||
|
||
# Define the input tensor | ||
input_tensor = torch.tensor([[[[1.0, 2.0, 3.0, 4.0], | ||
[5.0, 6.0, 7.0, 8.0], | ||
[9.0, 10.0, 11.0, 12.0], | ||
[13.0, 14.0, 15.0, 16.0]]]], requires_grad=True) | ||
|
||
# Define the gradient tensor | ||
grad_tensor = torch.tensor([[[[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], | ||
[0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5], | ||
[1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3], | ||
[2.5, 2.6, 2.7, 2.8, 2.9, 3.0, 3.1], | ||
[3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9], | ||
[4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7], | ||
[4.9, 5.0, 5.1, 5.2, 5.3, 5.4, 5.5]]]]) | ||
|
||
# Forward pass | ||
output = layer(input_tensor) | ||
|
||
# Backward pass | ||
output.backward(grad_tensor) | ||
|
||
# Get the gradients | ||
input_grad = input_tensor.grad | ||
weight_grad = layer.deconv.weight.grad | ||
|
||
# Check the shapes | ||
assert input_grad.shape == input_tensor.shape | ||
assert weight_grad.shape == layer.deconv.weight.shape | ||
|
||
print("Input gradient shape:", input_grad.shape) | ||
print("Weight gradient shape:", weight_grad.shape) | ||
|
||
print("Kernel gradient:\n", weight_grad) | ||
|
||
# Run the test | ||
test_backward_pass() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
// Copyright (C) 2024 Hallvard Høyland Lavik | ||
|
||
use neurons::{activation, feedback, network, objective, optimizer, plot, tensor}; | ||
|
||
use std::{ | ||
fs::File, | ||
io::{BufRead, BufReader}, | ||
}; | ||
|
||
fn data( | ||
path: &str, | ||
) -> ( | ||
( | ||
Vec<tensor::Tensor>, | ||
Vec<tensor::Tensor>, | ||
Vec<tensor::Tensor>, | ||
), | ||
( | ||
Vec<tensor::Tensor>, | ||
Vec<tensor::Tensor>, | ||
Vec<tensor::Tensor>, | ||
), | ||
( | ||
Vec<tensor::Tensor>, | ||
Vec<tensor::Tensor>, | ||
Vec<tensor::Tensor>, | ||
), | ||
) { | ||
let reader = BufReader::new(File::open(&path).unwrap()); | ||
|
||
let mut x_train: Vec<tensor::Tensor> = Vec::new(); | ||
let mut y_train: Vec<tensor::Tensor> = Vec::new(); | ||
let mut class_train: Vec<tensor::Tensor> = Vec::new(); | ||
|
||
let mut x_test: Vec<tensor::Tensor> = Vec::new(); | ||
let mut y_test: Vec<tensor::Tensor> = Vec::new(); | ||
let mut class_test: Vec<tensor::Tensor> = Vec::new(); | ||
|
||
let mut x_val: Vec<tensor::Tensor> = Vec::new(); | ||
let mut y_val: Vec<tensor::Tensor> = Vec::new(); | ||
let mut class_val: Vec<tensor::Tensor> = Vec::new(); | ||
|
||
for line in reader.lines().skip(1) { | ||
let line = line.unwrap(); | ||
let record: Vec<&str> = line.split(',').collect(); | ||
|
||
let mut data: Vec<f32> = Vec::new(); | ||
for i in 0..571 { | ||
data.push(record.get(i).unwrap().parse::<f32>().unwrap()); | ||
} | ||
match record.get(573).unwrap() { | ||
&"Train" => { | ||
x_train.push(tensor::Tensor::single(data)); | ||
y_train.push(tensor::Tensor::single(vec![record | ||
.get(571) | ||
.unwrap() | ||
.parse::<f32>() | ||
.unwrap()])); | ||
class_train.push(tensor::Tensor::one_hot( | ||
record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed. | ||
28, | ||
)); | ||
} | ||
&"Test" => { | ||
x_test.push(tensor::Tensor::single(data)); | ||
y_test.push(tensor::Tensor::single(vec![record | ||
.get(571) | ||
.unwrap() | ||
.parse::<f32>() | ||
.unwrap()])); | ||
class_test.push(tensor::Tensor::one_hot( | ||
record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed. | ||
28, | ||
)); | ||
} | ||
&"Val" => { | ||
x_val.push(tensor::Tensor::single(data)); | ||
y_val.push(tensor::Tensor::single(vec![record | ||
.get(571) | ||
.unwrap() | ||
.parse::<f32>() | ||
.unwrap()])); | ||
class_val.push(tensor::Tensor::one_hot( | ||
record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed. | ||
28, | ||
)); | ||
} | ||
_ => panic!("> Unknown class."), | ||
} | ||
} | ||
|
||
// let mut generator = random::Generator::create(12345); | ||
// let mut indices: Vec<usize> = (0..x.len()).collect(); | ||
// generator.shuffle(&mut indices); | ||
|
||
( | ||
(x_train, y_train, class_train), | ||
(x_test, y_test, class_test), | ||
(x_val, y_val, class_val), | ||
) | ||
} | ||
|
||
fn main() { | ||
// Load the ftir dataset | ||
let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) = | ||
data("./examples/datasets/ftir.csv"); | ||
|
||
let x_train: Vec<&tensor::Tensor> = x_train.iter().collect(); | ||
let _y_train: Vec<&tensor::Tensor> = y_train.iter().collect(); | ||
let class_train: Vec<&tensor::Tensor> = class_train.iter().collect(); | ||
|
||
let x_test: Vec<&tensor::Tensor> = x_test.iter().collect(); | ||
let y_test: Vec<&tensor::Tensor> = y_test.iter().collect(); | ||
let class_test: Vec<&tensor::Tensor> = class_test.iter().collect(); | ||
|
||
let x_val: Vec<&tensor::Tensor> = x_val.iter().collect(); | ||
let _y_val: Vec<&tensor::Tensor> = y_val.iter().collect(); | ||
let class_val: Vec<&tensor::Tensor> = class_val.iter().collect(); | ||
|
||
println!("Train data {}x{}", x_train.len(), x_train[0].shape,); | ||
println!("Test data {}x{}", x_test.len(), x_test[0].shape,); | ||
println!("Validation data {}x{}", x_val.len(), x_val[0].shape,); | ||
|
||
// Create the network | ||
let mut network = network::Network::new(tensor::Shape::Single(571)); | ||
|
||
network.dense(100, activation::Activation::ReLU, false, None); | ||
network.convolution( | ||
1, | ||
(3, 3), | ||
(1, 1), | ||
(1, 1), | ||
activation::Activation::ReLU, | ||
None, | ||
); | ||
network.dense(28, activation::Activation::Softmax, false, None); | ||
|
||
network.set_accumulation(feedback::Accumulation::Mean); | ||
|
||
network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None)); | ||
network.set_objective(objective::Objective::CrossEntropy, None); | ||
|
||
println!("{}", network); | ||
|
||
// Train the network | ||
let (train_loss, val_loss, val_acc) = network.learn( | ||
&x_train, | ||
&class_train, | ||
Some((&x_val, &class_val, 50)), | ||
16, | ||
500, | ||
Some(100), | ||
); | ||
plot::loss( | ||
&train_loss, | ||
&val_loss, | ||
&val_acc, | ||
"FEEDBACK : FTIR", | ||
"./static/ftir-cnn-feedback.png", | ||
); | ||
|
||
// Validate the network | ||
let (val_loss, val_acc) = network.validate(&x_test, &class_test, 1e-6); | ||
println!( | ||
"Final validation accuracy: {:.2} % and loss: {:.5}", | ||
val_acc * 100.0, | ||
val_loss | ||
); | ||
|
||
// Use the network | ||
let prediction = network.predict(x_test.get(0).unwrap()); | ||
println!( | ||
"Prediction. Target: {}. Output: {}.", | ||
class_test[0].argmax(), | ||
prediction.argmax() | ||
); | ||
} |
Oops, something went wrong.