-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ml pipeline rewrite #40
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
d0fbf8c
refactor: remove file read in pipeline
shailx383 73f194f
fix: remove preprocess from facenet
shailx383 64c5cdf
fix: adjust typing of np array
shailx383 0d156df
refactor: rewrite embedding pipeline
shailx383 6689611
test: add test for new pipeline
shailx383 0f0b6c5
docs: edit docstring of piepline
shailx383 e0416a8
fix: revert typing of array in clusterer
shailx383 396d481
perf: change default transform dim to 224
shailx383 a6c6fcd
chore: remove tqdm
shailx383 f369f54
test: update pipeline test nb
shailx383 54d1144
chore: fix typing of array in clusterer
shailx383 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,81 +1,102 @@ | ||
import torchvision | ||
from exception import FaceNotFoundError | ||
from model import FaceNet | ||
from facenet_pytorch import MTCNN | ||
from PIL import Image | ||
from triplet_dataset import create_transform | ||
import torch | ||
from facenet_pytorch import MTCNN, InceptionResnetV1 | ||
import numpy as np | ||
from exception import FaceNotFoundError | ||
|
||
|
||
class EmbeddingPipeline: | ||
""" | ||
Pipeline class for detecting faces from a single image and converting to embedding vectors | ||
|
||
Attributes: | ||
detector (MTCNN): MTCNN detector object. Defaults to MTCNN(keep_all=True, device=device).eval(). | ||
resnet (InceptionResnetV1): InceptionResnetV1 object. Defaults to InceptionResnetV1(pretrained=pretrained, device=device).eval(). | ||
device (str): device to run the model on. Defaults to 'cpu'. | ||
pretrained (str): pretrained model to use. Defaults to 'vggface2'. | ||
resize (float): resize factor for the image. Defaults to None. | ||
resnet (InceptionResnetV1): InceptionResnetV1 object. Defaults to InceptionResnetV1(pretrained='vggface2').eval(). | ||
|
||
""" | ||
def __init__(self, detector: MTCNN = None, resnet: InceptionResnetV1 = None, device: str = 'cpu', pretrained: str = 'vggface2', resize: float = None): | ||
|
||
def __init__(self, detector: MTCNN = None, default_resnet: bool = True, embedding_size: int = 512): | ||
""" | ||
Constructor for EmbeddingPipeline class | ||
|
||
Args: | ||
detector (MTCNN): MTCNN detector object. Defaults to MTCNN(keep_all=True, device=device).eval(). | ||
detector (MTCNN): MTCNN detector object. Defaults to MTCNN(keep_all=True, device=device).eval(). | ||
resnet (InceptionResnetV1): InceptionResnetV1 object. Defaults to InceptionResnetV1(pretrained=pretrained, device=device).eval(). | ||
device (str): device to run the model on. Defaults to 'cpu'. | ||
pretrained (str): pretrained model to use. Defaults to 'vggface2'. | ||
resize (float): resize factor for the image. Defaults to None. | ||
Constructor for EmbeddingPipeline class | ||
|
||
Args: | ||
detector (MTCNN): MTCNN detector object. Defaults to MTCNN(keep_all=True, device=device).eval(). | ||
default_resnet (bool): boolean. If True defaults resnet used to InceptionResnetV1(pretrained='vggface2').eval(). | ||
embedding_size (int): size of embedding vector. Defaults to 512. | ||
|
||
""" | ||
|
||
DETECTOR = MTCNN(keep_all=True, device=device).eval() | ||
RESNET = InceptionResnetV1(pretrained=pretrained, device=device).eval() | ||
DETECTOR = MTCNN(keep_all=True).eval() | ||
RESNET = FaceNet(embedding_size=embedding_size, use_default=default_resnet) | ||
|
||
self.detector = DETECTOR if detector is None else detector | ||
self.resize = resize | ||
self.resnet = RESNET if resnet is None else resnet | ||
|
||
self.resnet = RESNET | ||
|
||
def __call__(self, filepath: str): | ||
def __call__(self, image_path: str): | ||
""" | ||
Reads the image, processes it, optionally resizes it and detects faces | ||
Reads the image, processes it and detects faces | ||
|
||
Args: | ||
filepath (str): path to the image file | ||
image_path (str): path to the image file | ||
|
||
Returns: | ||
numpy array of embedding vectors of size {torch.Size([1, 512])} | ||
|
||
""" | ||
faces = self._detect_faces(image_path) | ||
embeddings = self._create_embeddings(faces) | ||
return embeddings | ||
|
||
def _detect_faces(self, path: str): | ||
""" detects faces from image in provided path | ||
|
||
Args: | ||
path (str): path to image | ||
|
||
Raises: | ||
FaceNotFoundError: in case of no face found in image | ||
|
||
Returns: | ||
list[PIL.Image.Image]: list of faces detected as PIL images | ||
""" | ||
|
||
img = Image.open(filepath) | ||
images = [] | ||
|
||
if self.resize is not None: | ||
img = img.resize([int(d*self.resize) for d in img.size]) | ||
transform_to_image = torchvision.transforms.ToPILImage() | ||
image = Image.open(path) | ||
faces = self.detector(image) | ||
if faces is None: | ||
raise FaceNotFoundError(path) | ||
else: | ||
for face in faces: | ||
images.append(transform_to_image(face)) | ||
return images | ||
|
||
def _create_embeddings(self, faces: list[Image.Image], transform_height: int = 224, transform_width: int = 224): | ||
"""transforms and coverts faces into embedding vectors | ||
|
||
Args: | ||
faces (list[Image.Image]): list of PIL Images of faces | ||
transform_height (int, optional): height of transformed image. Defaults to 224. | ||
transform_width (int, optional): width of transformed image. Defaults to 224. | ||
|
||
Returns: | ||
torch.Tensor: tensor of embedding vectors of faces | ||
""" | ||
|
||
transform = create_transform(transform_height, transform_width) | ||
|
||
transformed_faces = [] | ||
|
||
for face in faces: | ||
converted = face.convert('RGB') | ||
transformed_face = transform(converted) | ||
transformed_faces.append(transformed_face) | ||
|
||
detected_faces = self.detector(img) | ||
transformed_faces = torch.stack(transformed_faces) | ||
|
||
embeddings = self._create_embeddings(detected_faces, filepath) | ||
embeddings = self.resnet.embed(transformed_faces) | ||
|
||
return embeddings | ||
return embeddings | ||
|
||
def _create_embeddings(self, faces: list[torch.tensor], filepath: str): | ||
""" | ||
Converts array of faces to embedding vectors | ||
|
||
Args: | ||
faces (list[torch.tensor]): list of tensors of detected faces | ||
filepath (str): path of image | ||
|
||
Returns: | ||
numpy array of embedding vectors of size {torch.Size([1, 512])} | ||
|
||
""" | ||
if faces is not None: | ||
embeddings = np.array([self.resnet(torch.unsqueeze(face, 0)).detach().numpy() for face in faces]) | ||
return embeddings | ||
else: | ||
raise FaceNotFoundError(filepath) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line 90 still throws an error when only
point.data
is written. code runs perfectly withpoint.data.numpy()