Skip to content

Commit

Permalink
Test estimator correction
Browse files Browse the repository at this point in the history
  • Loading branch information
Francisco Berchez Moreno committed Nov 20, 2023
1 parent 210c84c commit 103021c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions dlordinal/estimator/tests/test_estimator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
import torch
from torch import cuda
Expand Down Expand Up @@ -74,11 +75,12 @@ def test_pytorch_estimator_fit():
estimator = PytorchEstimator(model, loss_fn, optimizer, device, max_iter)

# Verifies the training flow
initial_loss = calculate_loss(model, loss_fn, test_dataloader)
# initial_loss = calculate_loss(model, loss_fn, test_dataloader)
estimator.fit(train_dataloader)
final_loss = calculate_loss(model, loss_fn, test_dataloader)

assert final_loss < initial_loss
assert not np.isnan(final_loss)
assert not np.isinf(final_loss)


def test_pytorch_estimator_predict():
Expand Down

0 comments on commit 103021c

Please sign in to comment.