-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add some test funcs for dropout and cudnnlstm
- Loading branch information
1 parent
974f1a8
commit d560a82
Showing
4 changed files
with
428 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
""" | ||
Author: Wenyu Ouyang | ||
Date: 2024-10-09 14:22:18 | ||
LastEditTime: 2024-10-09 16:53:57 | ||
LastEditors: Wenyu Ouyang | ||
Description: Test functions for CudnnLstmModel | ||
FilePath: \torchhydro\tests\test_cudnnlstm.py | ||
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved. | ||
""" | ||
|
||
import torch | ||
import pytest | ||
from torchhydro.models.cudnnlstm import CudnnLstmModel | ||
from torchhydro.models.cudnnlstm import CudnnLstmModel, CudnnLstm | ||
|
||
|
||
def test_mc_dropout_eval(): | ||
# Monte Carlo Dropout sampling during evaluation | ||
model = CudnnLstmModel( | ||
n_input_features=10, n_output_features=1, n_hidden_states=50, dr=0.5 | ||
) | ||
model = model.to("cuda:0") | ||
# simulate input data | ||
input_data = torch.randn(20, 5, 10) # [seq_len, batch_size, input_size] | ||
input_data = input_data.to("cuda:0") | ||
# multiple forward passes to estimate the model's uncertainty using Monte Carlo Dropout | ||
num_samples = 10 | ||
mc_outputs = [] | ||
# set to training mode to enable Monte Carlo Dropout | ||
model.train() | ||
# during evaluation, we don't need to calculate gradients | ||
with torch.no_grad(): | ||
for _ in range(num_samples): | ||
# force the model to use dropout | ||
output = model(input_data, do_drop_mc=True) | ||
mc_outputs.append(output) | ||
|
||
# stack the outputs along the first dimension | ||
mc_outputs = torch.stack(mc_outputs) | ||
|
||
# calculate the mean and variance of the outputs to estimate the model's uncertainty | ||
mean_output = mc_outputs.mean(dim=0) | ||
variance_output = mc_outputs.var(dim=0) | ||
|
||
print("mean value: ", mean_output) | ||
print("var value: ", variance_output) | ||
|
||
|
||
def test_setstate_with_all_weights(): | ||
model = CudnnLstm(input_size=10, hidden_size=20, dr=0.5) | ||
# state_dict returns a dictionary containing whole weights and bias of the module | ||
state_dict = model.state_dict() | ||
state_dict["all_weights"] = [["w_ih", "w_hh", "b_ih", "b_hh"]] | ||
|
||
model.__setstate__(state_dict) | ||
|
||
assert model._all_weights == [["w_ih", "w_hh", "b_ih", "b_hh"]] | ||
|
||
|
||
def test_setstate_without_all_weights(): | ||
model = CudnnLstm(input_size=10, hidden_size=20, dr=0.5) | ||
state_dict = model.state_dict() | ||
|
||
model.__setstate__(state_dict) | ||
|
||
assert model._all_weights == [["w_ih", "w_hh", "b_ih", "b_hh"]] | ||
|
||
|
||
def test_setstate_with_non_string_all_weights(): | ||
model = CudnnLstm(input_size=10, hidden_size=20, dr=0.5) | ||
state_dict = model.state_dict() | ||
state_dict["all_weights"] = [[0, 1, 2, 3]] | ||
|
||
model.__setstate__(state_dict) | ||
|
||
assert model._all_weights == [["w_ih", "w_hh", "b_ih", "b_hh"]] | ||
|
||
|
||
def test_setstate_with_data_ptrs(): | ||
model = CudnnLstm(input_size=10, hidden_size=20, dr=0.5) | ||
state_dict = model.state_dict() | ||
state_dict["_data_ptrs"] = [1, 2, 3] | ||
|
||
model.__setstate__(state_dict) | ||
|
||
assert model.__dict__.get("_data_ptrs") == [1, 2, 3] | ||
|
||
|
||
def test_reset_mask(): | ||
model = CudnnLstm(input_size=10, hidden_size=20, dr=0.5) | ||
model.reset_mask() | ||
|
||
assert model.mask_w_ih is not None, "mask_w_ih should not be None after reset_mask" | ||
assert model.mask_w_hh is not None, "mask_w_hh should not be None after reset_mask" | ||
assert ( | ||
model.mask_w_ih.shape == model.w_ih.shape | ||
), "mask_w_ih should have the same shape as w_ih" | ||
assert ( | ||
model.mask_w_hh.shape == model.w_hh.shape | ||
), "mask_w_hh should have the same shape as w_hh" | ||
|
||
|
||
def test_reset_mask_with_different_dropout(): | ||
model = CudnnLstm(input_size=10, hidden_size=20, dr=0.3) | ||
model.reset_mask() | ||
|
||
assert model.mask_w_ih is not None, "mask_w_ih should not be None after reset_mask" | ||
assert model.mask_w_hh is not None, "mask_w_hh should not be None after reset_mask" | ||
assert ( | ||
model.mask_w_ih.shape == model.w_ih.shape | ||
), "mask_w_ih should have the same shape as w_ih" | ||
assert ( | ||
model.mask_w_hh.shape == model.w_hh.shape | ||
), "mask_w_hh should have the same shape as w_hh" | ||
|
||
|
||
def test_reset_mask_with_zero_dropout(): | ||
model = CudnnLstm(input_size=10, hidden_size=20, dr=0.0) | ||
model.reset_mask() | ||
|
||
assert model.mask_w_ih is not None, "mask_w_ih should not be None after reset_mask" | ||
assert model.mask_w_hh is not None, "mask_w_hh should not be None after reset_mask" | ||
assert ( | ||
model.mask_w_ih.shape == model.w_ih.shape | ||
), "mask_w_ih should have the same shape as w_ih" | ||
assert ( | ||
model.mask_w_hh.shape == model.w_hh.shape | ||
), "mask_w_hh should have the same shape as w_hh" | ||
|
||
|
||
def test_forward_with_mc_dropout(): | ||
model = CudnnLstm(input_size=10, hidden_size=20, dr=0.5) | ||
model = model.to("cuda:0") | ||
input_data = torch.randn(20, 5, 10) # [seq_len, batch_size, input_size] | ||
input_data = input_data.to("cuda:0") | ||
model.train() # Ensure dropout is enabled | ||
output, (hy, cy) = model(input_data, do_drop_mc=True) | ||
|
||
assert output.shape == (20, 5, 20), "Output shape mismatch" | ||
assert hy.shape == (1, 5, 20), "Hidden state shape mismatch" | ||
assert cy.shape == (1, 5, 20), "Cell state shape mismatch" | ||
|
||
|
||
def test_forward_with_dropout(): | ||
model = CudnnLstm(input_size=10, hidden_size=20, dr=0.5) | ||
input_data = torch.randn(20, 5, 10) # [seq_len, batch_size, input_size] | ||
model = model.to("cuda:0") | ||
input_data = input_data.to("cuda:0") | ||
model.train() | ||
output, (hy, cy) = model(input_data, do_drop_mc=False) | ||
|
||
assert output.shape == (20, 5, 20), "Output shape mismatch" | ||
assert hy.shape == (1, 5, 20), "Hidden state shape mismatch" | ||
assert cy.shape == (1, 5, 20), "Cell state shape mismatch" | ||
|
||
|
||
def test_forward_with_zero_dropout(): | ||
model = CudnnLstm(input_size=10, hidden_size=20, dr=0.0) | ||
input_data = torch.randn(20, 5, 10) # [seq_len, batch_size, input_size] | ||
model = model.to("cuda:0") | ||
input_data = input_data.to("cuda:0") | ||
model.train() # Dropout rate is zero, so dropout should be disabled | ||
output, (hy, cy) = model(input_data, do_drop_mc=True) | ||
|
||
assert output.shape == (20, 5, 20), "Output shape mismatch" | ||
assert hy.shape == (1, 5, 20), "Hidden state shape mismatch" | ||
assert cy.shape == (1, 5, 20), "Cell state shape mismatch" | ||
|
||
|
||
def test_forward_with_dropout_false(): | ||
model = CudnnLstm(input_size=10, hidden_size=20, dr=0.5) | ||
input_data = torch.randn(20, 5, 10) # [seq_len, batch_size, input_size] | ||
model = model.to("cuda:0") | ||
input_data = input_data.to("cuda:0") | ||
model.train() # Ensure dropout is enabled | ||
output, (hy, cy) = model(input_data, do_drop_mc=False, dropout_false=True) | ||
|
||
assert output.shape == (20, 5, 20), "Output shape mismatch" | ||
assert hy.shape == (1, 5, 20), "Hidden state shape mismatch" | ||
assert cy.shape == (1, 5, 20), "Cell state shape mismatch" | ||
|
||
|
||
def test_forward_with_initial_states(): | ||
model = CudnnLstm(input_size=10, hidden_size=20, dr=0.5) | ||
input_data = torch.randn(20, 5, 10) # [seq_len, batch_size, input_size] | ||
hx = torch.randn(1, 5, 20) | ||
cx = torch.randn(1, 5, 20) | ||
model = model.to("cuda:0") | ||
input_data = input_data.to("cuda:0") | ||
hx = hx.to("cuda:0") | ||
cx = cx.to("cuda:0") | ||
model.train() # Ensure dropout is enabled | ||
output, (hy, cy) = model(input_data, hx=hx, cx=cx, do_drop_mc=True) | ||
|
||
assert output.shape == (20, 5, 20), "Output shape mismatch" | ||
assert hy.shape == (1, 5, 20), "Hidden state shape mismatch" | ||
assert cy.shape == (1, 5, 20), "Cell state shape mismatch" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
""" | ||
Author: Wenyu Ouyang | ||
Date: 2024-10-09 14:34:41 | ||
LastEditTime: 2024-10-09 15:44:55 | ||
LastEditors: Wenyu Ouyang | ||
Description: Test functions for dropout | ||
FilePath: \torchhydro\tests\test_dropout.py | ||
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved. | ||
""" | ||
|
||
import pytest | ||
import torch | ||
from torch.autograd import Function | ||
from torchhydro.models.dropout import create_mask | ||
from torchhydro.models.dropout import DropMask | ||
|
||
|
||
def test_create_mask_shape(): | ||
x = torch.randn(10, 10) | ||
dr = 0.5 | ||
mask = create_mask(x, dr) | ||
assert mask.shape == x.shape, "Mask shape should match input shape" | ||
|
||
|
||
def test_create_mask_values(): | ||
x = torch.randn(10, 10) | ||
dr = 0.5 | ||
mask = create_mask(x, dr) | ||
assert torch.all((mask == 0) | (mask == 2)), "Mask values should be either 0 or 2" | ||
|
||
|
||
def test_create_mask_dropout_rate(): | ||
x = torch.randn(1000, 1000) | ||
dr = 0.5 | ||
mask = create_mask(x, dr) | ||
dropout_rate = (mask == 0).float().mean().item() | ||
assert ( | ||
abs(dropout_rate - dr) < 0.05 | ||
), "Dropout rate should be close to the specified rate" | ||
|
||
|
||
def test_create_mask_no_dropout(): | ||
x = torch.randn(10, 10) | ||
dr = 0.0 | ||
mask = create_mask(x, dr) | ||
assert torch.all(mask == 1), "Mask should be all ones when dropout rate is 0" | ||
|
||
|
||
def test_create_mask_full_dropout(): | ||
x = torch.randn(10, 10) | ||
dr = 1.0 | ||
mask = create_mask(x, dr) | ||
assert torch.all(mask == 0), "Mask should be all zeros when dropout rate is 1" | ||
|
||
|
||
class MockContext(Function): | ||
def __init__(self): | ||
super(MockContext, self).__init__() | ||
self.master_train = None | ||
self.inplace = None | ||
self.mask = None | ||
|
||
|
||
def test_forward_train_inplace(): | ||
ctx = MockContext() | ||
input = torch.randn(10, 10) | ||
mask = torch.randint(0, 2, (10, 10)).float() | ||
ctx.master_train = True | ||
ctx.inplace = True | ||
output = DropMask.forward(ctx, input, mask, train=True, inplace=True) | ||
assert torch.equal( | ||
output, input * mask | ||
), "Output should be input multiplied by mask when training and inplace" | ||
|
||
|
||
def test_forward_train_not_inplace(): | ||
ctx = MockContext() | ||
input = torch.randn(10, 10) | ||
mask = torch.randint(0, 2, (10, 10)).float() | ||
ctx.master_train = True | ||
ctx.inplace = False | ||
output = DropMask.forward(ctx, input, mask, train=True, inplace=False) | ||
assert torch.equal( | ||
output, input * mask | ||
), "Output should be input multiplied by mask when training and not inplace" | ||
assert not torch.equal( | ||
output, input | ||
), "Output should not be the same as input when not inplace" | ||
|
||
|
||
def test_forward_no_train(): | ||
ctx = MockContext() | ||
input = torch.randn(10, 10) | ||
mask = torch.randint(0, 2, (10, 10)).float() | ||
ctx.master_train = False | ||
ctx.inplace = False | ||
output = DropMask.forward(ctx, input, mask, train=False, inplace=False) | ||
assert torch.equal( | ||
output, input | ||
), "Output should be the same as input when not training" | ||
|
||
|
||
def test_forward_no_train_inplace(): | ||
ctx = MockContext() | ||
input = torch.randn(10, 10) | ||
mask = torch.randint(0, 2, (10, 10)).float() | ||
ctx.master_train = False | ||
ctx.inplace = True | ||
output = DropMask.forward(ctx, input, mask, train=False, inplace=True) | ||
assert torch.equal( | ||
output, input | ||
), "Output should be the same as input when not training and inplace" | ||
|
||
|
||
def test_backward_train(): | ||
ctx = MockContext() | ||
ctx.master_train = True | ||
ctx.mask = torch.randint(0, 2, (10, 10)).float() | ||
grad_output = torch.randn(10, 10) | ||
grad_input, _, _, _ = DropMask.backward(ctx, grad_output) | ||
assert torch.equal( | ||
grad_input, grad_output * ctx.mask | ||
), "Gradient input should be masked when training" | ||
|
||
|
||
def test_backward_no_train(): | ||
ctx = MockContext() | ||
ctx.master_train = False | ||
ctx.mask = torch.randint(0, 2, (10, 10)).float() | ||
grad_output = torch.randn(10, 10) | ||
grad_input, _, _, _ = DropMask.backward(ctx, grad_output) | ||
assert torch.equal( | ||
grad_input, grad_output | ||
), "Gradient input should be the same as gradient output when not training" |
Oops, something went wrong.