diff --git a/neural_network_model/bit_vision.py b/neural_network_model/bit_vision.py index 0e4a222..9320e09 100644 --- a/neural_network_model/bit_vision.py +++ b/neural_network_model/bit_vision.py @@ -161,26 +161,30 @@ def compile_model(self, *args, **kwargs) -> None: self.model.summary() def plot_image_category(self, *args, **kwargs) -> None: - """ - This function is used to plot images. - :param images: list of images - :param labels: list of labels - :return: - """ nrows = kwargs.get("nrows", 1) number_of_categories = len(self.categories) ncols = kwargs.get("ncols", number_of_categories) subdir = kwargs.get("subdir", "train") fig_size = kwargs.get("fig_size", (17, 10)) - # get one image for each category in train data and plot them + # Get one image for each category in train data and plot them fig, axs = plt.subplots(nrows, ncols, figsize=fig_size) - for category in self.categories: + + if nrows == 1 and ncols > 1: + axs = axs.reshape((ncols,)) # Reshape axs to handle 1 row and multiple columns + + for i, category in enumerate(self.categories): category_path = self.train_test_val_dir / subdir / category image_path = category_path / os.listdir(category_path)[1] img = load_img(image_path) - axs[self.categories.index(category)].imshow(img) - axs[self.categories.index(category)].set_title(category) + + if nrows == 1 and ncols > 1: + axs[i].imshow(img) # Use 1D indexing for 1 row and multiple columns + axs[i].set_title(category) + else: + axs[i // ncols, i % ncols].imshow(img) # Adjusted indexing + axs[i // ncols, i % ncols].set_title(category) # Adjusted indexing + plt.show() def _rescaling(self) -> None: @@ -351,19 +355,23 @@ def predict(self, *args, **kwargs): if model_path is None: logger.info(f"model_path from SETTING is was used - {model_path}") - # test_folder_dir = kwargs.get( - # "test_folder_dir", SETTING.DATA_ADDRESS_SETTING.TEST_DIR_ADDRESS - # ) - # if test_folder_dir is None: - # raise ValueError("test_folder_address is None") - model = keras.models.load_model(model_path) logger.info(f"Model loaded from {model_path}") + # Default settings + default_num_rows = SETTING.FIGURE_SETTING.NUM_ROWS_IN_PRED_MODEL + default_num_cols = SETTING.FIGURE_SETTING.NUM_COLS_IN_PRED_MODEL + default_figsize = SETTING.FIGURE_SETTING.FIGURE_SIZE_IN_PRED_MODEL + + # Extract details from kwargs or use default settings + num_rows = kwargs.get('num_rows', default_num_rows) + num_cols = kwargs.get('num_cols', default_num_cols) + figure_size = kwargs.get('figure_size', default_figsize) + for category in self.categories: - plt.figure(figsize=SETTING.FIGURE_SETTING.FIGURE_SIZE_IN_PRED_MODEL) - number_of_cols = SETTING.FIGURE_SETTING.NUM_COLS_IN_PRED_MODEL - number_of_rows = SETTING.FIGURE_SETTING.NUM_ROWS_IN_PRED_MODEL + plt.figure(figsize=figure_size) + number_of_cols = num_cols + number_of_rows = num_rows number_of_test_to_pred = SETTING.MODEL_SETTING.NUMBER_OF_TEST_TO_PRED # get the list of test images @@ -624,18 +632,18 @@ def return_best_model_name( obj = BitVision( train_test_val_dir=Path(__file__).parent / ".." / "dataset_train_test_val" ) - # print(obj.categories) - # print(obj.data_details) - # obj.plot_image_category() + print(obj.categories) + print(obj.data_details) + obj.plot_image_category(nrows=3, ncols=3) # obj.compile_model() # obj.train_model(epochs=8) # obj.plot_history() - obj.predict() - obj.grad_cam_viz( - gradcam_fig_name="test.png", - print_layer_names=True, - test_folder_dir=Path(__file__).parent - / ".." - / "dataset_train_test_val" - / "test", - ) + # obj.predict(num_rows=2, num_cols=2, figsize=(4, 10)) + # obj.grad_cam_viz( + # gradcam_fig_name="test.png", + # print_layer_names=True, + # test_folder_dir=Path(__file__).parent + # / ".." + # / "dataset_train_test_val" + # / "test", + # )