Skip to content

Commit

Permalink
Merge pull request #630 from DEUCE1957/EpisodeData
Browse files Browse the repository at this point in the history
Debug: Fix Missing Legal and Ambigious Arrays in CompactEpisodeData
  • Loading branch information
BDonnot authored Sep 23, 2024
2 parents d3a1c06 + 2bc55f8 commit f508dc6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
4 changes: 3 additions & 1 deletion grid2op/Episode/CompactEpisodeData.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ def __init__(self, env, obs, exp_dir, ep_id:str=None):
"""
if exp_dir is not None:
self.exp_dir = p(exp_dir)
self.exp_dir = self.exp_dir / "CompactEpisodeData"
self.exp_dir.mkdir(parents=False, exist_ok=True)
else:
self.exp_dir = None
self.array_names = ("actions", "env_actions", "attacks", "observations", "rewards", "other_rewards", "disc_lines", "times")
self.array_names = ("actions", "env_actions", "attacks", "observations", "rewards", "other_rewards", "disc_lines", "times", "legal", "ambiguous")
self.space_names = ("observation_space", "action_space", "attack_space", "env_modification_space")
if ep_id is None:
self.ep_id = env.chronics_handler.get_name()
Expand Down
2 changes: 1 addition & 1 deletion grid2op/Runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,7 +1728,7 @@ def run(
)
else:
if add_detailed_output and (_IS_WINDOWS or _IS_MACOS):
self.logger.warn(
self.logger.warning(
"Parallel run are not fully supported on windows or macos when "
'"add_detailed_output" is True. So we decided '
"to fully deactivate them."
Expand Down
17 changes: 11 additions & 6 deletions grid2op/tests/test_CompactEpisodeData.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import unittest

import grid2op
from grid2op.Agent import OneChangeThenNothing
from grid2op.Agent import DoNothingAgent, OneChangeThenNothing
from grid2op.tests.helper_path_test import *
from grid2op.Chronics import Multifolder
from grid2op.Reward import L2RPNReward
Expand Down Expand Up @@ -140,6 +140,8 @@ def act(self, observation, reward, done=False):
assert len(episode_data.observations) == self.max_iter + 1
assert len(episode_data.env_actions) == self.max_iter
assert len(episode_data.attacks) == self.max_iter
assert len(episode_data.ambiguous) == self.max_iter
assert len(episode_data.legal) == self.max_iter

def test_one_episode_with_saving(self):
f = tempfile.mkdtemp()
Expand All @@ -163,6 +165,7 @@ def test_collection_wrapper_after_run(self):
OneChange = OneChangeThenNothing.gen_next(
{"set_bus": {"lines_or_id": [(1, -1)]}}
)
# env.reset(options=)
runner = Runner(
init_grid_path=self.init_grid_path,
init_env_path=self.init_grid_path,
Expand All @@ -178,9 +181,11 @@ def test_collection_wrapper_after_run(self):
agentClass=OneChange,
use_compact_episode_data=True,
)
ep_id, ep_name, cum_reward, timestep, max_ts, episode_data = runner.run_one_episode(
max_iter=self.max_iter, detailed_output=True
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
*_, episode_data = runner.run_one_episode(
max_iter=self.max_iter, detailed_output=True,
)
# Check that the type of first action is set bus
assert episode_data.action_space.from_vect(episode_data.actions[0]).get_types()[2]

Expand Down Expand Up @@ -257,7 +262,7 @@ def test_with_opponent(self):
)

episode_data = CompactEpisodeData.from_disk(path=f, ep_id=res[0][1])
lines_impacted, subs_impacted = episode_data.attack_space.from_vect(episode_data.attacks[0]).get_topological_impact()
lines_impacted, _ = episode_data.attack_space.from_vect(episode_data.attacks[0]).get_topological_impact()
assert lines_impacted[3]

def test_can_return_ep_data(self):
Expand Down Expand Up @@ -296,4 +301,4 @@ def test_can_return_ep_data(self):


if __name__ == "__main__":
unittest.main()
unittest.main()

0 comments on commit f508dc6

Please sign in to comment.