From a55cf0814e983b7aebddb6077f29dd3317ed5dad Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Tue, 27 Aug 2024 20:39:56 -0700 Subject: [PATCH 1/6] implement core aes-gcm encrypt/decrypt --- src/snowflake/connector/constants.py | 15 ++ src/snowflake/connector/encryption_util.py | 181 ++++++++++++++++++- src/snowflake/connector/s3_storage_client.py | 10 +- src/snowflake/connector/storage_client.py | 11 +- 4 files changed, 214 insertions(+), 3 deletions(-) diff --git a/src/snowflake/connector/constants.py b/src/snowflake/connector/constants.py index 0d14fda1f..b6fd0f554 100644 --- a/src/snowflake/connector/constants.py +++ b/src/snowflake/connector/constants.py @@ -303,6 +303,10 @@ class EncryptionMetadata(NamedTuple): key: str iv: str matdesc: str + cipher: str | None + key_iv: str | None + key_aad: str | None + data_aad: str | None class FileHeader(NamedTuple): @@ -428,3 +432,14 @@ class IterUnit(Enum): _DOMAIN_NAME_MAP = {_DEFAULT_HOSTNAME_TLD: "GLOBAL", _CHINA_HOSTNAME_TLD: "CHINA"} + + +class CipherAlgorithm(str, Enum): + AES_CBC = "AES_CBC" # Use AES-ECB/AES-CBC + AES_GCM = "AES_GCM" # Use AES-GCM/AES-GCM + AES_GCM_CBC = ( + "AES_GCM,AES_CBC" # Use AES-GCM if the code version supports, otherwise AES-CBC + ) + + def __str__(self): + return self.value diff --git a/src/snowflake/connector/encryption_util.py b/src/snowflake/connector/encryption_util.py index c1c34079e..051e4e98a 100644 --- a/src/snowflake/connector/encryption_util.py +++ b/src/snowflake/connector/encryption_util.py @@ -16,7 +16,13 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from .compat import PKCS5_OFFSET, PKCS5_PAD, PKCS5_UNPAD -from .constants import UTF8, EncryptionMetadata, MaterialDescriptor, kilobyte +from .constants import ( + UTF8, + CipherAlgorithm, + EncryptionMetadata, + MaterialDescriptor, + kilobyte, +) from .util_text import random_string block_size = int(algorithms.AES.block_size / 8) # in bytes @@ -110,6 +116,10 @@ def encrypt_stream( key=base64.b64encode(enc_kek).decode("utf-8"), iv=base64.b64encode(iv_data).decode("utf-8"), matdesc=matdesc_to_unicode(mat_desc), + cipher=None, + key_iv=None, + key_aad=None, + data_aad=None, ) return metadata @@ -148,6 +158,105 @@ def encrypt_file( ) return metadata, temp_output_file + @staticmethod + def encrypt_stream_gcm( + encryption_material: SnowflakeFileEncryptionMaterial, + src: IO[bytes], + out: IO[bytes], + key_aad: str = "", + data_aad: str = "", + ): + logger = getLogger(__name__) + decoded_key = base64.standard_b64decode( + encryption_material.query_stage_master_key + ) + key_size = len(decoded_key) + logger.debug("key_size = %s", key_size) + + iv_data = SnowflakeEncryptionUtil.get_secure_random(block_size) + key_iv_data = SnowflakeEncryptionUtil.get_secure_random(key_size) + file_key = SnowflakeEncryptionUtil.get_secure_random(key_size) + backend = default_backend() + file_key_cipher = Cipher( + algorithms.AES(decoded_key), modes.CBC(key_iv_data), backend=backend + ) + file_key_encryptor = file_key_cipher.encryptor() + if key_aad: + file_key_encryptor.authenticate_additional_data( + base64.standard_b64decode(key_aad) + ) + encrypted_file_key = ( + file_key_encryptor.update(file_key) + file_key_encryptor.finalize() + ) + # encrypted_file_key_tag = file_key_encryptor.tag # TODO: where to put tag? + + content_cipher = Cipher( + algorithms.AES(file_key), modes.GCM(iv_data), backend=backend + ) + content_encryptor = content_cipher.encryptor() + if data_aad: + content_encryptor.authenticate_additional_data( + base64.standard_b64decode(data_aad) + ) + + encrypted_content = ( + content_encryptor.update(src.read()) + content_encryptor.finalize() + ) + # encrypted_content_tag = content_encryptor.tag # TODO: where to put tag? + out.write(encrypted_content) + + mat_desc = MaterialDescriptor( + smk_id=encryption_material.smk_id, + query_id=encryption_material.query_id, + key_size=key_size * 8, + ) + metadata = EncryptionMetadata( + key=base64.b64encode(encrypted_file_key).decode("utf-8"), + iv=base64.b64encode(iv_data).decode("utf-8"), + matdesc=matdesc_to_unicode(mat_desc), + cipher=str(CipherAlgorithm.AES_GCM), + key_iv=base64.b64encode(key_iv_data).decode("utf-8"), + key_aad=key_aad, + data_aad=data_aad, + ) + return metadata + + @staticmethod + def encrypt_file_gcm( + encryption_material: SnowflakeFileEncryptionMaterial, + in_filename: str, + tmp_dir: str | None = None, + key_aad: str = "", + data_aad: str = "", + ) -> tuple[EncryptionMetadata, str]: + """Encrypts a file in a temporary directory. + + Args: + encryption_material: The encryption material for file. + in_filename: The input file's name. + chunk_size: The size of read chunks (Default value = block_size * 4 * 1024). + tmp_dir: Temporary directory to use, optional (Default value = None). + + Returns: + The encryption metadata and the encrypted file's location. + """ + logger = getLogger(__name__) + temp_output_fd, temp_output_file = tempfile.mkstemp( + text=False, dir=tmp_dir, prefix=os.path.basename(in_filename) + "#" + ) + logger.debug( + "unencrypted file: %s, temp file: %s, tmp_dir: %s", + in_filename, + temp_output_file, + tmp_dir, + ) + with open(in_filename, "rb") as infile: + with os.fdopen(temp_output_fd, "wb") as outfile: + metadata = SnowflakeEncryptionUtil.encrypt_stream_gcm( + encryption_material, infile, outfile, key_aad, data_aad + ) + return metadata, temp_output_file + @staticmethod def decrypt_stream( metadata: EncryptionMetadata, @@ -218,3 +327,73 @@ def decrypt_file( metadata, encryption_material, infile, outfile, chunk_size ) return temp_output_file + + @staticmethod + def decrypt_stream_gcm( + metadata: EncryptionMetadata, + encryption_material: SnowflakeFileEncryptionMaterial, + src: IO[bytes], + out: IO[bytes], + ) -> None: + """To read from `src` stream then decrypt to `out` stream.""" + # TODO: where to get tag for both? + key_base64 = metadata.key + iv_base64 = metadata.iv + key_iv_base64 = metadata.key_iv + decoded_key = base64.standard_b64decode( + encryption_material.query_stage_master_key + ) + key_bytes = base64.standard_b64decode(key_base64) + iv_bytes = base64.standard_b64decode(iv_base64) + key_iv_bytes = base64.standard_b64decode(key_iv_base64) + key_aad = base64.standard_b64decode(metadata.key_aad) + data_aad = base64.standard_b64decode(metadata.data_aad) + + backend = default_backend() + file_key_cipher = Cipher( + algorithms.AES(decoded_key), modes.GCM(key_iv_bytes), backend=backend + ) + file_key_decryptor = file_key_cipher.decryptor() + if key_aad: + file_key_decryptor.authenticate_additional_data(key_aad) + file_key = file_key_decryptor.update(key_bytes) + file_key_decryptor.finalize() + + content_cipher = Cipher( + algorithms.AES(file_key), modes.GCM(iv_bytes), backend=backend + ) + content_decryptor = content_cipher.decryptor() + if data_aad: + content_decryptor.authenticate_additional_data(data_aad) + content = content_decryptor.update(src.read()) + content_decryptor.finalize() + out.write(content) + + @staticmethod + def decrypt_file_gcm( + metadata: EncryptionMetadata, + encryption_material: SnowflakeFileEncryptionMaterial, + in_filename: str, + tmp_dir: str | None = None, + ) -> str: + """Decrypts a file and stores the output in the temporary directory. + + Args: + metadata: The file's metadata input. + encryption_material: The file's encryption material. + in_filename: The name of the input file. + chunk_size: The size of read chunks (Default value = block_size * 4 * 1024). + tmp_dir: Temporary directory to use, optional (Default value = None). + + Returns: + The decrypted file's location. + """ + temp_output_file = f"{os.path.basename(in_filename)}#{random_string()}" + if tmp_dir: + temp_output_file = os.path.join(tmp_dir, temp_output_file) + + logger.debug("encrypted file: %s, tmp file: %s", in_filename, temp_output_file) + with open(in_filename, "rb") as infile: + with open(temp_output_file, "wb") as outfile: + SnowflakeEncryptionUtil.decrypt_stream_gcm( + metadata, encryption_material, infile, outfile + ) + return temp_output_file diff --git a/src/snowflake/connector/s3_storage_client.py b/src/snowflake/connector/s3_storage_client.py index 673134081..0389b3bb9 100644 --- a/src/snowflake/connector/s3_storage_client.py +++ b/src/snowflake/connector/s3_storage_client.py @@ -37,6 +37,10 @@ AMZ_MATDESC = "x-amz-matdesc" AMZ_KEY = "x-amz-key" AMZ_IV = "x-amz-iv" +AMZ_CIPHER = "x-amz-cipher" +AMZ_KEY_IV = "x-amz-key-iv" +AMZ_KEY_AAD = "x-amz-key-aad" +AMZ_DATA_AAD = "x-amz-data-aad" ERRORNO_WSAECONNABORTED = 10053 # network connection was aborted @@ -397,6 +401,10 @@ def get_file_header(self, filename: str) -> FileHeader | None: key=metadata.get(META_PREFIX + AMZ_KEY), iv=metadata.get(META_PREFIX + AMZ_IV), matdesc=metadata.get(META_PREFIX + AMZ_MATDESC), + cipher=metadata.get(META_PREFIX + AMZ_CIPHER), + key_iv=metadata.get(META_PREFIX + AMZ_KEY_IV), + key_aad=metadata.get(META_PREFIX + AMZ_KEY_AAD), + data_aad=metadata.get(META_PREFIX + AMZ_DATA_AAD), ) if metadata.get(META_PREFIX + AMZ_KEY) else None @@ -557,7 +565,7 @@ def download_chunk(self, chunk_id: int) -> None: else: chunk_size = self.chunk_size if chunk_id < self.num_of_chunks - 1: - _range = f"{chunk_id * chunk_size}-{(chunk_id+1)*chunk_size-1}" + _range = f"{chunk_id * chunk_size}-{(chunk_id + 1) * chunk_size - 1}" else: _range = f"{chunk_id * chunk_size}-" diff --git a/src/snowflake/connector/storage_client.py b/src/snowflake/connector/storage_client.py index ba74f511b..5635ad462 100644 --- a/src/snowflake/connector/storage_client.py +++ b/src/snowflake/connector/storage_client.py @@ -23,6 +23,7 @@ HTTP_HEADER_CONTENT_ENCODING, REQUEST_CONNECTION_TIMEOUT, REQUEST_READ_TIMEOUT, + CipherAlgorithm, FileHeader, ResultStatus, ) @@ -115,6 +116,14 @@ def __init__( self.failed_transfers: int = 0 # only used when PRESIGNED_URL expires self.last_err_is_presigned_url = False + self._is_client_side_encrypted = self.stage_info.get( + "isClientSideEncrypted", True + ) + self._ciphers = ( + CipherAlgorithm(str(self.stage_info.get("ciphers").upper())) + if self.stage_info.get("ciphers", "") + else None + ) def compress(self) -> None: if self.meta.require_compress: @@ -233,7 +242,7 @@ def prepare_upload(self) -> None: logger.debug(f"Preparing to upload {meta.src_file_name}") - if meta.encryption_material: + if meta.encryption_material and self._is_client_side_encrypted: self.encrypt() else: self.data_file = meta.real_src_file_name From ebd2720bdd37714d6a98696493d57208245298d2 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Tue, 27 Aug 2024 20:42:45 -0700 Subject: [PATCH 2/6] minor --- src/snowflake/connector/encryption_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/connector/encryption_util.py b/src/snowflake/connector/encryption_util.py index 051e4e98a..f4df2647e 100644 --- a/src/snowflake/connector/encryption_util.py +++ b/src/snowflake/connector/encryption_util.py @@ -178,7 +178,7 @@ def encrypt_stream_gcm( file_key = SnowflakeEncryptionUtil.get_secure_random(key_size) backend = default_backend() file_key_cipher = Cipher( - algorithms.AES(decoded_key), modes.CBC(key_iv_data), backend=backend + algorithms.AES(decoded_key), modes.GCM(key_iv_data), backend=backend ) file_key_encryptor = file_key_cipher.encryptor() if key_aad: From 6d5280b864064a8bc0335518ae301c01050807e1 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Wed, 28 Aug 2024 16:10:51 -0700 Subject: [PATCH 3/6] append tag to the end of the stream --- src/snowflake/connector/encryption_util.py | 24 +++-- test/unit/test_encryption_util.py | 114 +++++++++++++++++++++ 2 files changed, 129 insertions(+), 9 deletions(-) diff --git a/src/snowflake/connector/encryption_util.py b/src/snowflake/connector/encryption_util.py index f4df2647e..a0be6cab8 100644 --- a/src/snowflake/connector/encryption_util.py +++ b/src/snowflake/connector/encryption_util.py @@ -27,6 +27,7 @@ block_size = int(algorithms.AES.block_size / 8) # in bytes + if TYPE_CHECKING: # pragma: no cover from .storage_client import SnowflakeFileEncryptionMaterial @@ -186,10 +187,10 @@ def encrypt_stream_gcm( base64.standard_b64decode(key_aad) ) encrypted_file_key = ( - file_key_encryptor.update(file_key) + file_key_encryptor.finalize() + file_key_encryptor.update(file_key) + + file_key_encryptor.finalize() + + file_key_encryptor.tag ) - # encrypted_file_key_tag = file_key_encryptor.tag # TODO: where to put tag? - content_cipher = Cipher( algorithms.AES(file_key), modes.GCM(iv_data), backend=backend ) @@ -200,9 +201,10 @@ def encrypt_stream_gcm( ) encrypted_content = ( - content_encryptor.update(src.read()) + content_encryptor.finalize() + content_encryptor.update(src.read()) + + content_encryptor.finalize() + + content_encryptor.tag ) - # encrypted_content_tag = content_encryptor.tag # TODO: where to put tag? out.write(encrypted_content) mat_desc = MaterialDescriptor( @@ -336,7 +338,6 @@ def decrypt_stream_gcm( out: IO[bytes], ) -> None: """To read from `src` stream then decrypt to `out` stream.""" - # TODO: where to get tag for both? key_base64 = metadata.key iv_base64 = metadata.iv key_iv_base64 = metadata.key_iv @@ -344,6 +345,7 @@ def decrypt_stream_gcm( encryption_material.query_stage_master_key ) key_bytes = base64.standard_b64decode(key_base64) + key_bytes, key_tag = key_bytes[:-block_size], key_bytes[-block_size:] iv_bytes = base64.standard_b64decode(iv_base64) key_iv_bytes = base64.standard_b64decode(key_iv_base64) key_aad = base64.standard_b64decode(metadata.key_aad) @@ -351,20 +353,24 @@ def decrypt_stream_gcm( backend = default_backend() file_key_cipher = Cipher( - algorithms.AES(decoded_key), modes.GCM(key_iv_bytes), backend=backend + algorithms.AES(decoded_key), + modes.GCM(key_iv_bytes, key_tag), + backend=backend, ) file_key_decryptor = file_key_cipher.decryptor() if key_aad: file_key_decryptor.authenticate_additional_data(key_aad) file_key = file_key_decryptor.update(key_bytes) + file_key_decryptor.finalize() + src_bytes = src.read() + src_bytes, data_tag = src_bytes[:-block_size], src_bytes[-block_size:] content_cipher = Cipher( - algorithms.AES(file_key), modes.GCM(iv_bytes), backend=backend + algorithms.AES(file_key), modes.GCM(iv_bytes, data_tag), backend=backend ) content_decryptor = content_cipher.decryptor() if data_aad: content_decryptor.authenticate_additional_data(data_aad) - content = content_decryptor.update(src.read()) + content_decryptor.finalize() + content = content_decryptor.update(src_bytes) + content_decryptor.finalize() out.write(content) @staticmethod diff --git a/test/unit/test_encryption_util.py b/test/unit/test_encryption_util.py index d1c08ab8c..93e9e04c9 100644 --- a/test/unit/test_encryption_util.py +++ b/test/unit/test_encryption_util.py @@ -5,11 +5,14 @@ from __future__ import annotations +import base64 import codecs import glob import os from os import path +import pytest + from snowflake.connector import encryption_util from snowflake.connector.constants import UTF8 from snowflake.connector.encryption_util import SnowflakeEncryptionUtil @@ -116,3 +119,114 @@ def test_encrypt_decrypt_large_file(tmpdir): os.remove(encrypted_file) if decrypted_file: os.remove(decrypted_file) + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "key_aad, data_aad", + [("", ""), (base64.b64encode(b"test key aad"), base64.b64encode(b"test data aad"))], +) +def test_gcm_encrypt_decrypt_file(tmp_path, key_aad, data_aad): + """Encrypts and Decrypts a file.""" + encryption_material = SnowflakeFileEncryptionMaterial( + query_stage_master_key="ztke8tIdVt1zmlQIZm0BMA==", + query_id="123873c7-3a66-40c4-ab89-e3722fbccce1", + smk_id=3112, + ) + data = "test data" + input_file = tmp_path / "test_encrypt_decrypt_file" + encrypted_file = None + decrypted_file = None + try: + with input_file.open("w", encoding=UTF8) as fd: + fd.write(data) + + (metadata, encrypted_file) = SnowflakeEncryptionUtil.encrypt_file_gcm( + encryption_material, + input_file, + tmp_dir=str(tmp_path), + key_aad=key_aad, + data_aad=data_aad, + ) + assert key_aad == metadata.key_aad and data_aad == metadata.data_aad + decrypted_file = SnowflakeEncryptionUtil.decrypt_file_gcm( + metadata, encryption_material, encrypted_file, tmp_dir=str(tmp_path) + ) + + contents = "" + with codecs.open(decrypted_file, "r", encoding=UTF8) as fd: + for line in fd: + contents += line + assert data == contents, "encrypted and decrypted contents" + finally: + input_file.unlink() + if encrypted_file: + os.remove(encrypted_file) + if decrypted_file: + os.remove(decrypted_file) + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "key_aad, data_aad", + [("", ""), (base64.b64encode(b"test key aad"), base64.b64encode(b"test data aad"))], +) +def test_gcm_encrypt_decrypt_large_file(tmpdir, key_aad, data_aad): + """Encrypts and Decrypts a large file.""" + encryption_material = SnowflakeFileEncryptionMaterial( + query_stage_master_key="ztke8tIdVt1zmlQIZm0BMA==", + query_id="123873c7-3a66-40c4-ab89-e3722fbccce1", + smk_id=3112, + ) + + # generates N files + number_of_files = 1 + number_of_lines = 100_000 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + + files = glob.glob(os.path.join(tmp_dir, "file*")) + input_file = files[0] + encrypted_file = None + decrypted_file = None + try: + digest_in, size_in = SnowflakeFileUtil.get_digest_and_size_for_file(input_file) + for run_count in range(2): + # Test padding cases when size is and is not multiple of block_size + if run_count == 1: + # second time run, truncate the file to test a different padding case + with open(input_file, "wb") as f_in: + if size_in % encryption_util.block_size == 0: + size_in -= 3 + else: + size_in -= size_in % encryption_util.block_size + f_in.truncate(size_in) + digest_in, size_in = SnowflakeFileUtil.get_digest_and_size_for_file( + input_file + ) + + (metadata, encrypted_file) = SnowflakeEncryptionUtil.encrypt_file_gcm( + encryption_material, + input_file, + tmp_dir=str(tmpdir), + key_aad=key_aad, + data_aad=data_aad, + ) + assert key_aad == metadata.key_aad and data_aad == metadata.data_aad + decrypted_file = SnowflakeEncryptionUtil.decrypt_file_gcm( + metadata, encryption_material, encrypted_file, tmp_dir=str(tmpdir) + ) + + digest_dec, size_dec = SnowflakeFileUtil.get_digest_and_size_for_file( + decrypted_file + ) + assert size_in == size_dec + assert digest_in == digest_dec + + finally: + os.remove(input_file) + if encrypted_file: + os.remove(encrypted_file) + if decrypted_file: + os.remove(decrypted_file) From fcd67080fb5c9882df31bd0b62fd88f5f58adb97 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Wed, 28 Aug 2024 17:57:19 -0700 Subject: [PATCH 4/6] update upload and download metadata --- .../connector/azure_storage_client.py | 30 ++++++++++++++++--- src/snowflake/connector/encryption_util.py | 8 ++--- src/snowflake/connector/gcs_storage_client.py | 21 +++++++++++-- src/snowflake/connector/s3_storage_client.py | 10 +++++-- src/snowflake/connector/storage_client.py | 22 ++++++++++++-- 5 files changed, 75 insertions(+), 16 deletions(-) diff --git a/src/snowflake/connector/azure_storage_client.py b/src/snowflake/connector/azure_storage_client.py index 744f0ba91..63a89bffd 100644 --- a/src/snowflake/connector/azure_storage_client.py +++ b/src/snowflake/connector/azure_storage_client.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, NamedTuple from .compat import quote -from .constants import FileHeader, ResultStatus +from .constants import CipherAlgorithm, FileHeader, ResultStatus from .encryption_util import EncryptionMetadata from .storage_client import SnowflakeStorageClient from .vendored import requests @@ -130,6 +130,15 @@ def get_file_header(self, filename: str) -> FileHeader | None: key=encryption_data["WrappedContentKey"]["EncryptedKey"], iv=encryption_data["ContentEncryptionIV"], matdesc=r.headers.get(MATDESC), + cipher=( + str(CipherAlgorithm.AES_GCM) + if "AES_GCM" + in encryption_data["WrappedContentKey"]["Algorithm"] + else str(CipherAlgorithm.AES_CBC) + ), + key_iv=encryption_data.get("KeyEncryptionIV", ""), + key_aad=encryption_data.get("KeyAad", ""), + data_aad=encryption_data.get("DataAad", ""), ) ) return FileHeader( @@ -151,6 +160,16 @@ def _prepare_file_metadata(self) -> dict[str, str | None]: } encryption_metadata = self.encryption_metadata if encryption_metadata: + algorithm = ( + "AES_GCM_256" + if encryption_metadata.cipher == CipherAlgorithm.AES_GCM + else "AES_CBC_256" + ) + encryption_algorithm = ( + "AES_GCM_256" + if encryption_metadata.cipher == CipherAlgorithm.AES_GCM + else "AES_CBC_128" + ) azure_metadata.update( { ENCRYPTION_DATA: json.dumps( @@ -159,13 +178,16 @@ def _prepare_file_metadata(self) -> dict[str, str | None]: "WrappedContentKey": { "KeyId": "symmKey1", "EncryptedKey": encryption_metadata.key, - "Algorithm": "AES_CBC_256", + "Algorithm": algorithm, }, "EncryptionAgent": { "Protocol": "1.0", - "EncryptionAlgorithm": "AES_CBC_128", + "EncryptionAlgorithm": encryption_algorithm, }, "ContentEncryptionIV": encryption_metadata.iv, + "KeyEncryptionIV": encryption_metadata.key_iv or "", + "KeyAad": encryption_metadata.key_aad or "", + "DataAad": encryption_metadata.data_aad or "", "KeyWrappingMetadata": {"EncryptionLibrary": "Java 5.3.0"}, } ), @@ -237,7 +259,7 @@ def download_chunk(self, chunk_id: int) -> None: if self.num_of_chunks > 1: chunk_size = self.chunk_size if chunk_id < self.num_of_chunks - 1: - _range = f"{chunk_id * chunk_size}-{(chunk_id+1)*chunk_size-1}" + _range = f"{chunk_id * chunk_size}-{(chunk_id + 1) * chunk_size - 1}" else: _range = f"{chunk_id * chunk_size}-" headers = {"Range": f"bytes={_range}"} diff --git a/src/snowflake/connector/encryption_util.py b/src/snowflake/connector/encryption_util.py index a0be6cab8..8d4dafa62 100644 --- a/src/snowflake/connector/encryption_util.py +++ b/src/snowflake/connector/encryption_util.py @@ -117,10 +117,10 @@ def encrypt_stream( key=base64.b64encode(enc_kek).decode("utf-8"), iv=base64.b64encode(iv_data).decode("utf-8"), matdesc=matdesc_to_unicode(mat_desc), - cipher=None, - key_iv=None, - key_aad=None, - data_aad=None, + cipher=str(CipherAlgorithm.AES_CBC), + key_iv="", + key_aad="", + data_aad="", ) return metadata diff --git a/src/snowflake/connector/gcs_storage_client.py b/src/snowflake/connector/gcs_storage_client.py index 0bf76a75a..5550f5967 100644 --- a/src/snowflake/connector/gcs_storage_client.py +++ b/src/snowflake/connector/gcs_storage_client.py @@ -15,6 +15,7 @@ from .constants import ( FILE_PROTOCOL, HTTP_HEADER_CONTENT_ENCODING, + CipherAlgorithm, FileHeader, ResultStatus, kilobyte, @@ -119,6 +120,11 @@ def _upload_chunk(self, chunk_id: int, chunk: bytes) -> None: } if self.encryption_metadata: + algorithm = ( + "AES_GCM_256" + if self.encryption_metadata.cipher == CipherAlgorithm.AES_GCM + else "AES_CBC_256" + ) gcs_headers.update( { GCS_METADATA_ENCRYPTIONDATAPROP: json.dumps( @@ -127,13 +133,16 @@ def _upload_chunk(self, chunk_id: int, chunk: bytes) -> None: "WrappedContentKey": { "KeyId": "symmKey1", "EncryptedKey": self.encryption_metadata.key, - "Algorithm": "AES_CBC_256", + "Algorithm": algorithm, }, "EncryptionAgent": { "Protocol": "1.0", - "EncryptionAlgorithm": "AES_CBC_256", + "EncryptionAlgorithm": algorithm, }, "ContentEncryptionIV": self.encryption_metadata.iv, + "KeyEncryptionIV": self.encryption_metadata.key_iv, + "KeyAad": self.encryption_metadata.key_aad, + "DataAad": self.encryption_metadata.data_aad, "KeyWrappingMetadata": {"EncryptionLibrary": "Java 5.3.0"}, } ), @@ -208,6 +217,14 @@ def generate_url_and_rest_args() -> ( if GCS_METADATA_MATDESC_KEY in response.headers else None ), + cipher=( + str(CipherAlgorithm.AES_GCM) + if "AES_GCM" in encryptiondata["WrappedContentKey"]["Algorithm"] + else str(CipherAlgorithm.AES_CBC) + ), + key_iv=encryptiondata.get("KeyEncryptionIV", ""), + key_aad=encryptiondata.get("KeyAad", ""), + data_aad=encryptiondata.get("DataAad"), ) meta.gcs_file_header_digest = response.headers.get(GCS_METADATA_SFC_DIGEST) diff --git a/src/snowflake/connector/s3_storage_client.py b/src/snowflake/connector/s3_storage_client.py index 0389b3bb9..b18281f05 100644 --- a/src/snowflake/connector/s3_storage_client.py +++ b/src/snowflake/connector/s3_storage_client.py @@ -402,9 +402,9 @@ def get_file_header(self, filename: str) -> FileHeader | None: iv=metadata.get(META_PREFIX + AMZ_IV), matdesc=metadata.get(META_PREFIX + AMZ_MATDESC), cipher=metadata.get(META_PREFIX + AMZ_CIPHER), - key_iv=metadata.get(META_PREFIX + AMZ_KEY_IV), - key_aad=metadata.get(META_PREFIX + AMZ_KEY_AAD), - data_aad=metadata.get(META_PREFIX + AMZ_DATA_AAD), + key_iv=metadata.get(META_PREFIX + AMZ_KEY_IV, ""), + key_aad=metadata.get(META_PREFIX + AMZ_KEY_AAD, ""), + data_aad=metadata.get(META_PREFIX + AMZ_DATA_AAD, ""), ) if metadata.get(META_PREFIX + AMZ_KEY) else None @@ -438,6 +438,10 @@ def _prepare_file_metadata(self) -> dict[str, Any]: META_PREFIX + AMZ_IV: self.encryption_metadata.iv, META_PREFIX + AMZ_KEY: self.encryption_metadata.key, META_PREFIX + AMZ_MATDESC: self.encryption_metadata.matdesc, + META_PREFIX + AMZ_CIPHER: self.encryption_metadata.cipher or "", + META_PREFIX + AMZ_KEY_IV: self.encryption_metadata.key_iv or "", + META_PREFIX + AMZ_KEY_AAD: self.encryption_metadata.key_aad or "", + META_PREFIX + AMZ_DATA_AAD: self.encryption_metadata.data_aad or "", } ) return s3_metadata diff --git a/src/snowflake/connector/storage_client.py b/src/snowflake/connector/storage_client.py index 5635ad462..99d3f241f 100644 --- a/src/snowflake/connector/storage_client.py +++ b/src/snowflake/connector/storage_client.py @@ -163,11 +163,22 @@ def get_digest(self) -> None: def encrypt(self) -> None: meta = self.meta logger.debug(f"encrypting file={meta.real_src_file_name}") + # TODO: when putting files, where to get key_aad and data_aad from? + encrypt_file_impl = ( + SnowflakeEncryptionUtil.encrypt_file + if self._ciphers == CipherAlgorithm.AES_CBC + else SnowflakeEncryptionUtil.encrypt_file_gcm + ) + encrypt_stream_impl = ( + SnowflakeEncryptionUtil.encrypt_stream + if self._ciphers == CipherAlgorithm.AES_CBC + else SnowflakeEncryptionUtil.encrypt_stream_gcm + ) if meta.intermediate_stream is None: ( self.encryption_metadata, self.data_file, - ) = SnowflakeEncryptionUtil.encrypt_file( + ) = encrypt_file_impl( meta.encryption_material, meta.real_src_file_name, tmp_dir=self.tmp_dir, @@ -177,7 +188,7 @@ def encrypt(self) -> None: encrypted_stream = BytesIO() src_stream = meta.src_stream or meta.intermediate_stream src_stream.seek(0) - self.encryption_metadata = SnowflakeEncryptionUtil.encrypt_stream( + self.encryption_metadata = encrypt_stream_impl( meta.encryption_material, src_stream, encrypted_stream ) src_stream.seek(0) @@ -389,7 +400,12 @@ def finish_download(self) -> None: file_header = self.get_file_header(meta.src_file_name) self.encryption_metadata = file_header.encryption_metadata - tmp_dst_file_name = SnowflakeEncryptionUtil.decrypt_file( + decrypt_file_impl = ( + SnowflakeEncryptionUtil.decrypt_file + if self._ciphers == CipherAlgorithm.AES_CBC + else SnowflakeEncryptionUtil.decrypt_file_gcm + ) + tmp_dst_file_name = decrypt_file_impl( self.encryption_metadata, meta.encryption_material, str(self.intermediate_dst_path), From 57e251a3845eeb790085310f877fdc0fd5273ddd Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Wed, 28 Aug 2024 18:08:53 -0700 Subject: [PATCH 5/6] non breaking change --- src/snowflake/connector/constants.py | 30 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/snowflake/connector/constants.py b/src/snowflake/connector/constants.py index b6fd0f554..5f1cc34a4 100644 --- a/src/snowflake/connector/constants.py +++ b/src/snowflake/connector/constants.py @@ -299,14 +299,25 @@ class MaterialDescriptor(NamedTuple): key_size: int +class CipherAlgorithm(str, Enum): + AES_CBC = "AES_CBC" # Use AES-ECB/AES-CBC + AES_GCM = "AES_GCM" # Use AES-GCM/AES-GCM + AES_GCM_CBC = ( + "AES_GCM,AES_CBC" # Use AES-GCM if the code version supports, otherwise AES-CBC + ) + + def __str__(self): + return self.value + + class EncryptionMetadata(NamedTuple): key: str iv: str matdesc: str - cipher: str | None - key_iv: str | None - key_aad: str | None - data_aad: str | None + cipher: str = str(CipherAlgorithm.AES_CBC) + key_iv: str = "" + key_aad: str = "" + data_aad: str = "" class FileHeader(NamedTuple): @@ -432,14 +443,3 @@ class IterUnit(Enum): _DOMAIN_NAME_MAP = {_DEFAULT_HOSTNAME_TLD: "GLOBAL", _CHINA_HOSTNAME_TLD: "CHINA"} - - -class CipherAlgorithm(str, Enum): - AES_CBC = "AES_CBC" # Use AES-ECB/AES-CBC - AES_GCM = "AES_GCM" # Use AES-GCM/AES-GCM - AES_GCM_CBC = ( - "AES_GCM,AES_CBC" # Use AES-GCM if the code version supports, otherwise AES-CBC - ) - - def __str__(self): - return self.value From 72cfe9f27740b094290f3491e32809bca43a7588 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Thu, 29 Aug 2024 10:12:46 -0700 Subject: [PATCH 6/6] fix tests --- test/unit/test_gcs_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/test_gcs_client.py b/test/unit/test_gcs_client.py index 963d20d57..5d66c1d1a 100644 --- a/test/unit/test_gcs_client.py +++ b/test/unit/test_gcs_client.py @@ -344,7 +344,7 @@ def test_get_file_header_none_with_presigned_url(tmp_path): ) storage_credentials = Mock() storage_credentials.creds = {} - stage_info = Mock() + stage_info = {} connection = Mock() client = SnowflakeGCSRestClient( meta, storage_credentials, stage_info, connection, ""