diff --git a/face_alignment/api.py b/face_alignment/api.py index 97df85d5..10455abe 100644 --- a/face_alignment/api.py +++ b/face_alignment/api.py @@ -52,11 +52,13 @@ def __int__(self): class FaceAlignment: def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, - device='cuda', flip_input=False, face_detector='sfd', verbose=False): + device='cuda', flip_input=False, face_detector='sfd', verbose=False, + return_face_results=False): self.device = device self.flip_input = flip_input self.landmarks_type = landmarks_type self.verbose = verbose + self.return_face_results = return_face_results network_size = int(network_size) @@ -183,6 +185,8 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None): landmarks.append(pts_img.numpy()) + if self.return_face_results: + return landmarks, detected_faces return landmarks @torch.no_grad() @@ -257,6 +261,9 @@ def get_landmarks_from_batch(self, image_batch, detected_faces=None): landmark_set = np.concatenate(landmark_set, axis=0) landmarks.append(landmark_set) + + if self.return_face_results: + return landmarks, detected_faces return landmarks def get_landmarks_from_directory(self, path, extensions=['.jpg', '.png'], recursive=True, show_progress_bar=True):