Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snow 1622111 implement python connector aes gcm #2037

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions src/snowflake/connector/azure_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import TYPE_CHECKING, Any, NamedTuple

from .compat import quote
from .constants import FileHeader, ResultStatus
from .constants import CipherAlgorithm, FileHeader, ResultStatus
from .encryption_util import EncryptionMetadata
from .storage_client import SnowflakeStorageClient
from .vendored import requests
Expand Down Expand Up @@ -130,6 +130,15 @@ def get_file_header(self, filename: str) -> FileHeader | None:
key=encryption_data["WrappedContentKey"]["EncryptedKey"],
iv=encryption_data["ContentEncryptionIV"],
matdesc=r.headers.get(MATDESC),
cipher=(
str(CipherAlgorithm.AES_GCM)
if "AES_GCM"
in encryption_data["WrappedContentKey"]["Algorithm"]
else str(CipherAlgorithm.AES_CBC)
),
key_iv=encryption_data.get("KeyEncryptionIV", ""),
key_aad=encryption_data.get("KeyAad", ""),
data_aad=encryption_data.get("DataAad", ""),
)
)
return FileHeader(
Expand All @@ -151,6 +160,16 @@ def _prepare_file_metadata(self) -> dict[str, str | None]:
}
encryption_metadata = self.encryption_metadata
if encryption_metadata:
algorithm = (
"AES_GCM_256"
if encryption_metadata.cipher == CipherAlgorithm.AES_GCM
else "AES_CBC_256"
)
encryption_algorithm = (
"AES_GCM_256"
if encryption_metadata.cipher == CipherAlgorithm.AES_GCM
else "AES_CBC_128"
)
azure_metadata.update(
{
ENCRYPTION_DATA: json.dumps(
Expand All @@ -159,13 +178,16 @@ def _prepare_file_metadata(self) -> dict[str, str | None]:
"WrappedContentKey": {
"KeyId": "symmKey1",
"EncryptedKey": encryption_metadata.key,
"Algorithm": "AES_CBC_256",
"Algorithm": algorithm,
},
"EncryptionAgent": {
"Protocol": "1.0",
"EncryptionAlgorithm": "AES_CBC_128",
"EncryptionAlgorithm": encryption_algorithm,
},
"ContentEncryptionIV": encryption_metadata.iv,
"KeyEncryptionIV": encryption_metadata.key_iv or "",
"KeyAad": encryption_metadata.key_aad or "",
"DataAad": encryption_metadata.data_aad or "",
"KeyWrappingMetadata": {"EncryptionLibrary": "Java 5.3.0"},
}
),
Expand Down Expand Up @@ -237,7 +259,7 @@ def download_chunk(self, chunk_id: int) -> None:
if self.num_of_chunks > 1:
chunk_size = self.chunk_size
if chunk_id < self.num_of_chunks - 1:
_range = f"{chunk_id * chunk_size}-{(chunk_id+1)*chunk_size-1}"
_range = f"{chunk_id * chunk_size}-{(chunk_id + 1) * chunk_size - 1}"
else:
_range = f"{chunk_id * chunk_size}-"
headers = {"Range": f"bytes={_range}"}
Expand Down
15 changes: 15 additions & 0 deletions src/snowflake/connector/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,25 @@ class MaterialDescriptor(NamedTuple):
key_size: int


class CipherAlgorithm(str, Enum):
AES_CBC = "AES_CBC" # Use AES-ECB/AES-CBC
AES_GCM = "AES_GCM" # Use AES-GCM/AES-GCM
AES_GCM_CBC = (
"AES_GCM,AES_CBC" # Use AES-GCM if the code version supports, otherwise AES-CBC
)

def __str__(self):
return self.value


class EncryptionMetadata(NamedTuple):
key: str
iv: str
matdesc: str
cipher: str = str(CipherAlgorithm.AES_CBC)
key_iv: str = ""
key_aad: str = ""
data_aad: str = ""


class FileHeader(NamedTuple):
Expand Down
187 changes: 186 additions & 1 deletion src/snowflake/connector/encryption_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes

from .compat import PKCS5_OFFSET, PKCS5_PAD, PKCS5_UNPAD
from .constants import UTF8, EncryptionMetadata, MaterialDescriptor, kilobyte
from .constants import (
UTF8,
CipherAlgorithm,
EncryptionMetadata,
MaterialDescriptor,
kilobyte,
)
from .util_text import random_string

block_size = int(algorithms.AES.block_size / 8) # in bytes


if TYPE_CHECKING: # pragma: no cover
from .storage_client import SnowflakeFileEncryptionMaterial

Expand Down Expand Up @@ -110,6 +117,10 @@ def encrypt_stream(
key=base64.b64encode(enc_kek).decode("utf-8"),
iv=base64.b64encode(iv_data).decode("utf-8"),
matdesc=matdesc_to_unicode(mat_desc),
cipher=str(CipherAlgorithm.AES_CBC),
key_iv="",
key_aad="",
data_aad="",
)
return metadata

Expand Down Expand Up @@ -148,6 +159,106 @@ def encrypt_file(
)
return metadata, temp_output_file

@staticmethod
def encrypt_stream_gcm(
encryption_material: SnowflakeFileEncryptionMaterial,
src: IO[bytes],
out: IO[bytes],
key_aad: str = "",
data_aad: str = "",
):
logger = getLogger(__name__)
decoded_key = base64.standard_b64decode(
encryption_material.query_stage_master_key
)
key_size = len(decoded_key)
logger.debug("key_size = %s", key_size)

iv_data = SnowflakeEncryptionUtil.get_secure_random(block_size)
key_iv_data = SnowflakeEncryptionUtil.get_secure_random(key_size)
file_key = SnowflakeEncryptionUtil.get_secure_random(key_size)
backend = default_backend()
file_key_cipher = Cipher(
algorithms.AES(decoded_key), modes.GCM(key_iv_data), backend=backend
)
file_key_encryptor = file_key_cipher.encryptor()
if key_aad:
file_key_encryptor.authenticate_additional_data(
base64.standard_b64decode(key_aad)
)
encrypted_file_key = (
file_key_encryptor.update(file_key)
+ file_key_encryptor.finalize()
+ file_key_encryptor.tag
)
content_cipher = Cipher(
algorithms.AES(file_key), modes.GCM(iv_data), backend=backend
)
content_encryptor = content_cipher.encryptor()
if data_aad:
content_encryptor.authenticate_additional_data(
base64.standard_b64decode(data_aad)
)

encrypted_content = (
content_encryptor.update(src.read())
+ content_encryptor.finalize()
+ content_encryptor.tag
)
out.write(encrypted_content)

mat_desc = MaterialDescriptor(
smk_id=encryption_material.smk_id,
query_id=encryption_material.query_id,
key_size=key_size * 8,
)
metadata = EncryptionMetadata(
key=base64.b64encode(encrypted_file_key).decode("utf-8"),
iv=base64.b64encode(iv_data).decode("utf-8"),
matdesc=matdesc_to_unicode(mat_desc),
cipher=str(CipherAlgorithm.AES_GCM),
key_iv=base64.b64encode(key_iv_data).decode("utf-8"),
key_aad=key_aad,
data_aad=data_aad,
)
return metadata

@staticmethod
def encrypt_file_gcm(
encryption_material: SnowflakeFileEncryptionMaterial,
in_filename: str,
tmp_dir: str | None = None,
key_aad: str = "",
data_aad: str = "",
) -> tuple[EncryptionMetadata, str]:
"""Encrypts a file in a temporary directory.

Args:
encryption_material: The encryption material for file.
in_filename: The input file's name.
chunk_size: The size of read chunks (Default value = block_size * 4 * 1024).
tmp_dir: Temporary directory to use, optional (Default value = None).

Returns:
The encryption metadata and the encrypted file's location.
"""
logger = getLogger(__name__)
temp_output_fd, temp_output_file = tempfile.mkstemp(
text=False, dir=tmp_dir, prefix=os.path.basename(in_filename) + "#"
)
logger.debug(
"unencrypted file: %s, temp file: %s, tmp_dir: %s",
in_filename,
temp_output_file,
tmp_dir,
)
with open(in_filename, "rb") as infile:
with os.fdopen(temp_output_fd, "wb") as outfile:
metadata = SnowflakeEncryptionUtil.encrypt_stream_gcm(
encryption_material, infile, outfile, key_aad, data_aad
)
return metadata, temp_output_file

@staticmethod
def decrypt_stream(
metadata: EncryptionMetadata,
Expand Down Expand Up @@ -218,3 +329,77 @@ def decrypt_file(
metadata, encryption_material, infile, outfile, chunk_size
)
return temp_output_file

@staticmethod
def decrypt_stream_gcm(
metadata: EncryptionMetadata,
encryption_material: SnowflakeFileEncryptionMaterial,
src: IO[bytes],
out: IO[bytes],
) -> None:
"""To read from `src` stream then decrypt to `out` stream."""
key_base64 = metadata.key
iv_base64 = metadata.iv
key_iv_base64 = metadata.key_iv
decoded_key = base64.standard_b64decode(
encryption_material.query_stage_master_key
)
key_bytes = base64.standard_b64decode(key_base64)
key_bytes, key_tag = key_bytes[:-block_size], key_bytes[-block_size:]
iv_bytes = base64.standard_b64decode(iv_base64)
key_iv_bytes = base64.standard_b64decode(key_iv_base64)
key_aad = base64.standard_b64decode(metadata.key_aad)
data_aad = base64.standard_b64decode(metadata.data_aad)

backend = default_backend()
file_key_cipher = Cipher(
algorithms.AES(decoded_key),
modes.GCM(key_iv_bytes, key_tag),
backend=backend,
)
file_key_decryptor = file_key_cipher.decryptor()
if key_aad:
file_key_decryptor.authenticate_additional_data(key_aad)
file_key = file_key_decryptor.update(key_bytes) + file_key_decryptor.finalize()

src_bytes = src.read()
src_bytes, data_tag = src_bytes[:-block_size], src_bytes[-block_size:]
content_cipher = Cipher(
algorithms.AES(file_key), modes.GCM(iv_bytes, data_tag), backend=backend
)
content_decryptor = content_cipher.decryptor()
if data_aad:
content_decryptor.authenticate_additional_data(data_aad)
content = content_decryptor.update(src_bytes) + content_decryptor.finalize()
out.write(content)

@staticmethod
def decrypt_file_gcm(
metadata: EncryptionMetadata,
encryption_material: SnowflakeFileEncryptionMaterial,
in_filename: str,
tmp_dir: str | None = None,
) -> str:
"""Decrypts a file and stores the output in the temporary directory.

Args:
metadata: The file's metadata input.
encryption_material: The file's encryption material.
in_filename: The name of the input file.
chunk_size: The size of read chunks (Default value = block_size * 4 * 1024).
tmp_dir: Temporary directory to use, optional (Default value = None).

Returns:
The decrypted file's location.
"""
temp_output_file = f"{os.path.basename(in_filename)}#{random_string()}"
if tmp_dir:
temp_output_file = os.path.join(tmp_dir, temp_output_file)

logger.debug("encrypted file: %s, tmp file: %s", in_filename, temp_output_file)
with open(in_filename, "rb") as infile:
with open(temp_output_file, "wb") as outfile:
SnowflakeEncryptionUtil.decrypt_stream_gcm(
metadata, encryption_material, infile, outfile
)
return temp_output_file
21 changes: 19 additions & 2 deletions src/snowflake/connector/gcs_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .constants import (
FILE_PROTOCOL,
HTTP_HEADER_CONTENT_ENCODING,
CipherAlgorithm,
FileHeader,
ResultStatus,
kilobyte,
Expand Down Expand Up @@ -119,6 +120,11 @@ def _upload_chunk(self, chunk_id: int, chunk: bytes) -> None:
}

if self.encryption_metadata:
algorithm = (
"AES_GCM_256"
if self.encryption_metadata.cipher == CipherAlgorithm.AES_GCM
else "AES_CBC_256"
)
gcs_headers.update(
{
GCS_METADATA_ENCRYPTIONDATAPROP: json.dumps(
Expand All @@ -127,13 +133,16 @@ def _upload_chunk(self, chunk_id: int, chunk: bytes) -> None:
"WrappedContentKey": {
"KeyId": "symmKey1",
"EncryptedKey": self.encryption_metadata.key,
"Algorithm": "AES_CBC_256",
"Algorithm": algorithm,
},
"EncryptionAgent": {
"Protocol": "1.0",
"EncryptionAlgorithm": "AES_CBC_256",
"EncryptionAlgorithm": algorithm,
},
"ContentEncryptionIV": self.encryption_metadata.iv,
"KeyEncryptionIV": self.encryption_metadata.key_iv,
"KeyAad": self.encryption_metadata.key_aad,
"DataAad": self.encryption_metadata.data_aad,
"KeyWrappingMetadata": {"EncryptionLibrary": "Java 5.3.0"},
}
),
Expand Down Expand Up @@ -208,6 +217,14 @@ def generate_url_and_rest_args() -> (
if GCS_METADATA_MATDESC_KEY in response.headers
else None
),
cipher=(
str(CipherAlgorithm.AES_GCM)
if "AES_GCM" in encryptiondata["WrappedContentKey"]["Algorithm"]
else str(CipherAlgorithm.AES_CBC)
),
key_iv=encryptiondata.get("KeyEncryptionIV", ""),
key_aad=encryptiondata.get("KeyAad", ""),
data_aad=encryptiondata.get("DataAad"),
)

meta.gcs_file_header_digest = response.headers.get(GCS_METADATA_SFC_DIGEST)
Expand Down
Loading
Loading