-
-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #70 from vyomakesh09/master
[mod]
- Loading branch information
Showing
6 changed files
with
50 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import os | ||
import shutil | ||
import sys | ||
|
||
|
||
def delete_pycache(directory): | ||
for root, dirs, files in os.walk(directory): | ||
if "__pycache__" in dirs: | ||
shutil.rmtree(os.path.join(root, "__pycache__")) | ||
|
||
|
||
if __name__ == "__main__": | ||
if len(sys.argv) != 2: | ||
print("Usage: python delete_pycache.py <directory>") | ||
sys.exit(1) | ||
|
||
directory = sys.argv[1] | ||
delete_pycache(directory) | ||
print(f"__pycache__ directories deleted in {directory}") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,36 @@ | ||
import torch | ||
import torch.nn as nn | ||
import unittest | ||
|
||
import pytest | ||
from zeta.nn.modules.dense_connect import DenseBlock | ||
|
||
|
||
class DenseBlockTestCase(unittest.TestCase): | ||
def setUp(self): | ||
self.submodule = nn.Linear(10, 5) | ||
self.dense_block = DenseBlock(self.submodule) | ||
@pytest.fixture | ||
def dense_block(): | ||
submodule = nn.Linear(10, 5) | ||
return DenseBlock(submodule) | ||
|
||
|
||
def test_forward(self): | ||
x = torch.randn(32, 10) | ||
output = self.dense_block(x) | ||
def test_forward(dense_block): | ||
x = torch.randn(32, 10) | ||
output = dense_block(x) | ||
|
||
self.assertEqual(output.shape, (32, 15)) # Check output shape | ||
self.assertTrue( | ||
torch.allclose(output[:, :10], x) | ||
) # Check if input is preserved | ||
self.assertTrue( | ||
torch.allclose(output[:, 10:], self.submodule(x)) | ||
) # Check submodule output | ||
assert output.shape == (32, 15) # Check output shape | ||
assert torch.allclose(output[:, :10], x) # Check if input is preserved | ||
assert torch.allclose( | ||
output[:, 10:], dense_block.submodule(x) | ||
) # Check submodule output | ||
|
||
def test_initialization(self): | ||
self.assertEqual( | ||
self.dense_block.submodule, self.submodule | ||
) # Check submodule assignment | ||
|
||
def test_docstrings(self): | ||
self.assertIsNotNone( | ||
DenseBlock.__init__.__doc__ | ||
) # Check if __init__ has a docstring | ||
self.assertIsNotNone( | ||
DenseBlock.forward.__doc__ | ||
) # Check if forward has a docstring | ||
def test_initialization(dense_block): | ||
assert isinstance(dense_block.submodule, nn.Linear) # Check submodule type | ||
assert dense_block.submodule.in_features == 10 # Check input features | ||
assert dense_block.submodule.out_features == 5 # Check output features | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
def test_docstrings(): | ||
assert ( | ||
DenseBlock.__init__.__doc__ is not None | ||
) # Check if __init__ has a docstring | ||
assert ( | ||
DenseBlock.forward.__doc__ is not None | ||
) # Check if forward has a docstring |