forked from iceberg-project/Seals
-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_trainingset.py
221 lines (188 loc) · 10.1 KB
/
create_trainingset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import pandas as pd
import numpy as np
import os
import cv2
import time
import random
import argparse
import rasterio
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
parser = argparse.ArgumentParser(description='creates training sets to train and validate sealnet instances')
parser.add_argument('--rasters_dir', type=str, help='root directory where rasters are located')
parser.add_argument('--scale_bands', type=str, help='for multi-scale models, string with size of scale bands separated'
'by underscores')
parser.add_argument('--out_folder', type=str, help='directory where training set will be saved to')
parser.add_argument('--labels', type=str, help='class names, separated by underscores')
parser.add_argument('--det_classes', type=str, help='classes that will be targeted at detection, '
'separated by underscores')
parser.add_argument('--shape_file', type=str, help='path to shape file with seal points')
args = parser.parse_args()
def get_patches(out_folder: str, raster_dir: str, shape_file: str, lon: str, lat: str, patch_sizes: list,
labels: list) -> object:
"""
Generates multi-band patches at different scales around vector points to use as a training set.
Input:
out_folder: folder name for the created dataset
raster_dir : directory with raster images (.tif) we wish to extract patches from.
shape_file : path to .csv shape_file with training points with latitude, longitude, classification label and
source raster layer.
lon : column in vector_df containing longitude component.
lat : column in vector_df containing latitude component.
patch_sizes : list with pyramid dimensions in multi-band output, first element is used as base dimension,
subsequent elements must be larger than the previous element and will be re-sized to match patch_sizes[0]
labels : list with training set labels
Output:
folder with extracted multi-band training images separated in subfolders by intended label
"""
# check for invalid inputs
assert sum([patch_sizes[idx] > patch_sizes[idx+1] for idx in range(len(patch_sizes) - 1)]) == 0,\
"Patch sizes with non-increasing dimensions"
# extract detection classes, the training set will store count and (x,y) within tile for objects of those classes
det_classes = args.det_classes.split('_')
# save seal locations inside images
detections = pd.DataFrame()
# read csv exported from seal points shape file as a pandas DataFrame
df = pd.read_csv(shape_file)
# create training set directory
if not os.path.exists("./training_sets/"):
os.makedirs("./training_sets/")
if not os.path.exists("./training_sets/{}".format(out_folder)):
os.makedirs("./training_sets/{}".format(out_folder))
for folder in ['training', 'validation']:
if not os.path.exists("./training_sets/{}/{}".format(out_folder, folder)):
os.makedirs("./training_sets/{}/{}".format(out_folder, folder))
for lbl in labels:
subdir = "./training_sets/{}/{}/{}".format(out_folder, folder, lbl)
if not os.path.exists(subdir):
os.makedirs(subdir)
rasters = []
print("Checking input folder for invalid files:\n\n")
for path, _, files in os.walk(raster_dir):
for filename in files:
filename_lower = filename.lower()
# only add raster files wth annotated points
if not filename_lower.endswith('.tif'):
print(' {} is not a valid scene.'.format(filename))
continue
if filename not in pd.unique(df['scene']):
print(' {} is not an annotated scene.'.format(filename))
continue
rasters.append(os.path.join(path, filename))
# shuffle rasters and remove potential duplicates
rasters = shuffle(rasters)
# keep track of how many points were processed
num_imgs = 0
since = time.time()
print("\nCreating dataset:\n")
for idx, rs in enumerate(rasters):
# extract image data and affine transforms
with rasterio.open(rs) as src:
band = src.read()[0, :, :].astype(np.uint8)
affine_transforms = [src.transform[1], src.transform[2], src.transform[0], src.transform[4],
src.transform[5], src.transform[3]]
# get coordinates from affine matrix
width, _, x0, _, height, y0 = affine_transforms
# get distance to determine which points will fall inside the tile
tile_center = patch_sizes[0] // 2
# pad image
pad = patch_sizes[-1] // 2
band = np.pad(band, pad_width=pad, mode='constant', constant_values=0)
# filter points to include points inside current raster, sort them based on coordinates and fix index range
df_rs = df.loc[df['scene'] == os.path.basename(rs)]
df_rs = df_rs.sort_values(by=[lon, lat])
df_rs.index = range(len(df_rs.index))
# iterate through the points
for row, p in enumerate(df_rs.iterrows()):
x = int((p[1][lon] - x0) / width) + pad
y = int((p[1][lat] - y0) / height) + pad
upper_left = [x - int(patch_sizes[0]/2), y - int(patch_sizes[0]/2)]
bands = []
# extract patches at different scales
for scale in patch_sizes:
try:
patch = band[y - int(scale/2): y + int((scale+1)/2), x - int(scale/2): x + int((scale+1)/2)]
patch = cv2.resize(patch, (patch_sizes[0], patch_sizes[0]))
bands.append(patch)
except:
continue
# check if we have a valid image
if len(bands) == len(patch_sizes):
# combine bands into an image
bands = np.dstack(bands)
# save patch image to correct subfolder based on label
filename = "./training_sets/{}/{}/{}/{}.jpg".format(out_folder, p[1]['dataset'], p[1]['label'],
p[1]['shapeid'])
cv2.imwrite(filename, bands)
# store counts and detections
locs = ""
# add a detection in the center of the tile if class is in det_classes
if p[1]['label'] in det_classes:
locs += "{}_{}".format(tile_center, tile_center)
# look down the DataFrame for detections that also fall inside the tile
inside = True
search_idx = row + 1
while inside:
# check if idx is still inside DataFrame
if search_idx > (len(df_rs) - 1):
break
# get det_x, det_y
det_x = (int((df_rs.loc[search_idx, lon] - x0) / width) + pad) - upper_left[0]
det_y = (int((df_rs.loc[search_idx, lat] - y0) / height) + pad) - upper_left[1]
# check if it falls inside patch
if 0 <= det_x < patch_sizes[0]:
# check label
if 0 <= det_y < patch_sizes[0]:
if df_rs.loc[search_idx, 'label'] in det_classes:
# search y direction
plt.imshow(band[det_y - 50: det_y + 50, det_x - 50: det_x + 50])
locs += "_{}_{}".format(det_x, det_y)
search_idx += 1
else:
inside = False
# look up the DataFrame for detections that also fall inside the tile
inside = True
search_idx = row - 1
while inside:
# check if idx is still inside DataFrame
if search_idx < 0:
break
# get det_x, det_y
det_x = (int((df_rs.loc[search_idx, lon] - x0) / width) + pad) - upper_left[0]
det_y = (int((df_rs.loc[search_idx, lat] - y0) / height) + pad) - upper_left[1]
# check if it falls inside patch
if 0 <= det_x < patch_sizes[0]:
# check label
if 0 <= det_y < patch_sizes[0]:
if df_rs.loc[search_idx, 'label'] in det_classes:
# search y direction
plt.imshow(band[det_y - 50: det_y + 50, det_x - 50: det_x + 50])
locs += "_{}_{}".format(det_x, det_y)
search_idx -= 1
else:
inside = False
# add detections
new_row = pd.Series({'file_name': os.path.basename(filename), 'locations': locs})
new_row.name = p[1]['shapeid']
detections = detections.append(new_row)
num_imgs += 1
del band
print("\n Processed {} out of {} rasters".format(idx + 1, len(rasters)))
time_elapsed = time.time() - since
print("\n\n{} training images created in {:.0f}m {:.0f}s".format(num_imgs, time_elapsed // 60, time_elapsed % 60))
detections = detections.sort_index()
detections.to_csv('./training_sets/{}/detections.csv'.format(out_folder))
def main():
# set random seed to get same order of samples in both vanilla and multiscale training sets
random.seed(4)
raster_dir = args.rasters_dir
patch_sizes = [int(ele) for ele in args.scale_bands.split('_')]
out_folder = args.out_folder
labels = args.labels.split('_')
shape_file = args.shape_file
# create vanilla and multi-scale training sets
print('\nCreating {}:\n'.format(out_folder))
get_patches(out_folder=out_folder, raster_dir=raster_dir, shape_file=shape_file, lat='y', lon='x',
patch_sizes=patch_sizes, labels=labels)
if __name__ == '__main__':
main()