Skip to content
This repository has been archived by the owner on Jan 23, 2024. It is now read-only.

Commit

Permalink
Bug fix, improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
commanderxa committed Oct 8, 2023
1 parent 02268ae commit b9c52eb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
1 change: 0 additions & 1 deletion nn/dscovry/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(self, annotation_files: list[str]) -> None:
)
)
self.data = pd.concat(data, axis=0, ignore_index=True)
print(len(self.data))

def __len__(self) -> int:
return len(self.data)
Expand Down
7 changes: 4 additions & 3 deletions nn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def my_app(cfg: DictConfig) -> None:
# model = torch.compile(model)

# load the model
checkpoint = torch.load(f"models/{cfg.model.name}_{cfg.hyper.n_hidden}_a.pt")
checkpoint = torch.load(f"models/{cfg.model.name}_{cfg.hyper.n_hidden}.pt")
model.load_state_dict(checkpoint["model"])
model.eval()

Expand Down Expand Up @@ -69,8 +69,9 @@ def evaluate(model: DSCOVRYModel, dataloader: DataLoader, cfg):

accuracy = epoch_accuracy / len(dataloader)
loss = epoch_loss / len(dataloader)
print(f"Accuracy ({Config.cfg.hyper.tolerance} tolearnce): {accuracy}%")
print(f"Loss: {loss}")
print(f"\n|\n| Accuracy: {accuracy:.2f}%")
print(f"| Accuracy tolerance: {Config.cfg.hyper.tolerance}")
print(f"| Loss: {loss}\n|\n")


if __name__ == "__main__":
Expand Down

0 comments on commit b9c52eb

Please sign in to comment.