Skip to content

Commit

Permalink
ensure faster reread of s1-coherence and add test for coh
Browse files Browse the repository at this point in the history
  • Loading branch information
cmarshak committed Dec 4, 2023
1 parent 58427e6 commit 5121fca
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
18 changes: 16 additions & 2 deletions tests/test_stitch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from tile_stitcher import get_raster_from_tiles
from tile_stitcher.stitcher import HANSEN_MOSAIC_YEARS
from tile_stitcher.stitcher import HANSEN_MOSAIC_YEARS, S1_TEMPORAL_BASELINE_DAYS, SEASONS


def test_esa_world_cover():
Expand All @@ -27,5 +27,19 @@ def test_pekel_water_occ():
def test_hansen_datasets(year):
# Note only getting 1 tile - these are large datasets!
bounds = [-120.45, 34.85, -120.15, 34.95]
X, p = get_raster_from_tiles(bounds, tile_shortname='hansen_annual_mosaic', year=year)
X, p = get_raster_from_tiles(bounds,
tile_shortname='hansen_annual_mosaic',
year=year)
assert len(X.shape) == 3


@pytest.mark.parametrize("season", SEASONS)
@pytest.mark.parametrize("temporal_baseline_days", S1_TEMPORAL_BASELINE_DAYS)
def test_coherence_dataset(season, temporal_baseline_days):
# Note only getting 1 tile
bounds = [-120.45, 34.85, -120.15, 34.95]
X, p = get_raster_from_tiles(bounds,
tile_shortname='s1_coherence_2020',
season=season,
temporal_baseline_days=temporal_baseline_days)
assert len(X.shape) == 3
16 changes: 11 additions & 5 deletions tile_stitcher/stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,25 @@
CURRENT_HANSEN_VERSION = 10
CURRENT_HANSEN_YEAR = 2022
SEASONS = ['fall', 'winter', 'spring', 'summer']
S1_TEMPORAL_BASELINE_DAYS = list(range(6, 49, 6))
S1_TEMPORAL_BASELINE_DAYS = [6, 12, 18, 24, 36, 48]


@lru_cache
def get_tile_data(tile_key: str,
year: int = None,
season: str = None,
temporal_baseline_days: int = None) -> gpd.GeoDataFrame:
def get_all_tile_data(tile_key: str) -> gpd.GeoDataFrame:
if tile_key not in DATASET_SHORTNAMES:
raise TilesetNotSupported
geojson_name = GEOJSON_DICT[tile_key]
geojson_path = DATA_DIR / geojson_name
df_tiles = read_geojson_gzip(geojson_path)
return df_tiles


@lru_cache
def get_tile_data(tile_key: str,
year: int = None,
season: str = None,
temporal_baseline_days: int = None) -> gpd.GeoDataFrame:
df_tiles = get_all_tile_data(tile_key)

if (year is not None):
if tile_key not in DATASETS_WITH_YEAR:
Expand Down

0 comments on commit 5121fca

Please sign in to comment.