diff --git a/ape_aws/client.py b/ape_aws/client.py index 9869a32..bc731ae 100644 --- a/ape_aws/client.py +++ b/ape_aws/client.py @@ -5,7 +5,7 @@ from typing import ClassVar import boto3 # type: ignore[import] -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, Field, ConfigDict, field_validator class AliasResponse(BaseModel): @@ -78,19 +78,22 @@ class ImportKeyRequest(CreateKeyModel): class ImportKey(ImportKeyRequest): key_id: str = Field(default=None, alias="KeyId") public_key: bytes = Field(default=None, alias="PublicKey") - private_key: bytes = Field( - default=ec.generate_private_key( - ec.SeCP256K1(), - default_backend() - ).private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ), - alias="PrivateKey", - ) + private_key: bytes | None = Field(default=None, alias="PrivateKey") import_token: bytes = Field(default=None, alias="ImportToken") + @field_validator("private_key") + def validate_private_key(cls, value): + if not isinstance(value, bytes): + return ec.generate_private_key( + ec.SECP256K1(), + default_backend() + ).private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + return value + @property def encrypted_key(self): if not self.public_key: @@ -172,7 +175,6 @@ def create_key(self, key_spec: CreateKey | ImportKey): return key_id def import_key(self, key_spec: ImportKey): - breakpoint() return self.client.import_key_material( KeyId=key_spec.key_id, ImportToken=key_spec.import_token, diff --git a/ape_aws/kms/_cli.py b/ape_aws/kms/_cli.py index e112789..f31155a 100644 --- a/ape_aws/kms/_cli.py +++ b/ape_aws/kms/_cli.py @@ -63,6 +63,13 @@ def create_key( @kms.command(name="import") @ape_cli_context() +@click.option( + "-p", + "--private-key", + "private_key", + multiple=False, + help="The private key to import", +) @click.option( "-a", "--admin", @@ -81,7 +88,6 @@ def create_key( ) @click.argument("alias_name") @click.argument("description") -@click.argument("private_key") def import_key( cli_ctx, alias_name: str, @@ -98,8 +104,8 @@ def import_key( ) key_id = kms_client.create_key(key_spec) create_key_response = kms_client.get_parameters(key_id) - public_key = base64.b64encode(create_key_response["PublicKey"]) - import_token = base64.b64encode(create_key_response["ImportToken"]) + public_key = create_key_response["PublicKey"] + import_token = create_key_response["ImportToken"] import_key_spec = ImportKey( **key_spec.model_dump(), key_id=key_id,