Skip to content

Commit

Permalink
grad cam ret
Browse files Browse the repository at this point in the history
  • Loading branch information
Atashnezhad committed Sep 6, 2023
1 parent 98df14a commit 40d3364
Showing 1 changed file with 3 additions and 30 deletions.
33 changes: 3 additions & 30 deletions neural_network_model/transfer_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,43 +892,16 @@ def grad_cam_viz(self, *args, **kwargs):
num_rows = kwargs.get("num_rows", None)
num_cols = kwargs.get("num_cols", None)
last_conv_layer_name = kwargs.get("last_conv_layer_name", "Conv_1")
img_size = kwargs.get("img_size", SETTING.FLOW_FROM_DIRECTORY_SETTING.TARGET_SIZE)
img_size = kwargs.get("img_size", (224, 224))
gard_cam_image_name = kwargs.get("gard_cam_image_name", "transf_cam.jpg")
figsize = kwargs.get("figsize", (8, 6))
title_lable_size = kwargs.get("title_lable_size", 8)
model_path = kwargs.get("model_path", None)
test_dataset_address = kwargs.get("test_dataset_address", None)
x_col = TRANSFER_LEARNING_SETTING.DF_X_COL_NAME,
y_col = TRANSFER_LEARNING_SETTING.DF_Y_COL_NAME,
save_path = kwargs.get("save_path", Path(__file__).parent / ".." / "figures" / "grad_cam.png", )

if model_path:
logger.info(f"Loading the model from {model_path}")
self.model = tf.keras.models.load_model(model_path)
else:
logger.info(f"Using the self.model from memory")
# Remove last layer's softmax
self.model.layers[-1].activation = None

# Display the part of the pictures used by the neural network to classify the pictures
if test_dataset_address:
# Get filepaths and labels
filepaths = list(test_dataset_address.glob(r"**/*.png"))
# add those with jpg extension
filepaths.extend(list(test_dataset_address.glob(r"**/*.jpg")))
# add those with jpeg extension
filepaths.extend(list(test_dataset_address.glob(r"**/*.jpeg")))
labels = [path.stem for path in filepaths]

filepaths = pd.Series(filepaths, name=x_col).astype(str)
labels = pd.Series(labels, name=y_col)

# Concatenate filepaths and labels
test_df = pd.concat([filepaths, labels], axis=1)
# test_df, _ = train_test_split(test_df, train_size=1)

else:
_, test_df = self._train_test_split()
_, test_df = self._train_test_split()

if not num_rows and not num_cols:
# Get the number of rows and columns for subplots
Expand Down Expand Up @@ -977,7 +950,7 @@ def grad_cam_viz(self, *args, **kwargs):
plt.tight_layout()
# save the figure
plt.savefig(
save_path,
Path(__file__).parent / ".." / "figures" / "grad_cam.png",
bbox_inches="tight",
)
plt.show()
Expand Down

0 comments on commit 40d3364

Please sign in to comment.