Skip to content

Commit

Permalink
VoViT core
Browse files Browse the repository at this point in the history
  • Loading branch information
JuanFMontesinos committed Jul 4, 2022
1 parent 3dd0fdb commit 499a689
Show file tree
Hide file tree
Showing 19 changed files with 1,428 additions and 18 deletions.
19 changes: 16 additions & 3 deletions vovit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import yaml

import torch
from einops import rearrange
Expand All @@ -20,12 +21,15 @@ def __init__(self, *, model_name: str, debug: dict, pretrained: bool = True,

if self.extract_landmarks:
from .core.landmark_estimator.TDDFA_GPU import TDDFA
self.face_extractor = TDDFA()
cfg = yaml.load(open(utils.DEFAULT_CFG_PATH), Loader=yaml.SafeLoader)
cfg['checkpoint_fp'] = os.path.join(utils.LANDMARK_LIB_PATH, 'weights', 'mb1_120x120.pth')
cfg['bfm_fp'] = os.path.join(utils.LANDMARK_LIB_PATH, 'configs', 'bfm_noneck_v3.pkl')
self.face_extractor = TDDFA(**cfg)
self.register_buffer('mean_face',
torch.from_numpy(np_load(os.path.join(core_path, 'speech_mean_face.npy'))).float(),
persistent=False)

def forward(self, mixture, visuals):
def forward(self, mixture, visuals, extract_landmarks=False):
"""
:param mixture: torch.Tensor of shape (B,N)
:param visuals: torch.Tensor of shape (B,C,H,W) BGR format required
Expand All @@ -35,7 +39,7 @@ def forward(self, mixture, visuals):
raise NotImplementedError
else:
cropped_video = visuals
if self.extract_landmarks:
if extract_landmarks:
ld = self.face_extractor(cropped_video)
avg = (ld[:-2] + ld[1:-1] + ld[2:]) / 3
ld[:-2] = avg
Expand All @@ -58,10 +62,19 @@ def forward_unlimited(self, mixture, visuals):
Allows to run inference in an unlimited duration samples (up to gpu memory constrains)
The results will be trimmed to multiples of 2 seconds (e.g. if your audio is 8.5 seconds long,
the result will be trimmed to 8 seconds)
Args:
visuals: raw video if self.extract_landmarks is True, precomputed_landmarks otherwise.
lanmarks are uint16 tensors of shape (T,3,68)
raw video are uint8 RGB tensors of shape (T,H,W,3) (values between 0-255)
mixture: tensor of shape (N)
"""
fps = VIDEO_FRAMERATE
length = self.vovit.avse.av_se.ap._audio_length
n_chunks = visuals.shape[0] // (fps * 2)
if self.extract_landmarks:
visuals = self.face_extractor(visuals)
avg = (visuals[:-2] + visuals[1:-1] + visuals[2:]) / 3
visuals[:-2] = avg
visuals = visuals[:n_chunks * fps * 2].view(n_chunks, fps * 2, 3, 68)
mixture = mixture[:n_chunks * length].view(n_chunks, -1)
pred = self.forward(mixture, visuals)
Expand Down
13 changes: 13 additions & 0 deletions vovit/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
AUDIO_SAMPLERATE = 16384
VIDEO_FRAMERATE = 25
N_FFT = 1022
HOP_LENGTH = 256
SP_FREQ_SHAPE = N_FFT // 2 + 1

fourier_defaults = {"audio_samplerate": AUDIO_SAMPLERATE,
"n_fft": N_FFT,
"sp_freq_shape": SP_FREQ_SHAPE,
"hop_length": HOP_LENGTH}
core_path = __path__[0]

from .models import VoViT
119 changes: 119 additions & 0 deletions vovit/core/kabsch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Union
import torch


def rigid_transform_3D(target_face: torch.tensor, mean_face: torch.tensor) -> torch.tensor:
"""
Compute a rigid transformation between two sets of landmarks by using Kabsch algorithm.
The Kabsch algorithm, named after Wolfgang Kabsch, is a method for calculating the optimal rotation matrix
that minimizes the RMSD (root mean squared deviation) between two paired sets of points.
args:
target_face: NumPy array of shape (3,N)
mean_face: NumPy array of shape (3,N)
returns:
R: NumPy array of shape (3,3)
t: NumPy array of shape (3,1)
source:
https://en.wikipedia.org/wiki/Kabsch_algorithm
"""
# Geometric transformations in 3D
# https://cseweb.ucsd.edu/classes/wi18/cse167-a/lec3.pdf

# Affine transformation (theoretical)
# http://learning.aols.org/aols/3D_Affine_Coordinate_Transformations.pdf

# Implementation from http://nghiaho.com/?page_id=671
#
assert target_face.shape == mean_face.shape
assert target_face.shape[0] == 3, "3D rigid transform only"

# find mean column wise
centroid_A = torch.mean(target_face, dim=1)
centroid_B = torch.mean(mean_face, dim=1)

# ensure centroids are 3x1
centroid_A = centroid_A.reshape(-1, 1)
centroid_B = centroid_B.reshape(-1, 1)

# subtract mean
Am = target_face - centroid_A
Bm = mean_face - centroid_B

H = Am @ Bm.T
# H = (Am.cpu() @ Bm.T.cpu())

# find rotation
U, S, Vt = torch.linalg.svd(H) # torch.svd differs from torch.linalg.svd
# https://pytorch.org/docs/stable/generated/torch.svd.html
R = Vt.T @ U.T

# special reflection case
if torch.linalg.det(R) < 0:
print("det(R) < R, reflection detected!, correcting for it ...")
Vt[2, :] *= -1
R = Vt.T @ U.T

t = -R @ centroid_A + centroid_B

return R, t


def apply_transformation(R, t, landmarks: torch.tensor) -> torch.tensor:
"""
Apply a rigid transformation to a set of landmarks.
args:
R: NumPy array of shape (3,3)
t: NumPy array of shape (3,1)
landmarks: NumPy array of shape (3,N)
"""
assert landmarks.shape[0] == 3, "landmarks must be 3D"
assert R.shape == (3, 3), "R must be 3x3"
assert t.shape == (3, 1), "t must be 3x1"

# apply transformation
transformed_landmarks = R @ landmarks + t

return transformed_landmarks


def register_sequence_of_landmarks(target_sequence: torch.tensor, mean_face: torch.tensor, per_frame=False,
display_sequence: Union[torch.tensor, None] = None) -> torch.tensor:
"""
Register a sequence of landmarks to a mean face.
Computational complexity: O(3*N*T)
args:
target_face: NumPy array of shape (T,3,N)
mean_face: NumPy array of shape (3,N)
per_frame: either to estimate the transformation per frame or given the mean face.
display_sequence: (optional) NumPy array of shape (T',3,N'). Optional array to estimate the transformation
on some of the landmarks.
returns:
registered_sequence: NumPy array of shape (T,3,N)
example:
Computing the transformation ignoring landmarks from 48 onwards but
estimating the transformation for all of them
>>> registered_sequence = register_sequence_of_landmarks(landmarks[..., :48],
>>> mean_face[:, :48],
>>> display_sequence=landmarks)
"""
if display_sequence is None:
display_sequence = target_sequence

if not per_frame:
# Estimates the mean face
target_mean_face = torch.mean(target_sequence, dim=0)
# compute rigid transformation
R, t = rigid_transform_3D(target_mean_face, mean_face)

# apply transformation
registered_sequence = []
for x, y in zip(target_sequence, display_sequence):
if per_frame:
R, t = rigid_transform_3D(x, mean_face)
registered_sequence.append(apply_transformation(R, t, y))

return torch.stack(registered_sequence)
37 changes: 24 additions & 13 deletions vovit/core/landmark_estimator/TDDFA_GPU.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from torch import nn
from torchvision.transforms import Compose

import models
from bfm import BFMModel
from utils.io import _load
from utils.functions import (
from . import models
from .bfm import BFMModel
from .utils.io import _load
from .utils.functions import (
crop_video, reshape_fortran, parse_roi_box_from_bbox,
)
from utils.tddfa_util import (
from .utils.tddfa_util import (
load_model, _batched_parse_param, batched_similar_transform,
ToTensorGjz, NormalizeGjz
)
Expand All @@ -25,6 +25,7 @@ class TDDFA(nn.Module):
"""TDDFA: named Three-D Dense Face Alignment (TDDFA)"""

def __init__(self, **kvs):
super(TDDFA, self).__init__()
self.size = kvs.get('size', 120)

# load BFM
Expand All @@ -48,7 +49,6 @@ def __init__(self, **kvs):
)
model = load_model(model, kvs.get('checkpoint_fp'))


self.model = model

# data normalization
Expand All @@ -59,12 +59,8 @@ def __init__(self, **kvs):

# params normalization config
r = _load(param_mean_std_fp)
self.param_mean = torch.from_numpy(r.get('mean'))
self.param_std = torch.from_numpy(r.get('std'))
self.param_mean = self.param_mean
self.param_std = self.param_std


self.register_buffer('param_mean', torch.from_numpy(r.get('mean')), persistent=False)
self.register_buffer('param_std', torch.from_numpy(r.get('std')), persistent=False)

def batched_inference(self, video_ori, bbox, **kvs):
"""The main call of TDDFA, given image and box / landmark, return 3DMM params and roi_box
Expand All @@ -75,7 +71,8 @@ def batched_inference(self, video_ori, bbox, **kvs):
"""
roi_box = parse_roi_box_from_bbox(bbox)
video = crop_video(video_ori, roi_box)
img = torch.nn.functional.interpolate(video, size=(self.size, self.size), mode='bilinear', align_corners=False)
img = torch.nn.functional.interpolate(video.float(), size=(self.size, self.size), mode='bilinear',
align_corners=False)

inp = self.transform_normalize(img)
param = self.model(inp)
Expand All @@ -96,3 +93,17 @@ def batched_recon_vers(self, param, roi_box, **kvs):
pts3d = batched_similar_transform(pts3d, roi_box, size)

return pts3d

def forward(self, video):
"""
:param video: RGB Video of shape (T,H,W,C) uint8 (values between 0-255). Video has to be cropped around the face
accurately (mainly to reduce GPU memory requirements).
:return:
"""
T, H, W, C = video.shape
assert C == 3, 'Video has to be RGB'
video = video.flip(-1) # BGR conversion
video = video.permute(0, 3, 1, 2) # T H W C -> T C H W
param, box_roi = self.batched_inference(video, [0, 0, W, H])
pts = self.batched_recon_vers(param, box_roi)
return pts
1 change: 1 addition & 0 deletions vovit/core/landmark_estimator/bfm/bfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def _to_ctype(arr):

class BFMModel(torch.nn.Module):
def __init__(self, bfm_fp, shape_dim=40, exp_dim=10):
super(BFMModel, self).__init__()
bfm = _load(bfm_fp)
if osp.split(bfm_fp)[-1] == 'bfm_noneck_v3.pkl':
self.tri = _load(make_abs_path('../configs/tri.pkl')) # this tri/face is re-built for bfm_noneck_v3
Expand Down
1 change: 1 addition & 0 deletions vovit/core/landmark_estimator/utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def reshape_fortran(x, shape):
if len(x.shape) > 0:
x = x.permute(*reversed(range(len(x.shape))))
return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))

def crop_video(video, roi_box):
bs, c, h, w = video.shape

Expand Down
5 changes: 3 additions & 2 deletions vovit/core/landmark_estimator/utils/tddfa_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def __repr__(self):

class NormalizeGjz(torch.nn.Module):
def __init__(self, mean, std):
self.mean = mean
self.std = std
super(NormalizeGjz, self).__init__()
self.register_buffer('mean', torch.tensor(mean), persistent=False)
self.register_buffer('std', torch.tensor(std), persistent=False)

def __call__(self, tensor):
tensor.sub_(self.mean).div_(self.std)
Expand Down
2 changes: 2 additions & 0 deletions vovit/core/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .. import fourier_defaults, VIDEO_FRAMERATE
from .production_model import VoViT
Empty file.
80 changes: 80 additions & 0 deletions vovit/core/models/modules/gconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# The based unit of graph convolutional networks.

import torch
import torch.nn as nn


class ConvTemporalGraphical(nn.Module):
r"""The basic module for applying a graph convolution.
Args:
in_channels (int): Number of channels in the input sequence data
out_channels (int): Number of channels produced by the convolution
kernel_size (int): Size of the graph convolving kernel
t_kernel_size (int): Size of the temporal convolving kernel
t_stride (int, optional): Stride of the temporal convolution. Default: 1
t_padding (int, optional): Temporal zero-padding added to both sides of
the input. Default: 0
t_dilation (int, optional): Spacing between temporal kernel elements.
Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output.
Default: ``True``
Shape:
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
t_kernel_size=1,
t_stride=1,
t_padding=0,
t_dilation=1,
bias=True,
):

super().__init__()

self.kernel_size = kernel_size
self.conv = nn.Conv2d(in_channels,
out_channels * kernel_size,
kernel_size=(t_kernel_size, 1),
padding=(t_padding, 0),
stride=(t_stride, 1),
dilation=(t_dilation, 1),
bias=bias)

def forward(self, x, A):
if A.ndim == 4:
assert A.size(1) == self.kernel_size
elif A.ndim == 5:
assert A.size(2) == self.kernel_size
else:
assert A.size(0) == self.kernel_size

x = self.conv(x) # B,channels=3,T,J

n, kc, t, v = x.size()
x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v)
if A.ndimension() == 3:
# static or dynamic
x = torch.einsum('nkctv,kvw->nctw', (x, A))
elif A.ndimension() == 4:
# Categorical
x = torch.einsum('nkctv,nkvw->nctw', (x, A))
elif A.ndimension() == 5:
x = torch.einsum('nkctv,ntkvw->nctw', (x, A))
else:
raise Exception('Adjacency matrix dimensionalty is %d but should be 3,4 or 5' % A.ndimension())
return x.contiguous(), A
Loading

0 comments on commit 499a689

Please sign in to comment.