Skip to content

Commit

Permalink
fixing: transformer forward method
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed Jul 10, 2024
1 parent d3ac3f6 commit 6fbef5b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions lightorch/nn/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
fc: Optional[nn.Module] = None,
n_layers: int = 1,
) -> None:
assert (encoder is not None or decoder is not None), "Not valid parameters, must be at least one encoder or decoder."
super().__init__()
self.embedding = embedding_layer
self.pe = positional_encoding
Expand Down Expand Up @@ -110,8 +111,8 @@ def forward(self, x: Tensor) -> Tensor:
else:
for decoder in self.decoder:
out = decoder(x)

x = self.fc(out)
if self.fc:
x = self.fc(out)

return x

Expand Down

0 comments on commit 6fbef5b

Please sign in to comment.