Skip to content

Commit

Permalink
Update bit_vision.py
Browse files Browse the repository at this point in the history
fix the predict
  • Loading branch information
Atashnezhad committed Jun 22, 2023
1 parent 23ef848 commit fcdd3a6
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions neural_network_model/bit_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,10 @@ def train_model(
callbacks=[self._check_points(model_save_address, model_name)],
)

self.model.save(
model_save_address / model_name
or SETTING.MODEL_SETTING.MODEL_PATH / SETTING.MODEL_SETTING.MODEL_NAME
)
# self.model.save(
# model_save_address / model_name
# or SETTING.MODEL_SETTING.MODEL_PATH / SETTING.MODEL_SETTING.MODEL_NAME
# )
logger.info(f"Model saved to {SETTING.MODEL_SETTING.MODEL_PATH}")

def plot_history(self, *args, **kwargs):
Expand Down Expand Up @@ -351,11 +351,11 @@ def predict(self, *args, **kwargs):
if model_path is None:
logger.info(f"model_path from SETTING is was used - {model_path}")

test_folder_dir = kwargs.get(
"test_folder_dir", SETTING.DATA_ADDRESS_SETTING.TEST_DIR_ADDRESS
)
if test_folder_dir is None:
raise ValueError("test_folder_address is None")
# test_folder_dir = kwargs.get(
# "test_folder_dir", SETTING.DATA_ADDRESS_SETTING.TEST_DIR_ADDRESS
# )
# if test_folder_dir is None:
# raise ValueError("test_folder_address is None")

model = keras.models.load_model(model_path)
logger.info(f"Model loaded from {model_path}")
Expand All @@ -365,20 +365,20 @@ def predict(self, *args, **kwargs):
number_of_cols = SETTING.FIGURE_SETTING.NUM_COLS_IN_PRED_MODEL
number_of_rows = SETTING.FIGURE_SETTING.NUM_ROWS_IN_PRED_MODEL
number_of_test_to_pred = SETTING.MODEL_SETTING.NUMBER_OF_TEST_TO_PRED
if test_folder_dir:
train_test_val_dir = (
test_folder_dir
or SETTING.PREPROCESSING_SETTING.TRAIN_TEST_VAL_SPLIT_DIR_ADDRESS
)
else:
train_test_val_dir = (
self.train_test_val_dir
or SETTING.PREPROCESSING_SETTING.TRAIN_TEST_VAL_SPLIT_DIR_ADDRESS
)
# if test_folder_dir:
# train_test_val_dir = (
# test_folder_dir
# or SETTING.PREPROCESSING_SETTING.TRAIN_TEST_VAL_SPLIT_DIR_ADDRESS
# )
# else:
# train_test_val_dir = (
# self.train_test_val_dir
# or SETTING.PREPROCESSING_SETTING.TRAIN_TEST_VAL_SPLIT_DIR_ADDRESS
# )

# get the list of test images
test_images_list = os.listdir(
train_test_val_dir
self.train_test_val_dir
/ SETTING.PREPROCESSING_SETTING.TRAIN_TEST_SPLIT_DIR_NAMES[1]
/ category
)
Expand All @@ -387,7 +387,7 @@ def predict(self, *args, **kwargs):

for i, img in enumerate(test_images_list[0:number_of_test_to_pred]):
path_to_img = (
train_test_val_dir
self.train_test_val_dir
/ SETTING.PREPROCESSING_SETTING.TRAIN_TEST_SPLIT_DIR_NAMES[1]
/ category
/ str(img)
Expand Down Expand Up @@ -432,7 +432,7 @@ def predict(self, *args, **kwargs):

datagen = image.ImageDataGenerator(SETTING.DATA_GEN_SETTING.RESCALE)
DoubleCheck_generator = datagen.flow_from_directory(
directory=test_folder_dir / "test",
directory=self.train_test_val_dir / "test",
target_size=SETTING.FLOW_FROM_DIRECTORY_SETTING.TARGET_SIZE,
color_mode=SETTING.FLOW_FROM_DIRECTORY_SETTING.COLOR_MODE,
classes=None,
Expand Down

0 comments on commit fcdd3a6

Please sign in to comment.