-
Notifications
You must be signed in to change notification settings - Fork 0
/
util_scikit.py
156 lines (118 loc) · 5.16 KB
/
util_scikit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
AI4ER GTC - Sea Ice Classification
Classes for loading the data for input to the models
"""
import pandas as pd
import numpy as np
import xarray as xr
import rioxarray as rxr
from numpy import ndarray
from xarray.core.dataarray import DataArray
def define_band3(sar: DataArray, sar_band3: str = 'angle') -> DataArray:
"""
Defines the type of band in the third channel of the DataArray.
The possible options are ratio and angle.
Parameters:
sar (xarray.core.dataarray.DataArray): SAR image
sar_band3 (str): Name of 3rd band to return in the DataArray
Returns:
sar (xarray.core.dataarray.DataArray): DataArray with the 3rd band as specified
"""
if sar_band3 == "ratio":
band1 = sar.sel(band=1)
band2 = sar.sel(band=2)
band3 = sar.sel(band=3)
band3.values = (band1.values / (band2.values + 0.0001))
sar.loc[dict(band=3)] = band3
return sar
def normalize_sar(sar: DataArray, sar_band3: str = 'angle') -> DataArray:
"""
Normalises a SAR image with the mean and standard deviation of the whole
training dataset. Returns a normalised DataArray
Parameters:
sar (xarray.core.dataarray.DataArray): SAR image
sar_band3 (str): Name of 3rd band to use for metrics
Returns:
sar (xarray.core.dataarray.DataArray): DataArray with all the bands normalised
"""
metrics_df = pd.read_csv("metrics.csv", delimiter=",")
hh_mean = metrics_df.iloc[0]["hh_mean"]
hh_std = metrics_df.iloc[0]["hh_std"]
hv_mean = metrics_df.iloc[0]["hv_mean"]
hv_std = metrics_df.iloc[0]["hv_std"]
angle_mean = metrics_df.iloc[0]["angle_mean"]
angle_std = metrics_df.iloc[0]["angle_std"]
ratio_mean = metrics_df.iloc[0]["hh_hv_mean"]
ratio_std = metrics_df.iloc[0]["hh_hv_std"]
if sar_band3 == "angle":
band3_mean = angle_mean
band3_std = angle_std
elif sar_band3 == "ratio":
band3_mean = ratio_mean
band3_std = ratio_std
sar[0] = (sar[0] - hh_mean) / hh_std
sar[1] = (sar[1] - hv_mean) / hv_std
sar[2] = (sar[2] - band3_mean) / band3_std
return sar
def recategorize_chart(chart: DataArray, class_categories: dict) -> DataArray:
"""
Assigns new categories to an ice chart image and returns the corresponding DataArray.
Parameters:
chart (xarray.core.dataarray.DataArray): Ice chart image
class_categories (dict): Dictionary of new class labels to be used
Returns:
chart (xarray.core.dataarray.DataArray): DataArray with the new class labels
"""
if class_categories is not None:
for key, value in class_categories.items():
chart[np.isin(chart, value)] = key
return chart
def load_sar(file_path: str, sar_band3: str, parse_coordinates: bool=True) -> ndarray:
"""
Wrapper of the loading and processing functions for SAR images.
Returns an ndarray.
Parameters:
file_path (str): Path to raster file
parse_coordinates (bool): Parses the coordinates of the file, if any
sar_band3 (str): Name of 3rd band to return in the DataArray
Returns:
sar (numpy.ndarray): DataArray with the new class labels
"""
sar = rxr.open_rasterio(file_path, parse_coordinates=parse_coordinates)
band3_sar = define_band3(sar, sar_band3)
normalized_raster = normalize_sar(band3_sar, sar_band3)
return normalized_raster.values
def load_chart(file_path: str, class_categories: dict, parse_coordinates: bool=True, masked: bool=True, flip_vertically: bool=False) -> ndarray:
"""
Wrapper of the loading and processing functions for ice chart images.
Returns an ndarray.
Parameters:
file_path (str): Path to raster file
class_categories (dict): Dictionary of new class labels to be used
parse_coordinates (bool): Parses the coordinates of the file, if any
masked (bool): Reads raster as a mask
flip_vertically (bool): Whether to flip the resulting DataArray vertically
Returns:
sar (numpy.ndarray): DataArray with the new class labels
"""
chart = rxr.open_rasterio(file_path, parse_coordinates=parse_coordinates, masked=masked)
if flip_vertically:
chart.reindex(y=chart.y[::-1])
new_raster = recategorize_chart(chart.values, class_categories)
return new_raster
def crop_image(raster: DataArray, height_size: int, width_size: int) -> ndarray:
"""
Crops a DataArray image according to input parameters.
Returns an ndarray.
Parameters:
raster (xarray.core.dataarray.DataArray): Ice chart image
height_size (int): Size in the y-axis for the resulting ndarray
width_size (int): Size in the y-axis for the resulting ndarray
Returns:
raster (numpy.ndarray): ndarray with the specified size
"""
_, height, width = raster.shape
y_pad = (height-height_size) // 2
x_pad = (width-width_size) // 2
raster = raster[:, y_pad:y_pad+height_size, x_pad:x_pad+width_size]
return raster