Skip to content

Commit

Permalink
[Tests] skip if numba is not installed
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Jun 20, 2024
1 parent 57ace54 commit 1315201
Showing 1 changed file with 34 additions and 8 deletions.
42 changes: 34 additions & 8 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DACT,
MDAM,
N2S,
POMO,
ActiveSearch,
AttentionModelPolicy,
DeepACO,
Expand All @@ -29,7 +30,6 @@
NeuOpt,
PolyNet,
SymNCO,
POMO
)
from rl4co.utils import RL4COTrainer
from rl4co.utils.meta_trainer import ReptileCallback
Expand Down Expand Up @@ -130,20 +130,45 @@ def test_mdam():
trainer.fit(model)
trainer.test(model)


def test_pomo_reptile():
env = TSPEnv(generator_params=dict(num_loc=20))
policy = AttentionModelPolicy(env_name=env.name, embed_dim=128,
num_encoder_layers=6, num_heads=8,
normalization="instance", use_graph_context=False)
model = POMO(env, policy, batch_size=5, train_data_size=5*3, val_data_size=10, test_data_size=10)
policy = AttentionModelPolicy(
env_name=env.name,
embed_dim=128,
num_encoder_layers=6,
num_heads=8,
normalization="instance",
use_graph_context=False,
)
model = POMO(
env,
policy,
batch_size=5,
train_data_size=5 * 3,
val_data_size=10,
test_data_size=10,
)
meta_callback = ReptileCallback(
data_type="size", sch_bar=0.9, num_tasks=2, alpha = 0.99,
alpha_decay = 0.999, min_size = 20, max_size =50
data_type="size",
sch_bar=0.9,
num_tasks=2,
alpha=0.99,
alpha_decay=0.999,
min_size=20,
max_size=50,
)
trainer = RL4COTrainer(
max_epochs=2,
callbacks=[meta_callback],
devices=1,
accelerator=accelerator,
limit_train_batches=3,
)
trainer = RL4COTrainer(max_epochs=2, callbacks=[meta_callback], devices=1, accelerator=accelerator, limit_train_batches=3)
trainer.fit(model)
trainer.test(model)


@pytest.mark.parametrize("SearchMethod", [ActiveSearch, EASEmb, EASLay])
def test_search_methods(SearchMethod):
env = TSPEnv(generator_params=dict(num_loc=20))
Expand Down Expand Up @@ -175,6 +200,7 @@ def test_nargnn():
@pytest.mark.skipif(
"torch_geometric" not in sys.modules, reason="PyTorch Geometric not installed"
)
@pytest.mark.skipfif("numba" not in sys.modules, reason="Numba not installed")
def test_deepaco():
env = TSPEnv(generator_params=dict(num_loc=20))
model = DeepACO(env, train_data_size=10, val_data_size=10, test_data_size=10)
Expand Down

0 comments on commit 1315201

Please sign in to comment.