diff --git a/earthaccess/api.py b/earthaccess/api.py index f0264454..a10ee0e2 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -6,6 +6,7 @@ from fsspec import AbstractFileSystem from .auth import Auth +from .results import DataGranule from .search import CollectionQuery, DataCollections, DataGranules, GranuleQuery from .store import Store from .utils import _validation as validate @@ -150,8 +151,8 @@ def login(strategy: str = "all", persist: bool = False) -> Auth: def download( - granules: Union[List[earthaccess.results.DataGranule], List[str]], - local_path: Optional[str], + granules: Union[DataGranule, List[DataGranule], List[str]], + local_path: Union[str, None], provider: Optional[str] = None, threads: int = 8, ) -> List[str]: @@ -161,7 +162,7 @@ def download( * If we run it outside AWS (us-west-2 region) and the dataset is cloud hostes we'll use HTTP links Parameters: - granules: a list of granules(DataGranule) instances or a list of granule links (HTTP) + granules: a granule, list of granules, or a list of granule links (HTTP) local_path: local directory to store the remote data granules provider: if we download a list of URLs we need to specify the provider. threads: parallel number of threads to use to download the files, adjust as necessary, default = 8 @@ -169,6 +170,8 @@ def download( Returns: List of downloaded files """ + if isinstance(granules, DataGranule): + granules = [granules] try: results = earthaccess.__store__.get(granules, local_path, provider, threads) except AttributeError as err: diff --git a/earthaccess/store.py b/earthaccess/store.py index 62421e9f..cfe7bc79 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -453,6 +453,12 @@ def get( Returns: List of downloaded files """ + if local_path is None: + local_path = os.path.join( + ".", + "data", + f"{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}", + ) if len(granules): files = self._get(granules, local_path, provider, threads) return files @@ -464,7 +470,7 @@ def get( def _get( self, granules: Union[List[DataGranule], List[str]], - local_path: Optional[str] = None, + local_path: str, provider: Optional[str] = None, threads: int = 8, ) -> Union[None, List[str]]: @@ -492,7 +498,7 @@ def _get( def _get_urls( self, granules: List[str], - local_path: Optional[str] = None, + local_path: str, provider: Optional[str] = None, threads: int = 8, ) -> Union[None, List[str]]: @@ -509,22 +515,21 @@ def _get_urls( s3_fs = self.get_s3fs_session(provider=provider) # TODO: make this parallel or concurrent for file in data_links: - file_name = file.split("/")[-1] s3_fs.get(file, local_path) - print(f"Retrieved: {file} to {local_path}") + file_name = os.path.join(local_path, os.path.basename(file)) + print(f"Downloaded: {file_name}") downloaded_files.append(file_name) return downloaded_files else: # if we are not in AWS return self._download_onprem_granules(data_links, local_path, threads) - return None @_get.register def _get_granules( self, granules: List[DataGranule], - local_path: Optional[str] = None, + local_path: str, provider: Optional[str] = None, threads: int = 8, ) -> Union[None, List[str]]: @@ -557,14 +562,13 @@ def _get_granules( # TODO: make this async for file in data_links: s3_fs.get(file, local_path) - file_name = file.split("/")[-1] - print(f"Retrieved: {file} to {local_path}") + file_name = os.path.join(local_path, os.path.basename(file)) + print(f"Downloaded: {file_name}") downloaded_files.append(file_name) return downloaded_files else: # if the data is cloud based bu we are not in AWS it will be downloaded as if it was on prem return self._download_onprem_granules(data_links, local_path, threads) - return None def _download_file(self, url: str, directory: str) -> str: """ @@ -598,10 +602,10 @@ def _download_file(self, url: str, directory: str) -> str: raise Exception else: print(f"File {local_filename} already downloaded") - return local_filename + return local_path def _download_onprem_granules( - self, urls: List[str], directory: Optional[str] = None, threads: int = 8 + self, urls: List[str], directory: str, threads: int = 8 ) -> List[Any]: """ downloads a list of URLS into the data directory. @@ -618,14 +622,10 @@ def _download_onprem_granules( "We need to be logged into NASA EDL in order to download data granules" ) return [] - if directory is None: - directory_prefix = f"./data/{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}" - else: - directory_prefix = directory - if not os.path.exists(directory_prefix): - os.makedirs(directory_prefix) + if not os.path.exists(directory): + os.makedirs(directory) - arguments = [(url, directory_prefix) for url in urls] + arguments = [(url, directory) for url in urls] results = pqdm( arguments, self._download_file, diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index d4848d60..71745ff5 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -1,7 +1,6 @@ # package imports import logging import os -import shutil import unittest import earthaccess @@ -69,16 +68,15 @@ def test_granules_search_returns_valid_results(kwargs): assertions.assertTrue(len(results) <= 10) -def test_earthaccess_api_can_download_granules(): +@pytest.mark.parametrize("selection", [0, slice(None)]) +def test_earthaccess_api_can_download_granules(tmp_path, selection): results = earthaccess.search_data( count=2, short_name="ATL08", cloud_hosted=True, bounding_box=(-92.86, 16.26, -91.58, 16.97), ) - local_path = "./tests/integration/data/ATL08" - assertions.assertIsInstance(results, list) - assertions.assertTrue(len(results) <= 2) - files = earthaccess.download(results, local_path=local_path) + result = results[selection] + files = earthaccess.download(result, str(tmp_path)) assertions.assertIsInstance(files, list) - shutil.rmtree(local_path) + assert all(os.path.exists(f) for f in files)