From e5db130f26ddf98aae548041b0e61f10a292e4dd Mon Sep 17 00:00:00 2001 From: Darius Couchard Date: Tue, 24 Sep 2024 09:15:25 +0200 Subject: [PATCH] Made split_stac utility work with input catalogue optionally --- src/openeo_gfmap/utils/split_stac.py | 30 ++++++++++++++++----------- tests/test_openeo_gfmap/test_utils.py | 4 ++-- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/openeo_gfmap/utils/split_stac.py b/src/openeo_gfmap/utils/split_stac.py index 3b24815..5b93cd8 100644 --- a/src/openeo_gfmap/utils/split_stac.py +++ b/src/openeo_gfmap/utils/split_stac.py @@ -83,24 +83,30 @@ def _create_collection_skeleton( return new_collection -def split_collection_by_epsg(path: Union[str, Path], output_dir: Union[str, Path]): +def split_collection_by_epsg( + collection: Union[str, Path, pystac.Collection], output_dir: Union[str, Path] +): """ Split a STAC collection into multiple STAC collections based on EPSG code. - Parameters: - path (str): The path to the STAC collection. - output_dir (str): The output directory. + Parameters + ---------- + collection: Union[str, Path, pystac.Collection] + A collection of STAC items or a path to a STAC collection. + output_dir: Union[str, Path] + The directory where the split STAC collections will be saved. """ - path = Path(path) - output_dir = Path(output_dir) - os.makedirs(output_dir, exist_ok=True) + if not isinstance(collection, pystac.Collection): + collection = Path(collection) + output_dir = Path(output_dir) + os.makedirs(output_dir, exist_ok=True) - try: - collection = pystac.read_file(path) - except pystac.STACError: - print("Please provide a path to a valid STAC collection.") - return + try: + collection = pystac.read_file(collection) + except pystac.STACError: + print("Please provide a path to a valid STAC collection.") + return collections_by_epsg = {} diff --git a/tests/test_openeo_gfmap/test_utils.py b/tests/test_openeo_gfmap/test_utils.py index dd5a907..b904f6f 100644 --- a/tests/test_openeo_gfmap/test_utils.py +++ b/tests/test_openeo_gfmap/test_utils.py @@ -207,7 +207,7 @@ def test_split_collection_by_epsg(tmp_path): output_dir = str(tmp_path / "split_collections") collection.normalize_and_save(input_dir) - split_collection_by_epsg(path=input_dir, output_dir=output_dir) + split_collection_by_epsg(collection=input_dir, output_dir=output_dir) # Collection contains two different EPSG codes, so 2 collections should be created assert len([p for p in Path(output_dir).iterdir() if p.is_dir()]) == 2 @@ -236,7 +236,7 @@ def test_split_collection_by_epsg(tmp_path): with pytest.raises(KeyError): collection.add_item(missing_epsg_item) collection.normalize_and_save(input_dir) - split_collection_by_epsg(path=input_dir, output_dir=output_dir) + split_collection_by_epsg(collection=input_dir, output_dir=output_dir) @patch("openeo_gfmap.utils.catalogue._query_cdse_catalogue", mock_query_cdse_catalogue)