Skip to content

Commit

Permalink
fixing linting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
mk314k committed Nov 21, 2023
1 parent 695ffcd commit b9dc569
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class R3DAttention(nn.Module):
"""

def __init__(self, hidden_size: int, num_heads: int, dropout=0.1):
super(R3DAttention, self).__init__()
super().__init__()

# Calculating the size of each attention head
head_size = int(hidden_size / num_heads)
Expand Down
4 changes: 2 additions & 2 deletions models/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def forward(self, x: torch.Tensor):
return out


class R3DEncoder(nn.Module):
class R3DEncoder(nn.Module): # pylint: disable=too-many-instance-attributes
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
in_channel=1,
num_patches=128,
Expand Down
9 changes: 4 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ def vae_loss(recon_x, label):
return F.binary_cross_entropy(recon_x[0], label, reduction='sum')


def train(
def train( # pylint: disable=too-many-locals
train_x:torch.Tensor,
train_y:torch.Tensor,
model_e:nn.Module,
model_d:nn.Module,
model_g:nn.Module,
models,
optims,
num_epochs=10
):
Expand All @@ -68,6 +66,7 @@ def train(
_type_: _description_
"""
vae_optimizer, gan_optimizer = optims
model_e, model_g, model_d = models
train_losses = []
for _ in tqdm.tqdm(range(num_epochs)):
for i in range(70):
Expand Down Expand Up @@ -117,7 +116,7 @@ def train(
weight_decay=WD,
betas=betas
)
train_loss = train(train_data, train_label, MODEL_E, MODEL_G, MODEL_D, (vae_optim, gan_optim))
train_loss = train(train_data, train_label, (MODEL_E, MODEL_G, MODEL_D), (vae_optim, gan_optim))
# Plotting training losses
plt.figure(figsize=(16, 6))
plt.plot(range(len(train_loss)), [tloss[0] for tloss in train_loss], label='Training VAE loss')
Expand Down

0 comments on commit b9dc569

Please sign in to comment.