Skip to content

Commit

Permalink
fix the fig for predict method
Browse files Browse the repository at this point in the history
  • Loading branch information
Atashnezhad committed Jun 25, 2023
1 parent 488c83b commit f742d5b
Showing 1 changed file with 39 additions and 31 deletions.
70 changes: 39 additions & 31 deletions neural_network_model/bit_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,26 +161,30 @@ def compile_model(self, *args, **kwargs) -> None:
self.model.summary()

def plot_image_category(self, *args, **kwargs) -> None:
"""
This function is used to plot images.
:param images: list of images
:param labels: list of labels
:return:
"""
nrows = kwargs.get("nrows", 1)
number_of_categories = len(self.categories)
ncols = kwargs.get("ncols", number_of_categories)
subdir = kwargs.get("subdir", "train")
fig_size = kwargs.get("fig_size", (17, 10))

# get one image for each category in train data and plot them
# Get one image for each category in train data and plot them
fig, axs = plt.subplots(nrows, ncols, figsize=fig_size)
for category in self.categories:

if nrows == 1 and ncols > 1:
axs = axs.reshape((ncols,)) # Reshape axs to handle 1 row and multiple columns

for i, category in enumerate(self.categories):
category_path = self.train_test_val_dir / subdir / category
image_path = category_path / os.listdir(category_path)[1]
img = load_img(image_path)
axs[self.categories.index(category)].imshow(img)
axs[self.categories.index(category)].set_title(category)

if nrows == 1 and ncols > 1:
axs[i].imshow(img) # Use 1D indexing for 1 row and multiple columns
axs[i].set_title(category)
else:
axs[i // ncols, i % ncols].imshow(img) # Adjusted indexing
axs[i // ncols, i % ncols].set_title(category) # Adjusted indexing

plt.show()

def _rescaling(self) -> None:
Expand Down Expand Up @@ -351,19 +355,23 @@ def predict(self, *args, **kwargs):
if model_path is None:
logger.info(f"model_path from SETTING is was used - {model_path}")

# test_folder_dir = kwargs.get(
# "test_folder_dir", SETTING.DATA_ADDRESS_SETTING.TEST_DIR_ADDRESS
# )
# if test_folder_dir is None:
# raise ValueError("test_folder_address is None")

model = keras.models.load_model(model_path)
logger.info(f"Model loaded from {model_path}")

# Default settings
default_num_rows = SETTING.FIGURE_SETTING.NUM_ROWS_IN_PRED_MODEL
default_num_cols = SETTING.FIGURE_SETTING.NUM_COLS_IN_PRED_MODEL
default_figsize = SETTING.FIGURE_SETTING.FIGURE_SIZE_IN_PRED_MODEL

# Extract details from kwargs or use default settings
num_rows = kwargs.get('num_rows', default_num_rows)
num_cols = kwargs.get('num_cols', default_num_cols)
figure_size = kwargs.get('figure_size', default_figsize)

for category in self.categories:
plt.figure(figsize=SETTING.FIGURE_SETTING.FIGURE_SIZE_IN_PRED_MODEL)
number_of_cols = SETTING.FIGURE_SETTING.NUM_COLS_IN_PRED_MODEL
number_of_rows = SETTING.FIGURE_SETTING.NUM_ROWS_IN_PRED_MODEL
plt.figure(figsize=figure_size)
number_of_cols = num_cols
number_of_rows = num_rows
number_of_test_to_pred = SETTING.MODEL_SETTING.NUMBER_OF_TEST_TO_PRED

# get the list of test images
Expand Down Expand Up @@ -624,18 +632,18 @@ def return_best_model_name(
obj = BitVision(
train_test_val_dir=Path(__file__).parent / ".." / "dataset_train_test_val"
)
# print(obj.categories)
# print(obj.data_details)
# obj.plot_image_category()
print(obj.categories)
print(obj.data_details)
obj.plot_image_category(nrows=3, ncols=3)
# obj.compile_model()
# obj.train_model(epochs=8)
# obj.plot_history()
obj.predict()
obj.grad_cam_viz(
gradcam_fig_name="test.png",
print_layer_names=True,
test_folder_dir=Path(__file__).parent
/ ".."
/ "dataset_train_test_val"
/ "test",
)
# obj.predict(num_rows=2, num_cols=2, figsize=(4, 10))
# obj.grad_cam_viz(
# gradcam_fig_name="test.png",
# print_layer_names=True,
# test_folder_dir=Path(__file__).parent
# / ".."
# / "dataset_train_test_val"
# / "test",
# )

0 comments on commit f742d5b

Please sign in to comment.