diff --git a/flask_annex/compat.py b/flask_annex/compat.py index d16e859..55a35f6 100644 --- a/flask_annex/compat.py +++ b/flask_annex/compat.py @@ -1,3 +1,5 @@ +import fnmatch +import os import sys # ----------------------------------------------------------------------------- @@ -10,3 +12,17 @@ string_types = (str, unicode) # flake8: noqa else: string_types = (str,) + + +def recursive_glob(rootdir, pattern='*'): + """Search recursively for files matching a specified pattern. + + Adapted from http://stackoverflow.com/questions/2186525/use-a-glob-to-find-files-recursively-in-python + """ + + matches = [] + for root, dirnames, filenames in os.walk(rootdir): + for filename in fnmatch.filter(filenames, pattern): + matches.append(os.path.join(root, filename)) + + return matches diff --git a/flask_annex/file.py b/flask_annex/file.py index 4854b20..36b9586 100644 --- a/flask_annex/file.py +++ b/flask_annex/file.py @@ -1,12 +1,11 @@ import errno -import glob import os import shutil import flask from .base import AnnexBase -from .compat import string_types +from .compat import recursive_glob, string_types # ----------------------------------------------------------------------------- @@ -57,8 +56,8 @@ def get_file(self, key, out_file): shutil.copyfileobj(in_fp, out_file) def list_keys(self, prefix): - pattern = self._get_filename('{}*'.format(prefix)) - filenames = glob.glob(pattern) + root = self._get_filename(prefix) + filenames = [root] if os.path.isfile(root) else recursive_glob(root) return [ os.path.relpath(filename, self._root_path) diff --git a/tests/helpers.py b/tests/helpers.py index 1f61f73..2bf743c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -66,6 +66,18 @@ def test_get_filename(self, tmpdir, annex): assert open(out_filename).read() == '1\n' def test_list_keys(self, annex): + annex.save_file('foo/qux.txt', BytesIO(b'3\n')) + assert sorted(annex.list_keys('foo/')) == [ + 'foo/bar.txt', + 'foo/baz.json', + 'foo/qux.txt', + ] + + # check that dangling '/' is not relevant + assert sorted(annex.list_keys('foo/')) == \ + sorted(annex.list_keys('foo')) + + def test_list_keys_nested(self, annex): assert sorted(annex.list_keys('foo/')) == [ 'foo/bar.txt', 'foo/baz.json',