Skip to content

Commit

Permalink
Simple download manager
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 371030431
  • Loading branch information
kho authored and fedjax authors committed Apr 29, 2021
1 parent 6bf1fd9 commit 210f280
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.
"""Placeholder file for the datasets package that is not implemented."""

from fedjax.experimental.datasets import downloads # pylint: disable=g-bad-import-order
103 changes: 103 additions & 0 deletions fedjax/experimental/datasets/downloads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple download manager."""

import math
import os.path
import sys
import time
from typing import Callable, Iterator, Optional
import urllib.parse

import requests


def progress(n: int) -> Iterator[int]:
"""A simple generator for tracking progress."""
start = time.time()
last_log = -1
for i in range(n):
yield i
elapsed = time.time() - start
if elapsed - last_log >= 1:
last_log = elapsed
eta = elapsed / (i + 1) * (n - i - 1)
log(f'\r{100*(i + 1)/n:3.0f}%, ETA: {format_duration(eta)}', end='')
log(f'\r100%, elapsed: {format_duration(time.time() - start)}')


def maybe_download(url: str,
cache_dir: Optional[str] = None,
progress: Callable[[int], Iterator[int]] = progress) -> str:
"""Downloads `url` to local disk.
Args:
url: URL to download from.
cache_dir: Where to cache the file. If None, uses default_cache_dir().
progress: A callable that yields like range(n), for tracking progress.
Returns:
Path to local file.
"""
# TODO(wuke): Avoid race conditions when downloading the same file from
# different threads/processes at the same time.
if cache_dir is None:
cache_dir = default_cache_dir()
os.makedirs(cache_dir, exist_ok=True)
path = os.path.join(cache_dir,
os.path.basename(urllib.parse.urlparse(url).path))
if os.path.exists(path):
log(f'Reusing cached file {path!r}')
else:
log(f'Downloading {url!r} to {path!r}')
with open(path + '.partial', 'wb') as fo:
r = requests.get(url, stream=True)
r.raise_for_status()
length = int(r.headers['content-length'])
block_size = 1 << 18
for _ in progress((length + block_size - 1) // block_size):
fo.write(r.raw.read(block_size))
os.rename(path + '.partial', path)
return path


def format_duration(seconds: float) -> str:
"""Formats duration in seconds into hours/minutes/seconds."""
if seconds < 60:
return f'{seconds:.0f}s'
elif seconds < 3600:
minutes = math.floor(seconds / 60)
seconds -= minutes * 60
return f'{minutes}m{seconds:.0f}s'
else:
hours = math.floor(seconds / 3600)
seconds -= hours * 3600
minutes = math.floor(seconds / 60)
seconds -= minutes * 60
return f'{hours}h{minutes}m{seconds:.0f}s'


def log(*args, **kwargs):
print(*args, flush=True, file=sys.stderr, **kwargs)


def default_cache_dir() -> str:
"""Returns default local file cache directory."""
running_on_colab = 'google.colab' in sys.modules
if running_on_colab:
base_dir = '/tmp'
else:
base_dir = os.path.expanduser('~')
cache_dir = os.path.join(base_dir, '.cache/fedjax')
return cache_dir
95 changes: 95 additions & 0 deletions fedjax/experimental/datasets/downloads_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for fedjax.google.experimental.datasets.downloads.
Can only be run locally because of network access.
"""

import binascii
import hashlib
import os
import os.path
import shutil

from absl import flags
from absl.testing import absltest
from fedjax.experimental.datasets import downloads

FLAGS = flags.FLAGS


class DownloadsTest(absltest.TestCase):
SOURCE = 'https://storage.googleapis.com/tff-datasets-public/shakespeare.sqlite.lzma'
HEXDIGEST = b'd3d11fceb9e105439ac6f4d52af6efafed5a2a1e1eb24c5bd2dd54ced242f5c4'

def validate_file(self, path):
h = hashlib.sha256()
with open(path, 'rb') as f:
data = f.read()
h.update(data)
self.assertEqual(binascii.hexlify(h.digest()), self.HEXDIGEST)
# Don't use assertLen. It dumps the entire bytes object when failing.
length = len(data)
self.assertEqual(length, 1329828)

def setUp(self):
super().setUp()
# This will be set in actual tests.
self._cache_dir = None

def tearDown(self):
super().tearDown()
if self._cache_dir is not None:
shutil.rmtree(self._cache_dir)

def _with_cache_dir(self, cache_dir, expected_path):
# First access, being copied.
path = downloads.maybe_download(self.SOURCE, cache_dir)
self._cache_dir = os.path.dirname(path)
self.assertEqual(path, expected_path)
self.validate_file(path)
stat = os.stat(path)
# ctime >= mtime because of renaming.
self.assertGreaterEqual(stat.st_ctime_ns, stat.st_mtime_ns)

# Second access, no copy.
path1 = downloads.maybe_download(self.SOURCE, cache_dir)
self.assertEqual(path1, path)
self.validate_file(path1)
# Check that the file hasn't been overwritten.
stat1 = os.stat(path1)
self.assertEqual(stat1.st_mtime_ns, stat.st_mtime_ns)

def test_default(self):
self._with_cache_dir(
cache_dir=None,
expected_path=os.path.join(
os.path.expanduser('~'), '.cache/fedjax/shakespeare.sqlite.lzma'))

def test_cache_dir(self):
cache_dir = os.path.join(FLAGS.test_tmpdir, 'test_cache_dir')
cache_file = os.path.join(cache_dir, 'shakespeare.sqlite.lzma')
self._with_cache_dir(cache_dir=cache_dir, expected_path=cache_file)

def test_partial_file(self):
cache_dir = os.path.join(FLAGS.test_tmpdir, 'test_partial_file')
cache_file = os.path.join(cache_dir, 'shakespeare.sqlite.lzma')
os.makedirs(cache_dir, exist_ok=True)
with open(cache_file + '.partial', 'w') as f:
f.write('hel')
self._with_cache_dir(cache_dir=cache_dir, expected_path=cache_file)


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
'jaxlib',
'msgpack',
'optax',
'requests',
'tensorflow-federated',
],
# TensorFlow doesn't yet support Python 3.9.
Expand Down

0 comments on commit 210f280

Please sign in to comment.