Skip to content

Commit

Permalink
add wandb support and move to tf2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
RocketFlash committed Jul 1, 2020
1 parent 8bac495 commit 369b4ae
Show file tree
Hide file tree
Showing 9 changed files with 13,177 additions and 569 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,5 @@ plots/
sub.csv
core
work_dirs/
wandb/
*.csv
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ git clone git@github.com:RocketFlash/EmbeddingNet.git

### Requirements

- tensorflow=2.0.0
- tensorflow=2.2.0
- scikit-learn
- opencv
- matplotlib
Expand Down
27 changes: 14 additions & 13 deletions configs/road_signs_apollo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ MODEL:
encodings_len: 256
mode : 'triplet'
distance_type : 'l1'
backbone_name : 'efficientnet-b0'
backbone_name : 'efficientnet-b1'
backbone_weights : 'noisy-student'
freeze_backbone : False
embeddings_normalization: True
Expand All @@ -19,7 +19,7 @@ DATALOADER:

GENERATOR:
negatives_selection_mode : 'semihard'
k_classes: 10
k_classes: 20
k_samples: 3
margin: 0.5
batch_size : 8
Expand All @@ -29,9 +29,9 @@ GENERATOR:
TRAIN:
# optimizer parameters
optimizer : 'radam'
learning_rate : 0.0001
decay_factor : 0.99
step_size : 1
learning_rate : 0.001
decay_factor : 0.1
step_size : 5

# embeddings learning training parameters
n_epochs : 1000
Expand All @@ -42,14 +42,14 @@ TRAIN:
# SOFTMAX_PRETRAINING:
# # softmax pretraining parameters
# optimizer : 'radam'
# learning_rate : 0.0001
# decay_factor : 0.99
# step_size : 1
# learning_rate : 0.02
# decay_factor : 0.1
# step_size : 5

# batch_size : 16
# val_steps : 200
# steps_per_epoch : 1000
# n_epochs : 50
# val_steps : 100
# steps_per_epoch : 500
# n_epochs : 5

ENCODINGS:
# encodings parameters
Expand All @@ -59,6 +59,7 @@ ENCODINGS:
knn_k : 1

GENERAL:
project_name : 'road_signs_efnb0'
project_name : 'road_signs_efnb1'
work_dir : 'work_dirs/'
tensorboard_callback: False
tensorboard_callback: False
wandb_callback: True
2 changes: 1 addition & 1 deletion embedding_net/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def pretrain_backbone_softmax(backbone_model, data_loader, params_softmax, para
params_save_paths['work_dir'],
params_save_paths['project_name'],
'pretraining_model/weights/',
params_save_paths['project_name']+'_{epoch:03d}_{val_acc:03f}' +'.h5')
params_save_paths['project_name']+'_{epoch:03d}' +'.h5')

callbacks = [
LearningRateScheduler(lambda x: learning_rate *
Expand Down
34 changes: 24 additions & 10 deletions embedding_net/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, params):
self.params_model = params['model']
self.params_dataloader = params['dataloader']
self.params_generator = params['generator']
self.params_save_paths = params['general']
self.params_general = params['general']
self.params_train = params['train']
if 'softmax' in params:
self.params_softmax = params['softmax']
Expand All @@ -34,6 +34,9 @@ def __init__(self, params):
self.backbone_model = None
self.model = None

self.workdir_path = os.path.join(self.params_general['work_dir'],
self.params_general['project_name'])

self.encoded_training_data = {}

def _create_base_model(self):
Expand Down Expand Up @@ -68,7 +71,7 @@ def generate_encodings(self, data_loader, max_n_samples=10,
data_list = data_list[:max_n_samples]

data_paths += data_list
imgs = get_images(data_list)
imgs = get_images(data_list, self.params_model['input_shape'])
encods = self._generate_encodings(imgs)
for encod in encods:
data_encodings.append(encod)
Expand All @@ -81,21 +84,32 @@ def generate_encodings(self, data_loader, max_n_samples=10,
return encoded_training_data

def save_encodings(self, encoded_training_data,
save_folder='./',
save_file_name='encodings.pkl'):
with open(save_file_name, "wb") as f:
with open(os.path.join(save_folder, save_file_name), "wb") as f:
pickle.dump(encoded_training_data, f)

def load_model(self, file_path):
import efficientnet.tfkeras as efn
self.model = load_model(file_path, compile=False)

self.input_shape = list(self.model.inputs[0].shape[1:])
self.base_model = Model(inputs=[self.model.layers[2].get_input_at(0)],
outputs=[self.model.layers[2].layers[-1].output])
self.classification_model = Model(inputs=[self.model.layers[3].get_input_at(0)],
outputs=[self.model.layers[-1].output])
self.classification_model._make_predict_function()
self.base_model._make_predict_function()
self.base_model = Model(inputs=[self.model.get_layer('model').input],
outputs=[self.model.get_layer('model').output])
# self.classification_model = Model(inputs=[self.model.layers[3].get_input_at(0)],
# outputs=[self.model.layers[-1].output])
# self.classification_model._make_predict_function()
# self.base_model._make_predict_function()


def save_base_model(self, save_folder):
self.base_model.save(f'{save_folder}base_model.h5')

def save_onnx(self, save_folder, save_name='base_model.onnx'):
os.environ["TF_KERAS"] = '1'
import efficientnet.tfkeras as efn
import keras2onnx
onnx_model = keras2onnx.convert_keras(self.base_model, self.base_model.name)
keras2onnx.save_model(onnx_model, os.path.join(save_folder, save_name))

def predict(self, image):
if type(image) is str:
Expand Down
12 changes: 6 additions & 6 deletions embedding_net/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def plot_batch_simple(data, targets, class_names):
class_names = [class_names[i] for i in indxs]

for i in range(num_imgs):
full_img[:,i*img_w:(i+1)*img_w,:] = data[0][i,:,:,:]
cv2.putText(full_img, class_names[i], (img_w*i + 10, 20), cv2.FONT_HERSHEY_SIMPLEX,
0.5, (0, 255, 0), 1, cv2.LINE_AA)
full_img[:,i*img_w:(i+1)*img_w,:] = data[0][i,:,:,::-1]*255
cv2.putText(full_img, class_names[i], (img_w*i + 5, 20), cv2.FONT_HERSHEY_SIMPLEX,
0.2, (0, 255, 0), 1, cv2.LINE_AA)
plt.figure(figsize = (20,2))
plt.imshow(full_img)
plt.show()
Expand All @@ -131,10 +131,10 @@ def plot_batch(data, targets):
i = 0
for img_idx, targ in zip(range(num_imgs), targets):
for j in range(it_val):
img = cv2.cvtColor(data[j][img_idx].astype(
np.uint8), cv2.COLOR_BGR2RGB)
image = data[j][img_idx]*255
img = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_BGR2RGB)
axs[i+j].imshow(img)
axs[i+j].set_title(targ)
# axs[i+j].set_title(targ)
i += it_val

plt.show()
Expand Down
13,647 changes: 13,116 additions & 531 deletions examples/test_network.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
image-classifiers
tensorflow-gpu==2.0.0
tensorflow-gpu==2.2.0
matplotlib
albumentations
scikit-learn
Expand Down
19 changes: 13 additions & 6 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,19 @@ def main():
save_best_only=True,
verbose=1)
]

print('CREATE DATALOADER')
data_loader = ENDataLoader(**params_dataloader)
print('DATALOADER CREATED!')

if cfg_params['general']['tensorboard_callback']:
callbacks.append(TensorBoard(log_dir=tensorboard_save_path))

print('CREATE DATALOADER')
data_loader = ENDataLoader(**params_dataloader)
print('DATALOADER CREATED!')
if cfg_params['general']['wandb_callback']:
import wandb
from wandb.keras import WandbCallback
wandb.init()
callbacks.append(WandbCallback(data_type="image", labels=data_loader.class_names))

val_generator = None

Expand Down Expand Up @@ -124,14 +130,15 @@ def main():
metric = ['accuracy']
print('DONE')


if args.resume_from is not None:
model.load_model(args.resume_from)

print('COMPILE MODEL')
model.model.compile(loss=losses,
optimizer=params_train['optimizer'],
metrics=metric)

if args.resume_from is not None:
model.load_model(args.resume_from)

if 'softmax' in cfg_params:
params_softmax = cfg_params['softmax']
params_save_paths = cfg_params['general']
Expand Down

0 comments on commit 369b4ae

Please sign in to comment.