Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

remove timestamp from save path #200

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
11 changes: 4 additions & 7 deletions delira/training/base_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import pickle
import os
from datetime import datetime

import copy

Expand All @@ -13,7 +12,7 @@

from delira.data_loading import BaseDataManager
from delira.models import AbstractNetwork

from delira.training.utils import check_save_path
from delira.training.parameters import Parameters
from delira.training.base_trainer import BaseNetworkTrainer
from delira.training.predictor import Predictor
Expand Down Expand Up @@ -109,12 +108,10 @@ def __init__(self,
if save_path is None:
save_path = os.path.abspath(".")

self.save_path = os.path.join(save_path, name,
str(datetime.now().strftime(
"%y-%m-%d_%H-%M-%S")))
duplicate_number, self.save_path = check_save_path(os.path.join(save_path, name))

if os.path.isdir(self.save_path):
logger.warning("Save Path %s already exists")
if duplicate_number:
print('Save path is a duplicate and got changed to {}'.format(save_path))

os.makedirs(self.save_path, exist_ok=True)

Expand Down
12 changes: 11 additions & 1 deletion delira/training/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
import numpy as np

import os

def recursively_convert_elements(element, check_type, conversion_fn):
"""
Expand Down Expand Up @@ -98,3 +98,13 @@ def convert_to_numpy_identity(*args, **kwargs):
_correct_zero_shape)

return args, kwargs

def check_save_path(save_path):
mibaumgartner marked this conversation as resolved.
Show resolved Hide resolved
i = 0
gedoensmax marked this conversation as resolved.
Show resolved Hide resolved
new_path = save_path
run_name = os.path.basename(save_path)
dir_name = os.path.dirname(save_path)
while os.path.isdir(new_path):
i +=1
new_path = os.path.join(dir_name,run_name + '_{:02d}'.format(i))
return i , new_path