-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3dd0fdb
commit 499a689
Showing
19 changed files
with
1,428 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
AUDIO_SAMPLERATE = 16384 | ||
VIDEO_FRAMERATE = 25 | ||
N_FFT = 1022 | ||
HOP_LENGTH = 256 | ||
SP_FREQ_SHAPE = N_FFT // 2 + 1 | ||
|
||
fourier_defaults = {"audio_samplerate": AUDIO_SAMPLERATE, | ||
"n_fft": N_FFT, | ||
"sp_freq_shape": SP_FREQ_SHAPE, | ||
"hop_length": HOP_LENGTH} | ||
core_path = __path__[0] | ||
|
||
from .models import VoViT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from typing import Union | ||
import torch | ||
|
||
|
||
def rigid_transform_3D(target_face: torch.tensor, mean_face: torch.tensor) -> torch.tensor: | ||
""" | ||
Compute a rigid transformation between two sets of landmarks by using Kabsch algorithm. | ||
The Kabsch algorithm, named after Wolfgang Kabsch, is a method for calculating the optimal rotation matrix | ||
that minimizes the RMSD (root mean squared deviation) between two paired sets of points. | ||
args: | ||
target_face: NumPy array of shape (3,N) | ||
mean_face: NumPy array of shape (3,N) | ||
returns: | ||
R: NumPy array of shape (3,3) | ||
t: NumPy array of shape (3,1) | ||
source: | ||
https://en.wikipedia.org/wiki/Kabsch_algorithm | ||
""" | ||
# Geometric transformations in 3D | ||
# https://cseweb.ucsd.edu/classes/wi18/cse167-a/lec3.pdf | ||
|
||
# Affine transformation (theoretical) | ||
# http://learning.aols.org/aols/3D_Affine_Coordinate_Transformations.pdf | ||
|
||
# Implementation from http://nghiaho.com/?page_id=671 | ||
# | ||
assert target_face.shape == mean_face.shape | ||
assert target_face.shape[0] == 3, "3D rigid transform only" | ||
|
||
# find mean column wise | ||
centroid_A = torch.mean(target_face, dim=1) | ||
centroid_B = torch.mean(mean_face, dim=1) | ||
|
||
# ensure centroids are 3x1 | ||
centroid_A = centroid_A.reshape(-1, 1) | ||
centroid_B = centroid_B.reshape(-1, 1) | ||
|
||
# subtract mean | ||
Am = target_face - centroid_A | ||
Bm = mean_face - centroid_B | ||
|
||
H = Am @ Bm.T | ||
# H = (Am.cpu() @ Bm.T.cpu()) | ||
|
||
# find rotation | ||
U, S, Vt = torch.linalg.svd(H) # torch.svd differs from torch.linalg.svd | ||
# https://pytorch.org/docs/stable/generated/torch.svd.html | ||
R = Vt.T @ U.T | ||
|
||
# special reflection case | ||
if torch.linalg.det(R) < 0: | ||
print("det(R) < R, reflection detected!, correcting for it ...") | ||
Vt[2, :] *= -1 | ||
R = Vt.T @ U.T | ||
|
||
t = -R @ centroid_A + centroid_B | ||
|
||
return R, t | ||
|
||
|
||
def apply_transformation(R, t, landmarks: torch.tensor) -> torch.tensor: | ||
""" | ||
Apply a rigid transformation to a set of landmarks. | ||
args: | ||
R: NumPy array of shape (3,3) | ||
t: NumPy array of shape (3,1) | ||
landmarks: NumPy array of shape (3,N) | ||
""" | ||
assert landmarks.shape[0] == 3, "landmarks must be 3D" | ||
assert R.shape == (3, 3), "R must be 3x3" | ||
assert t.shape == (3, 1), "t must be 3x1" | ||
|
||
# apply transformation | ||
transformed_landmarks = R @ landmarks + t | ||
|
||
return transformed_landmarks | ||
|
||
|
||
def register_sequence_of_landmarks(target_sequence: torch.tensor, mean_face: torch.tensor, per_frame=False, | ||
display_sequence: Union[torch.tensor, None] = None) -> torch.tensor: | ||
""" | ||
Register a sequence of landmarks to a mean face. | ||
Computational complexity: O(3*N*T) | ||
args: | ||
target_face: NumPy array of shape (T,3,N) | ||
mean_face: NumPy array of shape (3,N) | ||
per_frame: either to estimate the transformation per frame or given the mean face. | ||
display_sequence: (optional) NumPy array of shape (T',3,N'). Optional array to estimate the transformation | ||
on some of the landmarks. | ||
returns: | ||
registered_sequence: NumPy array of shape (T,3,N) | ||
example: | ||
Computing the transformation ignoring landmarks from 48 onwards but | ||
estimating the transformation for all of them | ||
>>> registered_sequence = register_sequence_of_landmarks(landmarks[..., :48], | ||
>>> mean_face[:, :48], | ||
>>> display_sequence=landmarks) | ||
""" | ||
if display_sequence is None: | ||
display_sequence = target_sequence | ||
|
||
if not per_frame: | ||
# Estimates the mean face | ||
target_mean_face = torch.mean(target_sequence, dim=0) | ||
# compute rigid transformation | ||
R, t = rigid_transform_3D(target_mean_face, mean_face) | ||
|
||
# apply transformation | ||
registered_sequence = [] | ||
for x, y in zip(target_sequence, display_sequence): | ||
if per_frame: | ||
R, t = rigid_transform_3D(x, mean_face) | ||
registered_sequence.append(apply_transformation(R, t, y)) | ||
|
||
return torch.stack(registered_sequence) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .. import fourier_defaults, VIDEO_FRAMERATE | ||
from .production_model import VoViT |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# The based unit of graph convolutional networks. | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class ConvTemporalGraphical(nn.Module): | ||
r"""The basic module for applying a graph convolution. | ||
Args: | ||
in_channels (int): Number of channels in the input sequence data | ||
out_channels (int): Number of channels produced by the convolution | ||
kernel_size (int): Size of the graph convolving kernel | ||
t_kernel_size (int): Size of the temporal convolving kernel | ||
t_stride (int, optional): Stride of the temporal convolution. Default: 1 | ||
t_padding (int, optional): Temporal zero-padding added to both sides of | ||
the input. Default: 0 | ||
t_dilation (int, optional): Spacing between temporal kernel elements. | ||
Default: 1 | ||
bias (bool, optional): If ``True``, adds a learnable bias to the output. | ||
Default: ``True`` | ||
Shape: | ||
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format | ||
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format | ||
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format | ||
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format | ||
where | ||
:math:`N` is a batch size, | ||
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, | ||
:math:`T_{in}/T_{out}` is a length of input/output sequence, | ||
:math:`V` is the number of graph nodes. | ||
""" | ||
|
||
def __init__(self, | ||
in_channels, | ||
out_channels, | ||
kernel_size, | ||
t_kernel_size=1, | ||
t_stride=1, | ||
t_padding=0, | ||
t_dilation=1, | ||
bias=True, | ||
): | ||
|
||
super().__init__() | ||
|
||
self.kernel_size = kernel_size | ||
self.conv = nn.Conv2d(in_channels, | ||
out_channels * kernel_size, | ||
kernel_size=(t_kernel_size, 1), | ||
padding=(t_padding, 0), | ||
stride=(t_stride, 1), | ||
dilation=(t_dilation, 1), | ||
bias=bias) | ||
|
||
def forward(self, x, A): | ||
if A.ndim == 4: | ||
assert A.size(1) == self.kernel_size | ||
elif A.ndim == 5: | ||
assert A.size(2) == self.kernel_size | ||
else: | ||
assert A.size(0) == self.kernel_size | ||
|
||
x = self.conv(x) # B,channels=3,T,J | ||
|
||
n, kc, t, v = x.size() | ||
x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v) | ||
if A.ndimension() == 3: | ||
# static or dynamic | ||
x = torch.einsum('nkctv,kvw->nctw', (x, A)) | ||
elif A.ndimension() == 4: | ||
# Categorical | ||
x = torch.einsum('nkctv,nkvw->nctw', (x, A)) | ||
elif A.ndimension() == 5: | ||
x = torch.einsum('nkctv,ntkvw->nctw', (x, A)) | ||
else: | ||
raise Exception('Adjacency matrix dimensionalty is %d but should be 3,4 or 5' % A.ndimension()) | ||
return x.contiguous(), A |
Oops, something went wrong.