diff --git a/HISTORY.md b/HISTORY.md index a5e8a226..99da977f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,8 +1,9 @@ # cloudpathlib Changelog -## v0.7.2 (UNRELEASED) +## v0.8.0 (2022-05-19) - Fixed pickling of `CloudPath` objects not working. ([Issue #223](https://github.com/drivendataorg/cloudpathlib/issues/223), [PR #224](https://github.com/drivendataorg/cloudpathlib/pull/224)) + - Added functionality to [push the MIME (media) type to the content type property on cloud providers by default. ([Issue #222](https://github.com/drivendataorg/cloudpathlib/issues/222), [PR #226](https://github.com/drivendataorg/cloudpathlib/pull/226)) ## v0.7.1 (2022-04-06) diff --git a/cloudpathlib/azure/azblobclient.py b/cloudpathlib/azure/azblobclient.py index ddb0fd82..17c07984 100644 --- a/cloudpathlib/azure/azblobclient.py +++ b/cloudpathlib/azure/azblobclient.py @@ -1,7 +1,8 @@ from datetime import datetime +import mimetypes import os from pathlib import Path, PurePosixPath -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union from ..client import Client, register_client_class @@ -12,7 +13,7 @@ try: from azure.core.exceptions import ResourceNotFoundError - from azure.storage.blob import BlobServiceClient, BlobProperties + from azure.storage.blob import BlobServiceClient, BlobProperties, ContentSettings except ModuleNotFoundError: implementation_registry["azure"].dependencies_loaded = False @@ -32,6 +33,7 @@ def __init__( connection_string: Optional[str] = None, blob_service_client: Optional["BlobServiceClient"] = None, local_cache_dir: Optional[Union[str, os.PathLike]] = None, + content_type_method: Optional[Callable] = mimetypes.guess_type, ): """Class constructor. Sets up a [`BlobServiceClient`]( https://docs.microsoft.com/en-us/python/api/azure-storage-blob/azure.storage.blob.blobserviceclient?view=azure-python). @@ -68,6 +70,8 @@ def __init__( https://docs.microsoft.com/en-us/python/api/azure-storage-blob/azure.storage.blob.blobserviceclient?view=azure-python). local_cache_dir (Optional[Union[str, os.PathLike]]): Path to directory to use as cache for downloaded files. If None, will use a temporary directory. + content_type_method (Optional[Callable]): Function to call to guess media type (mimetype) when + writing a file to the cloud. Defaults to `mimetypes.guess_type`. Must return a tuple (content type, content encoding). """ if connection_string is None: connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING", None) @@ -86,7 +90,7 @@ def __init__( "Credentials are required; see docs for options." ) - super().__init__(local_cache_dir=local_cache_dir) + super().__init__(local_cache_dir=local_cache_dir, content_type_method=content_type_method) def _get_metadata(self, cloud_path: AzureBlobPath) -> Union["BlobProperties", Dict[str, Any]]: blob = self.service_client.get_blob_client( @@ -94,6 +98,8 @@ def _get_metadata(self, cloud_path: AzureBlobPath) -> Union["BlobProperties", Di ) properties = blob.get_blob_properties() + properties["content_type"] = properties.content_settings.content_type + return properties def _download_file( @@ -220,7 +226,18 @@ def _upload_file( container=cloud_path.container, blob=cloud_path.blob ) - blob.upload_blob(Path(local_path).read_bytes(), overwrite=True) # type: ignore + extra_args = {} + if self.content_type_method is not None: + content_type, content_encoding = self.content_type_method(str(local_path)) + + if content_type is not None: + extra_args["content_type"] = content_type + if content_encoding is not None: + extra_args["content_encoding"] = content_encoding + + content_settings = ContentSettings(**extra_args) + + blob.upload_blob(Path(local_path).read_bytes(), overwrite=True, content_settings=content_settings) # type: ignore return cloud_path diff --git a/cloudpathlib/client.py b/cloudpathlib/client.py index e1cb30d2..47404b44 100644 --- a/cloudpathlib/client.py +++ b/cloudpathlib/client.py @@ -1,4 +1,5 @@ import abc +import mimetypes import os from pathlib import Path from tempfile import TemporaryDirectory @@ -25,7 +26,11 @@ class Client(abc.ABC, Generic[BoundedCloudPath]): _cloud_meta: CloudImplementation _default_client = None - def __init__(self, local_cache_dir: Optional[Union[str, os.PathLike]] = None): + def __init__( + self, + local_cache_dir: Optional[Union[str, os.PathLike]] = None, + content_type_method: Optional[Callable] = mimetypes.guess_type, + ): self._cloud_meta.validate_completeness() # setup caching and local versions of file and track if it is a tmp dir self._cache_tmp_dir = None @@ -34,6 +39,7 @@ def __init__(self, local_cache_dir: Optional[Union[str, os.PathLike]] = None): local_cache_dir = self._cache_tmp_dir.name self._local_cache_dir = Path(local_cache_dir) + self.content_type_method = content_type_method def __del__(self) -> None: # make sure temp is cleaned up if we created it diff --git a/cloudpathlib/gs/gsclient.py b/cloudpathlib/gs/gsclient.py index b8c79dcb..bdca6cc0 100644 --- a/cloudpathlib/gs/gsclient.py +++ b/cloudpathlib/gs/gsclient.py @@ -1,7 +1,8 @@ from datetime import datetime +import mimetypes import os from pathlib import Path, PurePosixPath -from typing import Any, Dict, Iterable, Optional, TYPE_CHECKING, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Tuple, Union from ..client import Client, register_client_class from ..cloudpath import implementation_registry @@ -34,6 +35,7 @@ def __init__( project: Optional[str] = None, storage_client: Optional["StorageClient"] = None, local_cache_dir: Optional[Union[str, os.PathLike]] = None, + content_type_method: Optional[Callable] = mimetypes.guess_type, ): """Class constructor. Sets up a [`Storage Client`](https://googleapis.dev/python/storage/latest/client.html). @@ -65,6 +67,8 @@ def __init__( https://googleapis.dev/python/storage/latest/client.html). local_cache_dir (Optional[Union[str, os.PathLike]]): Path to directory to use as cache for downloaded files. If None, will use a temporary directory. + content_type_method (Optional[Callable]): Function to call to guess media type (mimetype) when + writing a file to the cloud. Defaults to `mimetypes.guess_type`. Must return a tuple (content type, content encoding). """ if application_credentials is None: application_credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") @@ -81,7 +85,7 @@ def __init__( except DefaultCredentialsError: self.client = StorageClient.create_anonymous_client() - super().__init__(local_cache_dir=local_cache_dir) + super().__init__(local_cache_dir=local_cache_dir, content_type_method=content_type_method) def _get_metadata(self, cloud_path: GSPath) -> Optional[Dict[str, Any]]: bucket = self.client.bucket(cloud_path.bucket) @@ -94,6 +98,7 @@ def _get_metadata(self, cloud_path: GSPath) -> Optional[Dict[str, Any]]: "etag": blob.etag, "size": blob.size, "updated": blob.updated, + "content_type": blob.content_type, } def _download_file(self, cloud_path: GSPath, local_path: Union[str, os.PathLike]) -> Path: @@ -207,7 +212,12 @@ def _upload_file(self, local_path: Union[str, os.PathLike], cloud_path: GSPath) bucket = self.client.bucket(cloud_path.bucket) blob = bucket.blob(cloud_path.blob) - blob.upload_from_filename(str(local_path)) + extra_args = {} + if self.content_type_method is not None: + content_type, _ = self.content_type_method(str(local_path)) + extra_args["content_type"] = content_type + + blob.upload_from_filename(str(local_path), **extra_args) return cloud_path diff --git a/cloudpathlib/local/localclient.py b/cloudpathlib/local/localclient.py index 2ed1d001..cce9fcc0 100644 --- a/cloudpathlib/local/localclient.py +++ b/cloudpathlib/local/localclient.py @@ -1,10 +1,11 @@ import atexit from hashlib import md5 +import mimetypes import os from pathlib import Path, PurePosixPath import shutil from tempfile import TemporaryDirectory -from typing import Iterable, List, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union from ..client import Client from .localpath import LocalPath @@ -21,6 +22,7 @@ def __init__( *args, local_cache_dir: Optional[Union[str, os.PathLike]] = None, local_storage_dir: Optional[Union[str, os.PathLike]] = None, + content_type_method: Optional[Callable] = mimetypes.guess_type, **kwargs, ): # setup caching and local versions of file. use default temp dir if not provided @@ -28,7 +30,7 @@ def __init__( local_storage_dir = self.get_default_storage_dir() self._local_storage_dir = Path(local_storage_dir) - super().__init__(local_cache_dir=local_cache_dir) + super().__init__(local_cache_dir=local_cache_dir, content_type_method=content_type_method) @classmethod def get_default_storage_dir(cls) -> Path: @@ -132,6 +134,17 @@ def _upload_file( shutil.copy(local_path, dst) return cloud_path + def _get_metadata(self, cloud_path: "LocalPath") -> Dict: + # content_type is the only metadata we test currently + if self.content_type_method is None: + content_type_method = lambda x: (None, None) + else: + content_type_method = self.content_type_method + + return { + "content_type": content_type_method(str(self._cloud_path_to_local(cloud_path)))[0], + } + _temp_dirs_to_clean: List[TemporaryDirectory] = [] diff --git a/cloudpathlib/s3/s3client.py b/cloudpathlib/s3/s3client.py index 7a321498..280b491a 100644 --- a/cloudpathlib/s3/s3client.py +++ b/cloudpathlib/s3/s3client.py @@ -1,6 +1,7 @@ +import mimetypes import os from pathlib import Path, PurePosixPath -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union from ..client import Client, register_client_class @@ -35,6 +36,7 @@ def __init__( local_cache_dir: Optional[Union[str, os.PathLike]] = None, endpoint_url: Optional[str] = None, boto3_transfer_config: Optional["TransferConfig"] = None, + content_type_method: Optional[Callable] = mimetypes.guess_type, ): """Class constructor. Sets up a boto3 [`Session`]( https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html). @@ -63,6 +65,8 @@ def __init__( Parameterize it to access a customly deployed S3-compatible object store such as MinIO, Ceph or any other. boto3_transfer_config (Optional[dict]): Instantiated TransferConfig for managing s3 transfers. (https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.TransferConfig) + content_type_method (Optional[Callable]): Function to call to guess media type (mimetype) when + writing a file to the cloud. Defaults to `mimetypes.guess_type`. Must return a tuple (content type, content encoding). """ endpoint_url = endpoint_url or os.getenv("AWS_ENDPOINT_URL") if boto3_session is not None: @@ -93,7 +97,7 @@ def __init__( self.boto3_transfer_config = boto3_transfer_config - super().__init__(local_cache_dir=local_cache_dir) + super().__init__(local_cache_dir=local_cache_dir, content_type_method=content_type_method) def _get_metadata(self, cloud_path: S3Path) -> Dict[str, Any]: data = self.s3.ObjectSummary(cloud_path.bucket, cloud_path.key).get() @@ -102,7 +106,7 @@ def _get_metadata(self, cloud_path: S3Path) -> Dict[str, Any]: "last_modified": data["LastModified"], "size": data["ContentLength"], "etag": data["ETag"], - "mime": data["ContentType"], + "content_type": data["ContentType"], "extra": data["Metadata"], } @@ -250,7 +254,16 @@ def _remove(self, cloud_path: S3Path) -> None: def _upload_file(self, local_path: Union[str, os.PathLike], cloud_path: S3Path) -> S3Path: obj = self.s3.Object(cloud_path.bucket, cloud_path.key) - obj.upload_file(str(local_path), Config=self.boto3_transfer_config) + extra_args = {} + + if self.content_type_method is not None: + content_type, content_encoding = self.content_type_method(str(local_path)) + if content_type is not None: + extra_args["ContentType"] = content_type + if content_encoding is not None: + extra_args["ContentEncoding"] = content_encoding + + obj.upload_file(str(local_path), Config=self.boto3_transfer_config, ExtraArgs=extra_args) return cloud_path diff --git a/docs/docs/other_client_settings.md b/docs/docs/other_client_settings.md new file mode 100644 index 00000000..92a33e64 --- /dev/null +++ b/docs/docs/other_client_settings.md @@ -0,0 +1,51 @@ +# Other `Client` settings + +## Content type guessing (`content_type_method`) + +All of the clients support passing a `content_type_method` when they are instantiated. +This is a method that is used to guess the [MIME (media) type](https://en.wikipedia.org/wiki/Media_type) +(often called the "content type") of the file and set that on the cloud provider. + +By default, `content_type_method` use the Python built-in +[`guess_type`](https://docs.python.org/3/library/mimetypes.html#mimetypes.guess_type) +to set this content type. This guesses based on the file extension, and may not always get the correct type. +In these cases, you can set `content_type_method` to your own function that gets the proper type; for example, by +reading the file content or by looking it up in a dictionary of filename-to-media-type mappings that you maintain. + +If you set a custom method, it should follow the signature of `guess_type` and return a tuple of the form: +`(content_type, content_encoding)`; for example, `("text/css", None)`. + +If you set `content_type_method` to None, it will do whatever the default of the cloud provider's SDK does. This +varies from provider to provider. + +Here is an example of using a custom `content_type_method`. + +```python +import mimetypes +from pathlib import Path + +from cloudpathlib import S3Client, CloudPath + +def my_content_type(path): + # do lookup for content types I define; fallback to + # guess_type for anything else + return { + ".potato": ("application/potato", None), + }.get(Path(path).suffix, mimetypes.guess_type(path)) + + +# create a client with my custom content type +client = S3Client(content_type_method=my_content_type) + +# To use this same method for every cloud path, set our client as the default. +# This is optional, and you could use client.CloudPath to create paths instead. +client.set_as_default_client() + +# create a cloud path +cp1 = CloudPath("s3://cloudpathlib-test-bucket/i_am_a.potato") +cp1.write_text("hello") + +# check content type with boto3 +print(client.s3.Object(cp1.bucket, cp1.key).content_type) +#> application/potato +``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 367ba8a5..470f4320 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -20,6 +20,7 @@ nav: - Authentication: "authentication.md" - Caching: "caching.ipynb" - AnyPath: "anypath-polymorphism.md" + - Other Client settings: "other_client_settings.md" - Testing code that uses cloudpathlib: "testing_mocked_cloudpathlib.ipynb" - Integrations: "integrations.md" - Changelog: "changelog.md" diff --git a/setup.py b/setup.py index 44e54ca7..965809a2 100644 --- a/setup.py +++ b/setup.py @@ -60,5 +60,5 @@ def load_requirements(path: Path): "Source Code": "https://github.com/drivendataorg/cloudpathlib", }, url="https://github.com/drivendataorg/cloudpathlib", - version="0.7.1", + version="0.8.0", ) diff --git a/tests/mock_clients/mock_azureblob.py b/tests/mock_clients/mock_azureblob.py index b8f92670..6861f5d3 100644 --- a/tests/mock_clients/mock_azureblob.py +++ b/tests/mock_clients/mock_azureblob.py @@ -20,6 +20,8 @@ def __init__(self, *args, **kwargs): self.tmp_path = Path(self.tmp.name) / "test_case_copy" shutil.copytree(TEST_ASSETS, self.tmp_path / test_dir) + self.metadata_cache = {} + @classmethod def from_connection_string(cls, *args, **kwargs): return cls() @@ -28,7 +30,7 @@ def __del__(self): self.tmp.cleanup() def get_blob_client(self, container, blob): - return MockBlobClient(self.tmp_path, blob) + return MockBlobClient(self.tmp_path, blob, service_client=self) def get_container_client(self, container): return MockContainerClient(self.tmp_path) @@ -37,10 +39,12 @@ def get_container_client(self, container): class MockBlobClient: - def __init__(self, root, key): + def __init__(self, root, key, service_client=None): self.root = root self.key = key + self.service_client = service_client + @property def url(self): return self.root / self.key @@ -53,6 +57,9 @@ def get_blob_properties(self): "name": self.key, "Last-Modified": datetime.fromtimestamp(path.stat().st_mtime), "ETag": "etag", + "content_type": self.service_client.metadata_cache.get( + self.root / self.key, None + ), } ) else: @@ -75,11 +82,16 @@ def delete_blob(self): path.unlink() delete_empty_parents_up_to_root(path=path, root=self.root) - def upload_blob(self, data, overwrite): + def upload_blob(self, data, overwrite, content_settings=None): path = self.root / self.key path.parent.mkdir(parents=True, exist_ok=True) path.write_bytes(data) + if content_settings is not None: + self.service_client.metadata_cache[ + self.root / self.key + ] = content_settings.content_type + class MockStorageStreamDownloader: def __init__(self, root, key): diff --git a/tests/mock_clients/mock_gs.py b/tests/mock_clients/mock_gs.py index 23d0ac39..60b4a01e 100644 --- a/tests/mock_clients/mock_gs.py +++ b/tests/mock_clients/mock_gs.py @@ -16,6 +16,8 @@ def __init__(self, *args, **kwargs): self.tmp_path = Path(self.tmp.name) / "test_case_copy" shutil.copytree(TEST_ASSETS, self.tmp_path / test_dir) + self.metadata_cache = {} + @classmethod def create_anonymous_client(cls): return cls() @@ -28,16 +30,17 @@ def __del__(self): self.tmp.cleanup() def bucket(self, bucket): - return MockBucket(self.tmp_path) + return MockBucket(self.tmp_path, client=self) return MockClient class MockBlob: - def __init__(self, root, name): + def __init__(self, root, name, client=None): self.bucket = root self.name = str(PurePosixPath(name)) self.metadata = None + self.client = client def delete(self): path = self.bucket / self.name @@ -56,12 +59,14 @@ def patch(self): if "updated" in self.metadata: (self.bucket / self.name).touch() - def upload_from_filename(self, filename): + def upload_from_filename(self, filename, content_type=None): data = Path(filename).read_bytes() path = self.bucket / self.name path.parent.mkdir(parents=True, exist_ok=True) path.write_bytes(data) + self.client.metadata_cache[self.bucket / self.name] = content_type + @property def etag(self): return "etag" @@ -76,13 +81,18 @@ def updated(self): path = self.bucket / self.name return datetime.fromtimestamp(path.stat().st_mtime) + @property + def content_type(self): + return self.client.metadata_cache.get(self.bucket / self.name, None) + class MockBucket: - def __init__(self, name): + def __init__(self, name, client=None): self.name = name + self.client = client def blob(self, blob): - return MockBlob(self.name, blob) + return MockBlob(self.name, blob, client=self.client) def copy_blob(self, blob, destination_bucket, new_name): data = (self.name / blob.name).read_bytes() @@ -92,14 +102,14 @@ def copy_blob(self, blob, destination_bucket, new_name): def get_blob(self, blob): if (self.name / blob).is_file(): - return MockBlob(self.name, blob) + return MockBlob(self.name, blob, client=self.client) else: return None def list_blobs(self, max_results=None, prefix=None): path = self.name if prefix is None else self.name / prefix items = [ - MockBlob(self.name, f.relative_to(self.name)) + MockBlob(self.name, f.relative_to(self.name), client=self.client) for f in path.glob("**/*") if f.is_file() and not f.name.startswith(".") ] diff --git a/tests/mock_clients/mock_s3.py b/tests/mock_clients/mock_s3.py index 138479d9..83074fe8 100644 --- a/tests/mock_clients/mock_s3.py +++ b/tests/mock_clients/mock_s3.py @@ -27,29 +27,32 @@ def __init__(self, *args, **kwargs): self.tmp_path = Path(self.tmp.name) / "test_case_copy" shutil.copytree(TEST_ASSETS, self.tmp_path / test_dir) + self.metadata_cache = {} + def __del__(self): self.tmp.cleanup() def resource(self, item, endpoint_url, config=None): - return MockBoto3Resource(self.tmp_path) + return MockBoto3Resource(self.tmp_path, session=self) def client(self, item, endpoint_url, config=None): - return MockBoto3Client(self.tmp_path) + return MockBoto3Client(self.tmp_path, session=self) return MockBoto3Session class MockBoto3Resource: - def __init__(self, root): + def __init__(self, root, session=None): self.root = root self.download_config = None self.upload_config = None + self.session = session def Bucket(self, bucket): - return MockBoto3Bucket(self.root) + return MockBoto3Bucket(self.root, session=self.session) def ObjectSummary(self, bucket, key): - return MockBoto3ObjectSummary(self.root, key) + return MockBoto3ObjectSummary(self.root, key, session=self.session) def Object(self, bucket, key): return MockBoto3Object(self.root, key, self) @@ -90,11 +93,14 @@ def download_file(self, to_path, Config=None): # track config to make sure it's used in tests self.resource.download_config = Config - def upload_file(self, from_path, Config=None): + def upload_file(self, from_path, Config=None, ExtraArgs=None): self.path.parent.mkdir(parents=True, exist_ok=True) self.path.write_bytes(Path(from_path).read_bytes()) self.resource.upload_config = Config + if ExtraArgs is not None: + self.resource.session.metadata_cache[self.path] = ExtraArgs.pop("ContentType", None) + def delete(self): self.path.unlink() delete_empty_parents_up_to_root(self.path, self.root) @@ -109,8 +115,9 @@ def copy(self, source): class MockBoto3ObjectSummary: - def __init__(self, root, path): + def __init__(self, root, path, session=None): self.path = root / path + self.session = session def get(self): if not self.path.exists() or self.path.is_dir(): @@ -120,41 +127,44 @@ def get(self): "LastModified": datetime.fromtimestamp(self.path.stat().st_mtime), "ContentLength": None, "ETag": hash(str(self.path)), - "ContentType": None, + "ContentType": self.session.metadata_cache.get(self.path, None), "Metadata": {}, } class MockBoto3Bucket: - def __init__(self, root): + def __init__(self, root, session=None): self.root = root + self.session = session @property def objects(self): - return MockObjects(self.root) + return MockObjects(self.root, session=self.session) class MockObjects: - def __init__(self, root): + def __init__(self, root, session=None): self.root = root + self.session = session def filter(self, Prefix=""): path = self.root / Prefix if path.is_file(): - return MockCollection([PurePosixPath(path)], self.root) + return MockCollection([PurePosixPath(path)], self.root, session=self.session) items = [ PurePosixPath(f) for f in path.glob("**/*") if f.is_file() and not f.name.startswith(".") ] - return MockCollection(items, self.root) + return MockCollection(items, self.root, session=self.session) class MockCollection: - def __init__(self, items, root): + def __init__(self, items, root, session=None): self.root = root + self.session = session s3_obj = collections.namedtuple("s3_obj", "key bucket_name") self.full_paths = items @@ -177,11 +187,12 @@ def delete(self): class MockBoto3Client: - def __init__(self, root): + def __init__(self, root, session=None): self.root = root + self.session = session def get_paginator(self, api): - return MockBoto3Paginator(self.root) + return MockBoto3Paginator(self.root, session=self.session) @property def exceptions(self): @@ -190,9 +201,10 @@ def exceptions(self): class MockBoto3Paginator: - def __init__(self, root, per_page=2): + def __init__(self, root, per_page=2, session=None): self.root = root self.per_page = per_page + self.session = session def paginate(self, Bucket=None, Prefix="", Delimiter=None): new_dir = self.root / Prefix diff --git a/tests/test_client.py b/tests/test_client.py index 96e8b4dd..3f17d84f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,11 @@ +import mimetypes +import os +from pathlib import Path +import random +import string + from cloudpathlib import CloudPath +from cloudpathlib.s3.s3client import S3Client def test_default_client_instantiation(rig): @@ -41,3 +48,65 @@ def test_different_clients(rig): assert p.client is not p2.client assert p._local is not p2._local + + +def test_content_type_setting(rig, tmpdir): + random.seed(1337) # reproducible file names + + mimes = [ + (".css", "text/css"), + (".html", "text/html"), + (".js", "application/javascript"), + (".mp3", "audio/mpeg"), + (".mp4", "video/mp4"), + (".jpeg", "image/jpeg"), + (".png", "image/png"), + ] + + def _test_write_content_type(suffix, expected, rig_ref, check=True): + filename = "".join(random.choices(string.ascii_letters, k=8)) + suffix + filepath = Path(tmpdir / filename) + filepath.write_text("testing") + + cp = rig_ref.create_cloud_path(filename) + cp.upload_from(filepath) + + meta = cp.client._get_metadata(cp) + + if check: + assert meta["content_type"] == expected + + # should guess by default + for suffix, content_type in mimes: + _test_write_content_type(suffix, content_type, rig) + + # None does whatever library default is; not checked, just ensure + # we don't throw an error + for suffix, content_type in mimes: + _test_write_content_type(suffix, content_type, rig, check=False) + + # custom mime type method + def my_content_type(path): + # do lookup for content types I define; fallback to + # guess_type for anything else + return { + ".potato": ("application/potato", None), + }.get(Path(path).suffix, mimetypes.guess_type(path)) + + mimes.append((".potato", "application/potato")) + + # see if testing custom s3 endpoint, make sure to pass the url to the constructor + kwargs = {} + custom_endpoint = os.getenv("CUSTOM_S3_ENDPOINT", "https://s3.us-west-1.drivendatabws.com") + if ( + rig.client_class is S3Client + and rig.live_server + and custom_endpoint in rig.create_cloud_path("").client.client._endpoint.host + ): + kwargs["endpoint_url"] = custom_endpoint + + # set up default client to use content_type_method + rig.client_class(content_type_method=my_content_type, **kwargs).set_as_default_client() + + for suffix, content_type in mimes: + _test_write_content_type(suffix, content_type, rig)