Skip to content

Commit

Permalink
Setup diffusion modules.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaseris committed Feb 12, 2024
1 parent 981d9c6 commit c26468f
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/skelcast/models/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .embedding import DiffusionEmbedding
58 changes: 58 additions & 0 deletions src/skelcast/models/diffusers/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class DiffusionEmbedding(nn.Module):
def __init__(self, num_steps, embedding_dim=128, projection_dim=None):
# Codes from CSDI
super().__init__()
if projection_dim is None:
projection_dim = embedding_dim
self.register_buffer(
"embedding",
self._build_embedding(num_steps, embedding_dim / 2),
persistent=False,
)
self.projection1 = nn.Linear(embedding_dim, projection_dim)
self.projection2 = nn.Linear(projection_dim, projection_dim)

def forward(self, diffusion_step):
x = self.embedding[diffusion_step]
x = self.projection1(x)
x = F.silu(x)
x = self.projection2(x)
x = F.silu(x)
return x

def _build_embedding(self, num_steps, dim=64):
steps = torch.arange(num_steps).unsqueeze(1) # (T,1)
frequencies = 10.0 ** (torch.arange(dim) / (dim - 1) * 4.0).unsqueeze(
0
) # (1,dim)
table = steps * frequencies # (T,dim)
table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) # (T,dim*2)
return table


class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)

position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)

def forward(self, x):
"""
Args:
x: Tensor, shape [seq_len, batch_size, embedding_dim]
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
51 changes: 51 additions & 0 deletions src/skelcast/models/diffusers/series_diffuser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torch.nn as nn

from skelcast.models.diffusers.embedding import DiffusionEmbedding, PositionalEncoding


class Series_Denoiser(nn.Module):
def __init__(self, input_dim, qkv_dim, num_layers, num_heads, prefix_len, pred_len, diff_steps):
super().__init__()

self.input_dim = input_dim
self.qkv_dim = qkv_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.prefix_len = prefix_len
self.pred_len = pred_len
self.diff_steps = diff_steps

self.spatial_layer = nn.TransformerEncoderLayer(d_model=qkv_dim, nhead=num_heads,
dim_feedforward=qkv_dim, activation='gelu')
self.spatial_TF = nn.TransformerEncoder(self.spatial_layer, num_layers=num_layers)
self.spatial_inp_fc = nn.Linear(prefix_len+pred_len, qkv_dim)
self.spatial_out_fc = nn.Linear(qkv_dim, prefix_len+pred_len)

self.temporal_layer = nn.TransformerEncoderLayer(d_model=qkv_dim, nhead=num_heads,
dim_feedforward=qkv_dim, activation='gelu')
self.temporal_TF = nn.TransformerEncoder(self.temporal_layer, num_layers=num_layers)
self.temporal_inp_fc = nn.Linear(input_dim, qkv_dim)
self.temporal_out_fc = nn.Linear(qkv_dim, input_dim)

self.pos_encoder = PositionalEncoding(qkv_dim)

self.step_temporal_encoder = DiffusionEmbedding(diff_steps, qkv_dim)
self.step_spatial_encoder = DiffusionEmbedding(diff_steps, qkv_dim)

def forward(self, noise, x, t):
step_spatial_embed = self.step_spatial_encoder(t)
step_temporal_embed = self.step_temporal_encoder(t)
window = torch.cat([x, noise], dim=0)

spatial = self.spatial_inp_fc(window.permute(2, 1, 0)) + step_spatial_embed # L B D -> D B L
spatial = self.pos_encoder(spatial)
spatial = self.spatial_TF(spatial)
spatial = self.spatial_out_fc(spatial).permute(2, 1, 0)

temporal = self.temporal_inp_fc(spatial) + step_temporal_embed
temporal = self.pos_encoder(temporal)
temporal = self.temporal_TF(temporal)
temporal = self.temporal_out_fc(temporal)

return temporal[x.shape[0]:]

0 comments on commit c26468f

Please sign in to comment.