Skip to content

Commit

Permalink
remove augmentations on prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
RocketFlash committed Jan 7, 2020
1 parent 241ecfa commit f38dfa5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
12 changes: 2 additions & 10 deletions embedding_net/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,22 +381,14 @@ def predict(self, image):
return predicted_label

def predict_knn(self, image, with_top5=False):
import albumentations as A
augmentations = A.Compose([
A.CenterCrop(p=1, height=2*self.input_shape[1]//3, width=2*self.input_shape[0]//3),
A.Resize(p=1, height=self.input_shape[1], width=self.input_shape[0])
], p=1)


if type(image) is str:
img = cv2.imread(image)
else:
img = image
img = cv2.resize(img, (self.input_shape[0], self.input_shape[1]))
img = augmentations(image=img)['image']

encoding = self.base_model.predict(np.expand_dims(img, axis=0))
predicted_label = self.encoded_training_data['knn_classifier'].predict(
encoding)
predicted_label = self.encoded_training_data['knn_classifier'].predict(encoding)
if with_top5:
prediction_top5_idx = self.encoded_training_data['knn_classifier'].kneighbors(encoding, n_neighbors=5)
prediction_top5 = [self.encoded_training_data['labels'][prediction_top5_idx[1][0][i]] for i in range(5)]
Expand Down
7 changes: 7 additions & 0 deletions test_network.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6025,11 +6025,18 @@
],
"source": [
"import pandas as pd\n",
"import albumentations as A\n",
"import os\n",
"\n",
"d = {'id':[],\n",
" 'label':[]}\n",
"\n",
"input_shape=(128, 128, 3)\n",
"augmentations = A.Compose([\n",
" A.CenterCrop(p=1, height=2*input_shape[1]//3, width=2*input_shape[0]//3),\n",
" A.Resize(p=1, height=input_shape[1], width=input_shape[0])\n",
"], p=1)\n",
"\n",
"for i in range(744):\n",
" f = '0' * (4-len(str(i))) + str(i) + '.jpg'\n",
" if f.endswith('.jpg') or f.endswith('.png'):\n",
Expand Down

0 comments on commit f38dfa5

Please sign in to comment.