-
Notifications
You must be signed in to change notification settings - Fork 0
/
check_training_data.py
69 lines (58 loc) · 2.66 KB
/
check_training_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import numpy as np
import h5py
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.gridspec import GridSpec
from datetime import datetime
today = datetime.today().strftime('%Y-%m-%d')
import argparse
def visualize_data(fdir):
"""Check data before training the model"""
# read in the data
radiance = np.load(os.path.join(fdir, 'inp_radiance.npy'), allow_pickle=True).item()
cot_true = np.load(os.path.join(fdir, 'out_cot_3d.npy'), allow_pickle=True).item()
cot_ipa = np.load(os.path.join(fdir, 'out_cot_1d.npy'), allow_pickle=True).item()
assert list(radiance.keys()) == list(cot_true.keys()), 'Image names of radiance and cot are different'
key = 'data_10' # changeable
# create figure of panels - radiance, ipa, true COT
fig = plt.figure(figsize=(25, 10))
gs = GridSpec(1, 3, figure=fig)
ax = fig.add_subplot(gs[0, 0])
ax.tick_params(direction='out', length=10, width=2)
ax.imshow(radiance[key], cmap='jet')
ax.set_title('Radiance', fontsize=30)
ax.set_xticks([0, 20, 40, 60])
ax.set_xticklabels([0, 2, 4, 6], fontsize=24, fontweight="bold")
ax.set_yticks([0, 20, 40, 60])
ax.set_yticklabels([0, 2, 4, 6], fontsize=24, fontweight="bold")
ax.set_xlabel('X [km]', fontsize=28, fontweight="bold")
ax.set_ylabel('Y [km]', fontsize=28, fontweight="bold")
ax = fig.add_subplot(gs[0, 1])
ax.tick_params(direction='out', length=10, width=2)
ax.imshow(cot_ipa[key], cmap='jet')
ax.set_title('IPA COT', fontsize=30)
ax.set_xticks([0, 20, 40, 60])
ax.set_xticklabels([0, 2, 4, 6], fontsize=24, fontweight="bold")
ax.set_yticks([])
# ax.set_yticklabels([0, 2, 4, 6], fontsize=24, fontweight="bold")
ax.set_xlabel('X [km]', fontsize=28, fontweight="bold")
# ax.set_ylabel('Y [km]', fontsize=28, fontweight="bold")
ax = fig.add_subplot(gs[0, 2])
ax.tick_params(direction='out', length=10, width=2)
ax.imshow(cot_true[key], cmap='jet')
ax.set_title('True COT (Binned)', fontsize=30)
ax.set_xticks([0, 20, 40, 60])
ax.set_xticklabels([0, 2, 4, 6], fontsize=24, fontweight="bold")
ax.set_yticks([])
# ax.set_yticklabels([0, 2, 4, 6], fontsize=24, fontweight="bold")
ax.set_xlabel('X [km]', fontsize=28, fontweight="bold")
# ax.set_ylabel('Y [km]', fontsize=28, fontweight="bold")
fig.savefig('sample_training_data_{}.png'.format(today), dpi=100, bbox_inches = 'tight', pad_inches = 0.25)
print('Saved image as sample_training_data_{}.png'.format(today))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--fdir', default='data/npy/rgb_radiance', type=str,
help="Path to directory containing npy files")
args = parser.parse_args()
visualize_data(args.fdir)