-
Notifications
You must be signed in to change notification settings - Fork 0
/
environment.py
79 lines (63 loc) · 2.94 KB
/
environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import annotations
from typing import List, Dict
from distributions import Distribution
from python_linq import From
class Map:
def __init__(self):
self.nodes: Dict[str, Node] = {}
self.distributions: Dict[str, Distribution] = {}
def getEdgeDistribution(self, fr: str, to: str) -> Distribution:
"""Returns the distribution of the edge between fr and to."""
return self.distributions[self.nodes[fr].edges[to].distribution]
def getEdgeCost(self, fr: str, to: str) -> float:
"""Returns the actual cost of the edge."""
return self.getEdgeDistribution(fr, to).getObservation()
def reset(self):
"""Resets the environment for a new run, i.e. resets all observed true costs."""
for _, distribution in self.distributions.items():
distribution.reset()
@staticmethod
def loadFromFile(filename: str) -> Map:
"""Static method to load a map from file."""
re = Map()
with open(filename, mode="r", encoding="utf-8") as file:
lines = file.read().splitlines()
i = 0
while i < len(lines):
if lines[i] == "NODES":
# Found start of NODES, let's find the end
j = i + 1
while lines[i] != "END NODES":
i += 1
# Add nodes found
re._addNodes(lines, j, i)
elif lines[i] == "DISTS":
# Found start of DISTS, let's find the end
j = i + 1
while lines[i] != "END DISTS":
i += 1
# Add nodes found
re._addDists(lines, j, i)
else:
i += 1
return re
def _addNodes(self, lines: List[str], fr: int, to: int):
"""Adds the nodes to the map during loading."""
self.nodes = From(lines).skip(fr).take(to - fr).select(Node).toDict(key=lambda node: node.name)
for node in self.nodes.values():
for edge in node.edges:
self.nodes[edge].parents.append(node.name)
def _addDists(self, lines: List[str], fr: int, to: int):
"""Adds the distributions to the map during loading."""
self.distributions = From(lines).skip(fr).take(to - fr).select(Distribution.getDistribution).toDict(key=lambda dist: dist.name)
class Node:
def __init__(self, in_str: str):
splits = in_str.split(' ')
self.name: str = splits[0]
self.edges: Dict[str, Edge] = From(splits[1:]).select(Edge).toDict(key=lambda edge: edge.target)
self.parents: List[str] = []
class Edge:
def __init__(self, in_str: str):
splits = in_str.split(':')
self.target: str = splits[0]
self.distribution: str = splits[1]