diff --git a/rows/utils.py b/rows/utils.py index adadd0e0..c908920a 100644 --- a/rows/utils.py +++ b/rows/utils.py @@ -17,7 +17,9 @@ from __future__ import unicode_literals +import bz2 import cgi +import gzip import mimetypes import os import tempfile @@ -31,6 +33,11 @@ except ImportError: from urllib.parse import urlparse # Python 3 +try: + import lzma +except ImportError: + lzma = None + try: import magic except ImportError: @@ -52,6 +59,8 @@ # old versions of urllib3 or requests pass +import rows +from rows.plugins.utils import get_filename_and_fobj # TODO: should get this information from the plugins TEXT_PLAIN = { @@ -160,14 +169,7 @@ def plugin_name_by_mime_type(mime_type, mime_name, file_extension): None) -def detect_local_source(path, content, mime_type=None, encoding=None): - - # TODO: may add sample_size - - filename = os.path.basename(path) - parts = filename.split('.') - extension = parts[-1] if len(parts) > 1 else None - +def describe_file_type(filename, content, mime_type=None, encoding=None): if magic is not None: detected = magic.detect_from_content(content) encoding = detected.encoding or encoding @@ -179,6 +181,19 @@ def detect_local_source(path, content, mime_type=None, encoding=None): mime_name = None mime_type = mime_type or mimetypes.guess_type(filename)[0] + return mime_type, encoding, mime_name + +def detect_local_source(path, content, mime_type=None, encoding=None): + # TODO: may add sample_size + + filename = os.path.basename(path) + parts = filename.split('.') + extension = parts[-1] if len(parts) > 1 else None + + args = (filename, content) + kwargs = dict(mime_type=mime_type, encoding=encoding) + mime_type, encoding, mime_name = describe_file_type(*args, **kwargs) + plugin_name = plugin_name_by_mime_type(mime_type, mime_name, extension) if encoding == 'binary': encoding = None @@ -299,3 +314,37 @@ def export_to_uri(table, uri, *args, **kwargs): raise ValueError('Plugin (export) "{}" not found'.format(plugin_name)) return export_function(table, uri, *args, **kwargs) + + +def decompress(path_or_fobj, algorithm=None, **kwargs): + """ + Given a bz2, gzip or lzma file returns a decompressed file object. All + kwargs are passed to either `bz2.openn`, `gzip.open` or `lzma.open`. + :param path_or_fobj: (str) path to a bz2, gzip or lzma file + :param algorithm: (str) either bz2, gzip or lzma + """ + filename, fobj = get_filename_and_fobj(path_or_fobj, 'rb') + + extension = None + if not algorithm and filename: + _, extension = os.path.splitext(filename) + extension = extension.replace('.', '') + + open_mapping = dict( + bz2=bz2.BZ2File, + gzip=gzip.GzipFile, + gz=gzip.GzipFile, + lzma=getattr(lzma, 'LZMAFile'), # lzma might not be available + xz=getattr(lzma, 'LZMAFile') # lzma might not be available + ) + open_compressed = open_mapping.get(algorithm or extension) + + if not open_compressed: + raise RuntimeError(( + 'Unknown extension and/or invalid algorithm: options are: bz2, ' + 'gzip, gz, lzma or xz ({})'.format(filename or fobj) + )) + + target = filename or fobj + with open_compressed(target, 'rb', **kwargs) as handler: + return handler.read() diff --git a/tests/tests_utils.py b/tests/tests_utils.py index eaef5b56..3a36d935 100644 --- a/tests/tests_utils.py +++ b/tests/tests_utils.py @@ -17,9 +17,21 @@ from __future__ import unicode_literals +import bz2 +import contextlib +import gzip +import os +import shutil import tempfile import unittest +try: + import lzma +except ImportError: + lzma = None + +import six + import rows.utils import tests.utils as utils @@ -70,3 +82,84 @@ def test_local_file_sample_size(self): # TODO: test normalize_mime_type # TODO: test plugin_name_by_mime_type # TODO: test plugin_name_by_uri + + +class TestUtilsDecompress(unittest.TestCase): + + def setUp(self): + self.contents = b'I use rows and it is awesome!' + self.temp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.temp) + + @contextlib.contextmanager + def _create_file(self, algorithm, extension=True): + extension = '.{}'.format(algorithm.__name__) if extension else '' + filename = 'test{}'.format(extension) + filepath = os.path.join(self.temp, filename) + + open_mapping = { + bz2: bz2.BZ2File, + gzip: gzip.GzipFile, + lzma: getattr(lzma, 'LZMAFile') + } + open_method = open_mapping.get(algorithm) + with open_method(filepath, 'wb') as obj: + obj.write(self.contents) + + with open(filepath, 'rb') as obj: + yield filepath, obj + + def _test_decompress_with_path(self, algorithm): + with self._create_file(algorithm) as path_and_obj: + path, _ = path_and_obj + decompressed = rows.utils.decompress(path) + self.assertEqual(self.contents, decompressed) + + def _test_decompress_with_file_obj(self, algorithm,): + with self._create_file(algorithm) as path_and_obj: + _, obj = path_and_obj + decompressed = rows.utils.decompress(obj) + self.assertEqual(self.contents, decompressed) + + def _test_decompress_without_extension(self, algorithm): + with self._create_file(algorithm, False) as path_and_obj: + path, _ = path_and_obj + decompressed = rows.utils.decompress(path, algorithm.__name__) + self.assertEqual(self.contents, decompressed) + + def test_decompress_bz2_with_path(self): + self._test_decompress_with_path(bz2) + + def test_decompress_gzip_with_path(self): + self._test_decompress_with_path(gzip) + + @unittest.skipIf(not lzma, 'No lzma module available') + def test_decompress_lzma_with_path(self): + self._test_decompress_with_path(lzma) + + def test_decompress_bz2_with_file_object(self): + self._test_decompress_with_file_obj(bz2) + + def test_decompress_gzip_with_file_object(self): + self._test_decompress_with_file_obj(gzip) + + @unittest.skipIf(not lzma, 'No lzma module available') + def test_decompress_lzma_with_file_object(self): + self._test_decompress_with_file_obj(lzma) + + def test_decompress_bz2_without_extension(self): + self._test_decompress_without_extension(bz2) + + def test_decompress_gzip_without_extension(self): + self._test_decompress_without_extension(gzip) + + @unittest.skipIf(not lzma, 'No lzma module available') + def test_decompress_lzma_without_extension(self): + self._test_decompress_without_extension(lzma) + + def test_decompress_with_incompatible_file(self): + with self.assertRaises(RuntimeError): + with tempfile.NamedTemporaryFile() as tmp: + rows.utils.decompress(tmp.name)