Skip to content

Commit

Permalink
Refactor and performance code while getting and processing the data i…
Browse files Browse the repository at this point in the history
…n chunks
  • Loading branch information
XavierCLL committed Aug 18, 2023
1 parent 2a5296c commit cd711f9
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 60 deletions.
106 changes: 50 additions & 56 deletions stack_composed/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,27 @@ class Image:
def __init__(self, file_path):
self.file_path = self.get_dataset_path(file_path)
### set geoproperties ###
self.gdal_file = gdal.Open(self.file_path, gdal.GA_ReadOnly)
# setting the extent, pixel sizes and projection
gdal_file = gdal.Open(self.file_path, gdal.GA_ReadOnly)
min_x, x_res, x_skew, max_y, y_skew, y_res = gdal_file.GetGeoTransform()
max_x = min_x + (gdal_file.RasterXSize * x_res)
min_y = max_y + (gdal_file.RasterYSize * y_res)
min_x, x_res, x_skew, max_y, y_skew, y_res = self.gdal_file.GetGeoTransform()
max_x = min_x + (self.gdal_file.RasterXSize * x_res)
min_y = max_y + (self.gdal_file.RasterYSize * y_res)
# extent
self.extent = [min_x, max_y, max_x, min_y]
# pixel sizes
self.x_res = abs(float(x_res))
self.y_res = abs(float(y_res))
# number of bands
self.n_bands = gdal_file.RasterCount
# no data values from arguments
self.n_bands = self.gdal_file.RasterCount
# no data values
self.nodata_parameters = None
self.nodata_from_file = [self.gdal_file.GetRasterBand(i).GetNoDataValue() for i in range(1, self.n_bands + 1)]
# projection
if Image.projection is None:
Image.projection = gdal_file.GetProjectionRef()
del gdal_file
Image.projection = self.gdal_file.GetProjectionRef()
# output type
self.output_type = None
self.gdal_file = None

@staticmethod
def get_dataset_path(file_path):
Expand Down Expand Up @@ -77,76 +78,69 @@ def get_chunk(self, band, xoff, xsize, yoff, ysize):
"""
Get the array of the band for the respective chunk
"""
gdal_file = gdal.Open(self.file_path, gdal.GA_ReadOnly)
raster_band = gdal_file.GetRasterBand(band).ReadAsArray(xoff, yoff, xsize, ysize)
raster_band = raster_band.astype(np.float32)
if self.gdal_file is None:
self.gdal_file = gdal.Open(self.file_path, gdal.GA_ReadOnly)
raster_band = self.gdal_file.GetRasterBand(band).ReadAsArray(xoff, yoff, xsize, ysize).astype(np.float32)

# convert the no data values from file to NaN
nodata_from_file = gdal_file.GetRasterBand(band).GetNoDataValue()
if nodata_from_file is not None:
raster_band[raster_band == nodata_from_file] = np.nan
if self.nodata_from_file[band] is not None:
nodata_mask = raster_band == self.nodata_from_file[band]
raster_band[nodata_mask] = np.nan

# convert the no data values set from arguments to NaN
if self.nodata_parameters is not None and self.nodata_parameters != nodata_from_file:
if self.nodata_parameters is not None and self.nodata_parameters != self.nodata_from_file[band]:
if isinstance(self.nodata_parameters, (int, float)):
raster_band[raster_band == self.nodata_parameters] = np.nan
nodata_mask = raster_band == self.nodata_parameters
raster_band[nodata_mask] = np.nan
else:
for condition in self.nodata_parameters:
if condition[0] == "<":
raster_band[raster_band < condition[1]] = np.nan
elif condition[0] == "<=":
raster_band[raster_band <= condition[1]] = np.nan
elif condition[0] == ">":
raster_band[raster_band > condition[1]] = np.nan
elif condition[0] == ">=":
raster_band[raster_band >= condition[1]] = np.nan
elif condition[0] == "==":
raster_band[raster_band == condition[1]] = np.nan

del gdal_file
operator, threshold = condition[0], condition[1]
eval_string = f'raster_band {operator} {threshold}'
nodata_mask = eval(eval_string)
raster_band[nodata_mask] = np.nan

return raster_band

def get_chunk_in_wrapper(self, band, xc, xc_size, yc, yc_size):
"""
Get the array of the band adjusted into the wrapper matrix for the respective chunk
"""
# bounds for chunk with respect to wrapper
# the 0,0 is left-upper corner
# Calculate bounds for the chunk within the wrapper
xc_min = xc
xc_max = xc+xc_size
xc_max = xc + xc_size
yc_min = yc
yc_max = yc+yc_size
yc_max = yc + yc_size

# check if the current chunk is outside of the image
if xc_min >= self.xi_max or xc_max <= self.xi_min or yc_min >= self.yi_max or yc_max <= self.yi_min:
# Check if the chunk is outside the wrapper's bounds
if xc_max <= self.xi_min or xc_min >= self.xi_max or yc_max <= self.yi_min or yc_min >= self.yi_max:
return None
else:
# initialize the chunk with a nan matrix
chunk_matrix = np.full((yc_size, xc_size), np.nan)

# set bounds for get the array chunk in image
xoff = 0 if xc_min <= self.xi_min else xc_min - self.xi_min
xsize = xc_max - self.xi_min if xc_min <= self.xi_min else self.xi_max - xc_min
yoff = 0 if yc_min <= self.yi_min else yc_min - self.yi_min
ysize = yc_max - self.yi_min if yc_min <= self.yi_min else self.yi_max - yc_min

# adjust to maximum size with respect to chunk or/and image
xsize = xc_size if xsize > xc_size else xsize
xsize = self.xi_max - self.xi_min if xsize > self.xi_max - self.xi_min else xsize
ysize = yc_size if ysize > yc_size else ysize
ysize = self.yi_max - self.yi_min if ysize > self.yi_max - self.yi_min else ysize
# Calculate the overlapping region between chunk and wrapper
x_start = max(xc_min, self.xi_min)
x_end = min(xc_max, self.xi_max)
y_start = max(yc_min, self.yi_min)
y_end = min(yc_max, self.yi_max)

# set bounds for fill in chunk matrix
x_min = self.xi_min - xc_min if xc_min <= self.xi_min else 0
x_max = x_min + xsize if x_min + xsize < xc_max else xc_max
y_min = self.yi_min - yc_min if yc_min <= self.yi_min else 0
y_max = y_min + ysize if y_min + ysize < yc_max else yc_max
# Calculate the offset and size for the get_chunk function
xoff = max(0, x_start - self.xi_min)
xsize = x_end - x_start
yoff = max(0, y_start - self.yi_min)
ysize = y_end - y_start

# fill with the chunk data of the image in the corresponding position
chunk_matrix[y_min:y_max, x_min:x_max] = self.get_chunk(band, xoff, xsize, yoff, ysize)
# Get the chunk data from the main get_chunk function
chunk_data = self.get_chunk(band, xoff, xsize, yoff, ysize)

return chunk_matrix
# Create a nan-filled chunk matrix
chunk_matrix = np.full((yc_size, xc_size), np.nan)

# Calculate the fill bounds within the chunk_matrix
fill_x_start = max(0, xc_min - x_start)
fill_x_end = min(xc_size, xc_max - x_start)
fill_y_start = max(0, yc_min - y_start)
fill_y_end = min(yc_size, yc_max - y_start)

# Fill the overlapping region with the chunk data
chunk_matrix[fill_y_start:fill_y_end, fill_x_start:fill_x_end] = \
chunk_data[fill_y_start:fill_y_end, fill_x_start:fill_x_end]

return chunk_matrix
8 changes: 4 additions & 4 deletions stack_composed/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,14 @@ def _preprocess(self, chunks):
return np.array(chunks)

if self.prep_func.startswith('less_than_'):
less_than_N = int(self.prep_func.split('_')[2])
mask = np.array([chunk < less_than_N for chunk in chunks])
threshold = int(self.prep_func.split('_')[2])
mask = np.array([chunk < threshold for chunk in chunks])
chunks_data = np.where(mask, chunks, np.nan)
return chunks_data

if self.prep_func.startswith('greater_than_'):
greater_than_N = int(self.prep_func.split('_')[2])
mask = np.array([chunk > greater_than_N for chunk in chunks])
threshold = int(self.prep_func.split('_')[2])
mask = np.array([chunk > threshold for chunk in chunks])
chunks_data = np.where(mask, chunks, np.nan)
return chunks_data

Expand Down

0 comments on commit cd711f9

Please sign in to comment.