Skip to content

Commit

Permalink
Updates based on PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
0xlianhu committed Oct 3, 2023
1 parent 43011c0 commit 77847f2
Showing 1 changed file with 61 additions and 59 deletions.
120 changes: 61 additions & 59 deletions test/test_rohmu.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
import hashlib
import logging
import os
import pathlib
from tempfile import NamedTemporaryFile

import pytest
from rohmu import get_transfer, rohmufile
from rohmu.object_storage.base import BaseTransfer
from rohmu.rohmufile import create_sink_pipeline
from rohmu.typing import Metadata

from .base import CONSTANT_TEST_RSA_PRIVATE_KEY, CONSTANT_TEST_RSA_PUBLIC_KEY

EMPTY_FILE_SHA1 = "da39a3ee5e6b4b0d3255bfef95601890afd80709"

log = logging.getLogger(__name__)


@pytest.mark.parametrize(
"compress_algorithm, file_size",
[("lzma", 0), ("snappy", 0), ("zstd", 0), ("lzma", 1), ("snappy", 1), ("zstd", 1)],
ids=[
"test_lzma_0byte_file", "test_snappy_0byte_file", "test_zstd_0byte_file", "test_lzma_1byte_file",
"test_snappy_1byte_file", "test_zstd_1byte_file"
],
)
@pytest.mark.parametrize("compress_algorithm", ["lzma", "snappy", "zstd"], ids=["test_lzma", "test_snappy", "test_zstd"])
@pytest.mark.parametrize("file_size", [0, 1], ids=["0byte_file", "1byte_file"])
def test_rohmu_with_local_storage(compress_algorithm: str, file_size: int, tmp_path):
hash_algorithm = "sha1"
compression_level = 0
Expand All @@ -34,30 +33,20 @@ def test_rohmu_with_local_storage(compress_algorithm: str, file_size: int, tmp_p
with open(orig_file, "rb") as file_in:
assert file_in.read() == content

# 1 - Compressed the file
original_file_size = os.path.getsize(orig_file)
assert original_file_size == len(content)

# 1 - Compress the file
compressed_filepath = work_dir / "compressed" / "hello_compressed"
compressed_filepath.parent.mkdir(exist_ok=True)
hasher = hashlib.new(hash_algorithm)
input_obj = open(orig_file, "rb")
output_obj = NamedTemporaryFile(
dir=os.path.dirname(compressed_filepath), prefix=os.path.basename(compressed_filepath), suffix=".tmp-compress"
)
with input_obj, output_obj:
original_file_size, compressed_file_size = rohmufile.write_file(
data_callback=hasher.update,
input_obj=input_obj,
output_obj=output_obj,
compression_algorithm=compress_algorithm,
compression_level=compression_level,
rsa_public_key=CONSTANT_TEST_RSA_PUBLIC_KEY,
log_func=log.debug,
)
os.link(output_obj.name, compressed_filepath)

log.info("original_file_size: %s, compressed_file_size: %s", original_file_size, compressed_file_size)
assert original_file_size == len(content)
compressed_file_size = _compress_file(orig_file, compressed_filepath, compress_algorithm, compression_level, hasher)
file_hash = hasher.hexdigest()
log.info("original_file_hash: %s", file_hash)

log.info(
"original_file_size: %s, original_file_hash: %s, compressed_file_size: %s", original_file_size, file_hash,
compressed_file_size
)

# 2 - Upload the compressed file
upload_dir = work_dir / "uploaded"
Expand All @@ -66,22 +55,16 @@ def test_rohmu_with_local_storage(compress_algorithm: str, file_size: int, tmp_p
"directory": str(upload_dir),
"storage_type": "local",
}
storage = get_transfer(storage_config)

metadata = {
"encryption-key-id": "No matter",
"compression-algorithm": compress_algorithm,
"compression-level": compression_level,
"Content-Length": str(compressed_file_size)
}
storage = get_transfer(storage_config)

metadata_copy = metadata.copy()
metadata_copy["Content-Length"] = str(compressed_file_size)
file_key = "compressed/hello_compressed"

def upload_progress_callback(n_bytes: int) -> None:
log.debug("File: '%s', uploaded %d bytes", file_key, n_bytes)

with open(compressed_filepath, "rb") as f:
storage.store_file_object(file_key, f, metadata=metadata_copy, upload_progress_fn=upload_progress_callback)
_upload_compressed_file(storage=storage, file_to_upload=str(compressed_filepath), file_key=file_key, metadata=metadata)

# 3 - Decrypt and decompress
# 3.1 Use file downloading rohmu API
Expand All @@ -90,37 +73,58 @@ def upload_progress_callback(n_bytes: int) -> None:
decompressed_size = _download_and_decompress_with_file(storage, str(decompressed_filepath), file_key, metadata)
assert len(content) == decompressed_size
# Compare content
with open(decompressed_filepath, "rb") as file_in:
content_decrypted = file_in.read()
hasher = hashlib.new(hash_algorithm)
hasher.update(content_decrypted)
assert hasher.hexdigest() == file_hash
assert content_decrypted == content
content_decrypted = decompressed_filepath.read_bytes()
hasher = hashlib.new(hash_algorithm)
hasher.update(content_decrypted)
assert hasher.hexdigest() == file_hash
assert content_decrypted == content

# 3.2 Use rohmu SinkIO API
decompressed_filepath = work_dir / "hello_decompressed_2"
decompressed_size = _download_and_decompress_with_sink(storage, str(decompressed_filepath), file_key, metadata)
assert len(content) == decompressed_size

# Compare content
hasher.hexdigest()
with open(decompressed_filepath, "rb") as file_in:
content_decrypted = file_in.read()
hasher = hashlib.new(hash_algorithm)
hasher.update(content_decrypted)
assert hasher.hexdigest() == file_hash
assert content_decrypted == content
content_decrypted = decompressed_filepath.read_bytes()
hasher = hashlib.new(hash_algorithm)
hasher.update(content_decrypted)
assert hasher.hexdigest() == file_hash
assert content_decrypted == content

if file_size == 0:
empty_file_sha1 = "da39a3ee5e6b4b0d3255bfef95601890afd80709"
assert empty_file_sha1 == hasher.hexdigest()
assert EMPTY_FILE_SHA1 == hasher.hexdigest()


def _key_lookup(key_id: str): # pylint: disable=unused-argument
def _key_lookup(key_id: str) -> str: # pylint: disable=unused-argument
return CONSTANT_TEST_RSA_PRIVATE_KEY


def _download_and_decompress_with_sink(storage, output_path: str, file_key: str, metadata: dict):
def _compress_file(input_file: pathlib.Path, output_file: pathlib.Path, algorithm: str, compress_level: int, hasher) -> int:
with open(input_file, "rb") as input_obj, NamedTemporaryFile(
dir=output_file.parent, prefix=output_file.name, suffix=".tmp-compress"
) as output_obj:
_, compressed_file_size = rohmufile.write_file(
data_callback=hasher.update,
input_obj=input_obj,
output_obj=output_obj,
compression_algorithm=algorithm,
compression_level=compress_level,
rsa_public_key=CONSTANT_TEST_RSA_PUBLIC_KEY,
log_func=log.debug,
)
os.link(output_obj.name, output_file)
return compressed_file_size


def _upload_compressed_file(storage: BaseTransfer, file_to_upload: str, file_key: str, metadata: Metadata) -> None:
def upload_progress_callback(n_bytes: int) -> None:
log.debug("File: '%s', uploaded %d bytes", file_key, n_bytes)

with open(file_to_upload, "rb") as f:
storage.store_file_object(file_key, f, metadata=metadata, upload_progress_fn=upload_progress_callback)


def _download_and_decompress_with_sink(storage: BaseTransfer, output_path: str, file_key: str, metadata: Metadata) -> int:
data, _ = storage.get_contents_to_string(file_key)
if isinstance(data, str):
data = data.encode("latin1")
Expand All @@ -135,7 +139,7 @@ def _download_and_decompress_with_sink(storage, output_path: str, file_key: str,
return decompressed_size


def _download_and_decompress_with_file(storage, output_path: str, file_key: str, metadata: dict):
def _download_and_decompress_with_file(storage: BaseTransfer, output_path: str, file_key: str, metadata: Metadata) -> int:
# Download the compressed file
file_download_path = output_path + ".tmp"

Expand All @@ -146,9 +150,7 @@ def download_progress_callback(bytes_written: int, input_size: int) -> None:
storage.get_contents_to_fileobj(file_key, f, progress_callback=download_progress_callback)

# Decrypt and decompress
input_obj = open(file_download_path, "rb")
output_obj = open(output_path, "wb")
with input_obj, output_obj:
with open(file_download_path, "rb") as input_obj, open(output_path, "wb") as output_obj:
_, decompressed_size = rohmufile.read_file(
input_obj=input_obj,
output_obj=output_obj,
Expand Down

0 comments on commit 77847f2

Please sign in to comment.