-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathvisualize.py
59 lines (48 loc) · 1.95 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
from matplotlib import animation
import matplotlib.pyplot as plt
TYPE_TO_COLOR = {
3: "black",
0: "green",
7: "magenta",
6: "gold",
5: "blue",
}
def visualize_prepare(ax, particle_type, position, metadata):
bounds = metadata["bounds"]
ax.set_xlim(bounds[0][0], bounds[0][1])
ax.set_ylim(bounds[1][0], bounds[1][1])
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect(1.0)
points = {type_: ax.plot([], [], "o", ms=2, color=color)[0] for type_, color in TYPE_TO_COLOR.items()}
return ax, position, points
def visualize_single(particle_type, position, metadata):
fig, axes = plt.subplots(1, 1, figsize=(5, 5))
plot_info = [visualize_prepare(axes, particle_type, position, metadata)]
def update(step_i):
outputs = []
for _, position, points in plot_info:
for type_, line in points.items():
mask = particle_type == type_
line.set_data(position[step_i, mask, 0], position[step_i, mask, 1])
outputs.append(line)
return outputs
return animation.FuncAnimation(fig, update, frames=np.arange(0, position.size(0)), interval=10)
def visualize_pair(particle_type, position_pred, position_gt, metadata):
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
plot_info = [
visualize_prepare(axes[0], particle_type, position_gt, metadata),
visualize_prepare(axes[1], particle_type, position_pred, metadata),
]
axes[0].set_title("Ground truth")
axes[1].set_title("Prediction")
def update(step_i):
outputs = []
for _, position, points in plot_info:
for type_, line in points.items():
mask = particle_type == type_
line.set_data(position[step_i, mask, 0], position[step_i, mask, 1])
outputs.append(line)
return outputs
return animation.FuncAnimation(fig, update, frames=np.arange(0, position_gt.size(0)), interval=10)