diff --git a/tutorials/02-intermediate/recurrent_neural_network/main.py b/tutorials/02-intermediate/recurrent_neural_network/main.py index 9b8685ca..c37ac4b4 100644 --- a/tutorials/02-intermediate/recurrent_neural_network/main.py +++ b/tutorials/02-intermediate/recurrent_neural_network/main.py @@ -42,7 +42,7 @@ def __init__(self, input_size, hidden_size, num_layers, num_classes): super(RNN, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers - self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) + self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, num_classes) def forward(self, x): @@ -51,7 +51,7 @@ def forward(self, x): c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # Forward propagate LSTM - out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size) + out, _ = self.rnn(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size) # Decode the hidden state of the last time step out = self.fc(out[:, -1, :]) @@ -99,4 +99,4 @@ def forward(self, x): print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) # Save the model checkpoint -torch.save(model.state_dict(), 'model.ckpt') \ No newline at end of file +torch.save(model.state_dict(), 'model.ckpt')