Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor earthaccess.download updates #317

Merged
merged 2 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions earthaccess/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fsspec import AbstractFileSystem

from .auth import Auth
from .results import DataGranule
from .search import CollectionQuery, DataCollections, DataGranules, GranuleQuery
from .store import Store
from .utils import _validation as validate
Expand Down Expand Up @@ -150,8 +151,8 @@ def login(strategy: str = "all", persist: bool = False) -> Auth:


def download(
granules: Union[List[earthaccess.results.DataGranule], List[str]],
local_path: Optional[str],
granules: Union[DataGranule, List[DataGranule], List[str]],
local_path: Union[str, None],
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand All @@ -161,14 +162,16 @@ def download(
* If we run it outside AWS (us-west-2 region) and the dataset is cloud hostes we'll use HTTP links

Parameters:
granules: a list of granules(DataGranule) instances or a list of granule links (HTTP)
granules: a granule, list of granules, or a list of granule links (HTTP)
local_path: local directory to store the remote data granules
provider: if we download a list of URLs we need to specify the provider.
threads: parallel number of threads to use to download the files, adjust as necessary, default = 8

Returns:
List of downloaded files
"""
if isinstance(granules, DataGranule):
granules = [granules]
try:
results = earthaccess.__store__.get(granules, local_path, provider, threads)
except AttributeError as err:
Expand Down
36 changes: 18 additions & 18 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,12 @@ def get(
Returns:
List of downloaded files
"""
if local_path is None:
local_path = os.path.join(
".",
"data",
f"{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}",
)
if len(granules):
files = self._get(granules, local_path, provider, threads)
return files
Expand All @@ -464,7 +470,7 @@ def get(
def _get(
self,
granules: Union[List[DataGranule], List[str]],
local_path: Optional[str] = None,
local_path: str,
provider: Optional[str] = None,
threads: int = 8,
) -> Union[None, List[str]]:
Expand Down Expand Up @@ -492,7 +498,7 @@ def _get(
def _get_urls(
self,
granules: List[str],
local_path: Optional[str] = None,
local_path: str,
provider: Optional[str] = None,
threads: int = 8,
) -> Union[None, List[str]]:
Expand All @@ -509,22 +515,21 @@ def _get_urls(
s3_fs = self.get_s3fs_session(provider=provider)
# TODO: make this parallel or concurrent
for file in data_links:
file_name = file.split("/")[-1]
s3_fs.get(file, local_path)
print(f"Retrieved: {file} to {local_path}")
file_name = os.path.join(local_path, os.path.basename(file))
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files

else:
# if we are not in AWS
return self._download_onprem_granules(data_links, local_path, threads)
return None

@_get.register
def _get_granules(
self,
granules: List[DataGranule],
local_path: Optional[str] = None,
local_path: str,
provider: Optional[str] = None,
threads: int = 8,
) -> Union[None, List[str]]:
Expand Down Expand Up @@ -557,14 +562,13 @@ def _get_granules(
# TODO: make this async
for file in data_links:
s3_fs.get(file, local_path)
file_name = file.split("/")[-1]
print(f"Retrieved: {file} to {local_path}")
file_name = os.path.join(local_path, os.path.basename(file))
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files
else:
# if the data is cloud based bu we are not in AWS it will be downloaded as if it was on prem
return self._download_onprem_granules(data_links, local_path, threads)
return None

def _download_file(self, url: str, directory: str) -> str:
"""
Expand Down Expand Up @@ -598,10 +602,10 @@ def _download_file(self, url: str, directory: str) -> str:
raise Exception
else:
print(f"File {local_filename} already downloaded")
return local_filename
return local_path

def _download_onprem_granules(
self, urls: List[str], directory: Optional[str] = None, threads: int = 8
self, urls: List[str], directory: str, threads: int = 8
) -> List[Any]:
"""
downloads a list of URLS into the data directory.
Expand All @@ -618,14 +622,10 @@ def _download_onprem_granules(
"We need to be logged into NASA EDL in order to download data granules"
)
return []
if directory is None:
directory_prefix = f"./data/{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}"
else:
directory_prefix = directory
if not os.path.exists(directory_prefix):
os.makedirs(directory_prefix)
if not os.path.exists(directory):
os.makedirs(directory)

arguments = [(url, directory_prefix) for url in urls]
arguments = [(url, directory) for url in urls]
results = pqdm(
arguments,
self._download_file,
Expand Down
12 changes: 5 additions & 7 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# package imports
import logging
import os
import shutil
import unittest

import earthaccess
Expand Down Expand Up @@ -69,16 +68,15 @@ def test_granules_search_returns_valid_results(kwargs):
assertions.assertTrue(len(results) <= 10)


def test_earthaccess_api_can_download_granules():
@pytest.mark.parametrize("selection", [0, slice(None)])
def test_earthaccess_api_can_download_granules(tmp_path, selection):
results = earthaccess.search_data(
count=2,
short_name="ATL08",
cloud_hosted=True,
bounding_box=(-92.86, 16.26, -91.58, 16.97),
)
local_path = "./tests/integration/data/ATL08"
assertions.assertIsInstance(results, list)
assertions.assertTrue(len(results) <= 2)
files = earthaccess.download(results, local_path=local_path)
result = results[selection]
files = earthaccess.download(result, str(tmp_path))
assertions.assertIsInstance(files, list)
shutil.rmtree(local_path)
assert all(os.path.exists(f) for f in files)
Loading