Skip to content

Commit

Permalink
fix predict
Browse files Browse the repository at this point in the history
  • Loading branch information
Atashnezhad committed Aug 7, 2023
1 parent 374ee15 commit 17b8d7d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,14 @@ if __name__ == "__main__":
```

## Using TransferLearning Module

```python
from neural_network_model.transfer_learning import TransferModel
from pathlib import Path


transfer_model = TransferModel(
dataset_address=Path(__file__).parent / "dataset"
)
dataset_address=Path(__file__).parent / "dataset"
)

transfer_model.plot_classes_number()
transfer_model.analyze_image_names()
Expand All @@ -164,7 +164,7 @@ transfer_model.train_model(epochs=3,
transfer_model.plot_metrics_results()
transfer_model.results()
# one can pass the model address to the predict_test method
transfer_model.predcit_test()
transfer_model.predict_test()
transfer_model.grad_cam_viz(num_rows=3, num_cols=2)
```

Expand Down
6 changes: 3 additions & 3 deletions neural_network_model/transfer_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def results(self):
print(" ## Test Loss: {:.5f}".format(results[0]))
print("## Accuracy on the test set: {:.2f}%".format(results[1] * 100))

def predcit_test(self, model_path: str = None, **kwargs):
def predict_test(self, model_path: str = None, **kwargs):
(
train_generator,
test_generator,
Expand Down Expand Up @@ -635,7 +635,7 @@ def predcit_test(self, model_path: str = None, **kwargs):

train_df, test_df = self.train_test_split()
y_test = list(test_df.Label)
print(classification_report(y_test, pred))
print("classification_report\n", classification_report(y_test, pred))

cf_matrix = confusion_matrix(y_test, pred, normalize="true")
plt.figure(figsize=(10, 6))
Expand Down Expand Up @@ -832,5 +832,5 @@ def grad_cam_viz(self, *args, **kwargs):
transfer_model.plot_metrics_results()
transfer_model.results()
# one can pass the model address to the predict_test method
transfer_model.predcit_test()
transfer_model.predict_test()
transfer_model.grad_cam_viz(num_rows=3, num_cols=2)
2 changes: 1 addition & 1 deletion tests/test_transfer_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_run():
transfer_model.train_model()
transfer_model.plot_metrics_results()
transfer_model.results()
transfer_model.predcit_test()
transfer_model.predict_test()
transfer_model.grad_cam_viz(num_rows=3, num_cols=2)

assert True

0 comments on commit 17b8d7d

Please sign in to comment.