Skip to content

Commit

Permalink
Merge pull request #54 from kaseris/init
Browse files Browse the repository at this point in the history
Init parameters and apply tanh
  • Loading branch information
kaseris committed Dec 14, 2023
2 parents 34cd2c1 + 127a0dc commit 28ccd3e
Showing 1 changed file with 44 additions and 3 deletions.
47 changes: 44 additions & 3 deletions src/skelcast/models/rnn/pvred.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, rnn_type: str = 'rnn',
input_dim: int = 75,
hidden_dim: int = 256,
batch_first: bool = True,
dropout: float = 0.2,
dropout: float = 0.5,
use_residual: bool = True) -> None:
super().__init__()
assert rnn_type in ['lstm', 'gru'], f'rnn_type must be one of rnn, lstm, gru, got {rnn_type}'
Expand All @@ -29,6 +29,26 @@ def __init__(self, rnn_type: str = 'rnn',
self.linear = nn.Linear(hidden_dim, input_dim)
self.dropout = nn.Dropout(dropout)

if self.rnn_type == 'lstm':
for name, param in self.rnn.named_parameters():
if 'weight_ih' in name:
torch.nn.init.xavier_uniform_(param.data)
elif 'weight_hh' in name:
torch.nn.init.orthogonal_(param.data)
elif 'bias' in name:
param.data.fill_(0)

elif self.rnn_type == 'gru':
for name, param in self.rnn.named_parameters():
if 'weight_ih' in name:
torch.nn.init.kaiming_normal_(param.data)
elif 'weight_hh' in name:
torch.nn.init.orthogonal_(param.data)
elif 'bias' in name:
param.data.fill_(0)



def forward(self, x: torch.Tensor) -> torch.Tensor:
out, hidden = self.rnn(x)
out = self.dropout(out)
Expand All @@ -44,7 +64,7 @@ def __init__(self,rnn_type: str = 'rnn',
input_dim: int = 75,
hidden_dim: int = 256,
batch_first: bool = True,
dropout: float = 0.2,
dropout: float = 0.5,
use_residual: bool = True) -> None:
super().__init__()
assert rnn_type in ['lstm', 'gru'], f'rnn_type must be one of rnn, lstm, gru, got {rnn_type}'
Expand All @@ -57,16 +77,37 @@ def __init__(self,rnn_type: str = 'rnn',
elif self.rnn_type == 'gru':
self.rnn = nn.GRU(input_size=input_dim, hidden_size=hidden_dim, batch_first=batch_first)
self.linear = nn.Linear(hidden_dim, input_dim)
self.tanh = nn.Tanh()
self.dropout = nn.Dropout(dropout)

if self.rnn_type == 'lstm':
for name, param in self.rnn.named_parameters():
if 'weight_ih' in name:
torch.nn.init.xavier_uniform_(param.data)
elif 'weight_hh' in name:
torch.nn.init.orthogonal_(param.data)
elif 'bias' in name:
param.data.fill_(0)

elif self.rnn_type == 'gru':
for name, param in self.rnn.named_parameters():
if 'weight_ih' in name:
torch.nn.init.kaiming_normal_(param.data)
elif 'weight_hh' in name:
torch.nn.init.orthogonal_(param.data)
elif 'bias' in name:
param.data.fill_(0)


def forward(self, x: torch.Tensor, hidden: torch.Tensor = None, timesteps_to_predict: int = 5) -> torch.Tensor:
predictions = []
for _ in range(timesteps_to_predict):
out, hidden = self.rnn(x, hidden)
out = self.dropout(out)
out = self.linear(out)
out = self.tanh(out)
if self.use_residual:
out = out + x
out = self.tanh(out + x)
predictions.append(out)
x = out
return torch.cat(predictions, dim=1)
Expand Down

0 comments on commit 28ccd3e

Please sign in to comment.