-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rename repo and add instructions in README file
- Loading branch information
1 parent
5b0f3c5
commit 0ab9710
Showing
17 changed files
with
9,478 additions
and
9,249 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,81 @@ | ||
# SiameseNet | ||
Siamese network for image classification | ||
# Siamese and Triplet networks for image classification | ||
|
||
This repository contains implementation of a deep neural networks for embeddings learning using Siamese and Triplets approaches with different negative samples mining strategies. | ||
|
||
# Installation | ||
|
||
## Install dependencies | ||
|
||
### Requirements | ||
|
||
- keras | ||
- tensorflow | ||
- scikit-learn | ||
- opencv | ||
- matplotlib | ||
- plotly - for interactive t-SNE plot visualization | ||
- [albumentations](https://github.com/albu/albumentations) - for online augmentation during training | ||
- [image-classifiers](https://github.com/qubvel/classification_models) - for different backbone models | ||
- [keras-rectified-adam](https://github.com/CyberZHG/keras-radam) - for cool state-of-the-art optimization | ||
|
||
```bash | ||
$ pip3 install -r requirements.txt | ||
``` | ||
|
||
# Train | ||
|
||
In the training dataset, the data for training and validation should be in separate folders, in each of which folders with images for each class. Dataset should have the following structure: | ||
|
||
``` | ||
Dataset | ||
└───train | ||
│ └───class_1 | ||
│ │ image1.jpg | ||
│ │ image2.jpg | ||
│ │ ... | ||
│ └───class_2 | ||
│ | image1.jpg | ||
│ │ image2.jpg | ||
│ │ ... | ||
│ └───class_N | ||
│ │ ... | ||
│ | ||
└───val | ||
│ └───class_1 | ||
│ │ image1.jpg | ||
│ │ image2.jpg | ||
│ │ ... | ||
│ └───class_2 | ||
│ | image1.jpg | ||
│ │ image2.jpg | ||
│ │ ... | ||
│ └───class_N | ||
│ │ ... | ||
``` | ||
|
||
For training, it is necessary to create a configuration file in which all network parameters and training parameters will be indicated. Examples of configuration files can be found in the **configs** folder. | ||
|
||
After the configuration file is created, you can modify **train.py** file, and then start training: | ||
|
||
```bash | ||
$ python3 train.py | ||
``` | ||
|
||
# Test | ||
|
||
The trained model can be tested using the following command: | ||
|
||
```bash | ||
$ python3 test.py [--weights (path to trained model weights file)] | ||
[--encodings (path to trained model encodings file)] | ||
[--image (path to image file)] | ||
``` | ||
|
||
Is is also possible to use [test_network.ipynb](https://github.com/RocketFlash/SiameseNet/blob/master/test_network.ipynb) notebook to test the trained network and visualize input data as well as output encodings. | ||
|
||
# Embeddings visualization | ||
|
||
Result encodings could be visualized interactively using **plot_tsne_interactive** function in [utils.py](https://github.com/RocketFlash/SiameseNet/blob/master/embedding_net/utils.py). | ||
|
||
t-SNE plot of russian traffic sign images embeddings (107 classes): | ||
![t-SNE example](images/t-sne.png) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
input_shape : [48, 48, 3] | ||
encodings_len: 256 | ||
margin: 0.7 | ||
mode : 'triplet' | ||
distance_type : 'l1' | ||
backbone : 'resnet18' | ||
backbone_weights : 'imagenet' | ||
optimizer : 'radam' | ||
learning_rate : 0.0001 | ||
project_name : 'road_signs/' | ||
freeze_backbone : False | ||
embeddings_normalization: True | ||
|
||
#paths | ||
dataset_path : '/home/rauf/datasets/road_signs_merged/' | ||
tensorboard_log_path : 'tf_log/' | ||
weights_save_path : 'weights/' | ||
plots_path : 'plots/' | ||
encodings_path : 'encodings/' | ||
model_save_name : 'best_model_resnet18_merged.h5' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,6 @@ keras | |
tensorflow-gpu | ||
matplotlib | ||
albumentations | ||
pydot | ||
scikit-learn | ||
opencv-python | ||
keras-rectified-adam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,22 @@ | ||
from siamese_net.model import SiameseNet | ||
from embedding_net.model import EmbeddingNet | ||
import argparse | ||
|
||
model = SiameseNet() | ||
model.load_model('weights/road_signs/best_model_4.h5') | ||
model.load_encodings('encodings/road_signs/encodings.pkl') | ||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--weights", type=str, | ||
help="path to trained model weights file") | ||
parser.add_argument("--encodings", type=str, | ||
help="path to trained model encodings file") | ||
parser.add_argument("--image", type=str, help="path to image file") | ||
opt = parser.parse_args() | ||
|
||
image_path = '/home/rauf/datasets/road_signs/road_signs_separated/val/1_1/rtsd-r1_train_006470.png' | ||
model_prediction = model.predict(image_path) | ||
print('Model prediction: {}'.format(model_prediction)) | ||
weights_path = opt.weights | ||
encodings_path = opt.encodings | ||
image_path = opt.image | ||
|
||
model = EmbeddingNet() | ||
model.load_model(weights_path) | ||
model.load_encodings(encodings_path) | ||
|
||
model_prediction = model.predict(image_path) | ||
print('Model prediction: {}'.format(model_prediction)) |
Oops, something went wrong.