Skip to content

Commit

Permalink
fix on_epoch_end
Browse files Browse the repository at this point in the history
  • Loading branch information
TjarkMiener committed Nov 27, 2024
1 parent 2a928f0 commit 2c0419c
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions dl1_data_handler/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def __init__(
self.tasks = tasks
self.batch_size = batch_size
self.random_seed = random_seed
if self.random_seed is not None:
self.on_epoch_end()
self.on_epoch_end()

# Get the input shape for the convolutional neural network
self.image_shape = self.DLDataReader.image_mappers[self.DLDataReader.cam_name].image_shape
Expand All @@ -44,9 +43,10 @@ def __len__(self):
return int(np.floor(len(self.indices) / self.batch_size))

def on_epoch_end(self):
"Updates indexes after each epoch"
np.random.seed(self.random_seed)
np.random.shuffle(self.indices)
"Updates indexes after each epoch if random seed is set"
if self.random_seed is not None:
np.random.seed(self.random_seed)
np.random.shuffle(self.indices)

def __getitem__(self, index):
"Generate one batch of data"
Expand Down

0 comments on commit 2c0419c

Please sign in to comment.