Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _slice_collection to use toBands and select #91

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 38 additions & 38 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def __init__(
else:
self.mask_value = mask_value

# verify that each image in the collection has a system:index property
self.has_system_index = self.get_info['system_index_count'] == self.n_images

@functools.cached_property
def get_info(self) -> Dict[str, Any]:
"""Make all getInfo() calls to EE at once."""
Expand Down Expand Up @@ -281,30 +284,27 @@ def get_info(self) -> Dict[str, Any]:
# (few) values of the primary dim (read: time) and interpolate the rest
# client-side. Ideally, this would live behind a xarray-backend-specific
# feature flag, since it's not guaranteed that data is this consistent.
columns = ['system:id', self.primary_dim_property]
rpcs.append((
'properties',
(
self.image_collection.reduceColumns(
ee.Reducer.toList().repeat(len(columns)), columns
).get('list')
),
'primary_coords',
self.image_collection.aggregate_array(self.primary_dim_property),
))

# since we are using system:index in a ee filter we dont need to pull all
# of the system:index values out with getInfo, we only need to verify that
# each image in the collection has a system:index property. If an image in
# a collection is missing a property, aggregate_array returns nothing for
# that image. So a call to aggregate_array('system:index') on a collection
# with 10 images where 1 is missing 'system:index' will return a list with
# 9 elements.
rpcs.append((
'system_index_count',
self.image_collection.aggregate_array('system:index').length(),
))

info = ee.List([rpc for _, rpc in rpcs]).getInfo()

return dict(zip((name for name, _ in rpcs), info))

@property
def image_collection_properties(self) -> Tuple[List[str], List[str]]:
system_ids, primary_coord = self.get_info['properties']
return (system_ids, primary_coord)

@property
def image_ids(self) -> List[str]:
image_ids, _ = self.image_collection_properties
return image_ids

def _max_itemsize(self) -> int:
return max(
_parse_dtype(b['data_type']).itemsize for b in self._img_info['bands']
Expand Down Expand Up @@ -522,7 +522,7 @@ def get_attrs(self) -> utils.Frozen[Any, Any]:

def _get_primary_coordinates(self) -> List[Any]:
"""Gets the primary dimension coordinate values from an ImageCollection."""
_, primary_coords = self.image_collection_properties
primary_coords = self.get_info['primary_coords']

if not primary_coords:
raise ValueError(
Expand Down Expand Up @@ -686,29 +686,29 @@ def _slice_collection(self, image_slice: slice) -> ee.Image:
# a range of images...
start, stop, stride = image_slice.indices(self.shape[0])

# If the input images have IDs, just slice them. Otherwise, we need to do
# an expensive `toList()` operation.
if self.store.image_ids:
imgs = self.store.image_ids[start:stop:stride]
# never need more than "stop" images so limit the collection upfront
col = self.store.image_collection.limit(stop)
col = col.select(self.variable_name)

if self.store.has_system_index: # recommended way to slice a collection
target_image = col.filter(
ee.Filter.listContains(
leftValue=col.aggregate_array('system:index').slice(
start, stop, stride
),
rightField='system:index',
)
).toBands()
elif (
stop <= 5000
): # toBands fails if it would create an image with 5000+ bands
selectors = list(range(start, stop, stride))
target_image = col.toBands().select(selectors)
else:
# TODO(alxr, mahrsee): Find a way to make this case more efficient.
list_range = stop - start
col0 = self.store.image_collection
imgs = col0.toList(list_range, offset=start).slice(0, list_range, stride)

col = ee.ImageCollection(imgs)

# For a more efficient slice of the series of images, we reduce each
# image in the collection to bands on a single image.
def reduce_bands(x, acc):
return ee.Image(acc).addBands(x, [self.variable_name])

aggregate_images_as_bands = ee.Image(col.iterate(reduce_bands, ee.Image()))
# Remove the first "constant" band from the reduction.
target_image = aggregate_images_as_bands.select(
aggregate_images_as_bands.bandNames().slice(1)
)

imgs = col.toList(list_range, offset=start).slice(0, list_range, stride)
target_image = ee.ImageCollection(imgs).toBands()
return target_image

def _raw_indexing_method(
Expand Down
64 changes: 64 additions & 0 deletions xee/ext_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,19 @@ def test_can_chunk__opened_dataset(self):
except ValueError:
self.fail('Chunking failed.')

def test_can_slice_past_5000(self):
ds = xr.open_dataset(
'NASA/GPM_L3/IMERG_V06',
crs='EPSG:4326',
scale=0.25,
engine=xee.EarthEngineBackendEntrypoint,
).isel(time=slice(4999, 5001), lon=slice(0, 1), lat=slice(0, 1))

try:
ds.chunk().compute()
except ValueError:
self.fail('Chunking failed.')

def test_honors_geometry(self):
ic = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY').filterDate(
'1992-10-05', '1993-03-31'
Expand Down Expand Up @@ -397,6 +410,57 @@ def test_validate_band_attrs(self):
for _, value in variable.attrs.items():
self.assertIsInstance(value, valid_types)

def test_rename_bands(self):
point = ee.Geometry.Point((-122.45, 37.79))
col = ee.ImageCollection('LANDSAT/LC08/C02/T1_L2').filterBounds(point)
col = col.map(lambda im: im.regexpRename('$', '_new'))
b1, b2 = col.first().bandNames().getInfo()[:2]

ds = xr.open_dataset(
col,
engine=xee.EarthEngineBackendEntrypoint,
scale=120,
crs='epsg:32610',
geometry=point.buffer(512).bounds(),
)

ds['sum'] = ds[b1] + ds[b2]

self.assertTrue('sum' in ds)

def test_add_new_bands(self):
s2 = ee.ImageCollection('COPERNICUS/S2_HARMONIZED')
geometry = ee.Geometry.Polygon([[
[82.60642647743225, 27.16350437805251],
[82.60984897613525, 27.1618529901377],
[82.61088967323303, 27.163695288375266],
[82.60757446289062, 27.16517483230927],
]])
filtered = (
s2.filter(ee.Filter.date('2017-01-01', '2018-01-01'))
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 30))
.filter(ee.Filter.bounds(geometry))
)

def addNDVI(image):
ndvi = image.normalizedDifference(['B8', 'B4']).rename('ndvi')
return image.addBands(ndvi)

withNdvi = filtered.map(addNDVI)

ds = xr.open_dataset(
withNdvi.select('ndvi'),
engine=xee.EarthEngineBackendEntrypoint,
crs='EPSG:3857',
scale=10,
geometry=geometry,
)

original_ts = ds.ndvi.chunk('auto')
original_ts = original_ts.interp(X=82.607376, Y=27.164335)

self.assertIsInstance(original_ts.values, np.ndarray)


if __name__ == '__main__':
absltest.main()
Loading