diff --git a/ape_aws/accounts.py b/ape_aws/accounts.py index 8c6ddde..33dd256 100644 --- a/ape_aws/accounts.py +++ b/ape_aws/accounts.py @@ -1,37 +1,78 @@ +from json import dumps from functools import cached_property +from pathlib import Path from typing import Any, Iterator, Optional from ape.api.accounts import AccountAPI, AccountContainerAPI, TransactionAPI from ape.types import AddressType, MessageSignature, SignableMessage, TransactionSignature +from ape.utils.validators import _validate_account_passphrase + from eth_account._utils.legacy_transactions import serializable_unsigned_transaction_from_dict from eth_account.messages import _hash_eip191_message, encode_defunct +from eth_account import Account as EthAccount from eth_pydantic_types import HexBytes from eth_typing import Hash32 -from eth_utils import keccak, to_checksum_address +from eth_utils import keccak, to_checksum_address, to_bytes from .client import kms_client from .utils import _convert_der_to_rsv class AwsAccountContainer(AccountContainerAPI): + loaded_accounts: dict[str, "KmsAccount"] = {} + + def model_post_init(self, __context: Any): + print("Initializing AWS KMS Account Container") + print([acc.alias for acc in self.accounts]) + + @property + def _keyfiles(self) -> list[Path]: + return [file for file in self.data_folder.glob("*.json")] + @property def aliases(self) -> Iterator[str]: - return map(lambda x: x.alias, kms_client.raw_aliases) + return map(lambda x: x.alias.replace("alias/", ""), kms_client.raw_aliases) def __len__(self) -> int: return len(kms_client.raw_aliases) @property def accounts(self) -> Iterator[AccountAPI]: + def _load_account(key_alias, key_id, key_arn) -> Iterator[AccountAPI]: + filename = f"{key_alias}.json" + keyfile = self.data_folder.joinpath(filename) + if filename not in self._keyfiles: + self.loaded_accounts[keyfile.stem] = KmsAccount( + key_alias=key_alias, + key_id=key_id, + key_arn=key_arn, + ) + keyfile.write_text( + self.loaded_accounts[keyfile.stem].dump_to_json() + ) + return self.loaded_accounts[keyfile.stem] return map( - lambda x: KmsAccount( - key_alias=x.alias, + lambda x: _load_account( + key_alias=x.alias.replace("alias/", ""), key_id=x.key_id, key_arn=x.arn, ), kms_client.raw_aliases, ) + def add_private_key(self, alias, passphrase, private_key): + kms_account = self.loaded_accounts[alias] + _validate_account_passphrase(passphrase) + account = EthAccount.from_key(to_bytes(hexstr=private_key)) + keyfile = self.data_folder.joinpath(f"{alias}.json") + account = EthAccount.encrypt(account.key, passphrase) + model = kms_account.model_dump() + model["address"] = kms_account.address + del account["address"] + model.update(account) + keyfile.write_text(dumps(model, indent=4)) + print("Key cached successfully") + return class KmsAccount(AccountAPI): key_alias: str @@ -40,7 +81,7 @@ class KmsAccount(AccountAPI): @property def alias(self) -> str: - return self.key_alias.replace("alias/", "") + return self.key_alias @property def public_key(self): @@ -105,3 +146,8 @@ def sign_transaction(self, txn: TransactionAPI, **signer_options) -> Optional[Tr ) return txn + + def dump_to_json(self, indent: int = 4): + model = self.model_dump() + model["address"] = self.address + return dumps(model, indent=indent) diff --git a/ape_aws/client.py b/ape_aws/client.py index 98ab26e..fccdb12 100644 --- a/ape_aws/client.py +++ b/ape_aws/client.py @@ -1,13 +1,17 @@ +from ape.utils.basemodel import ManagerAccessMixin + from cryptography.hazmat.primitives.asymmetric import ec, padding from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.backends import default_backend from eth_account import Account +from eth_utils import to_bytes from datetime import datetime from typing import ClassVar +from pydantic import BaseModel, Field, ConfigDict, field_validator import boto3 # type: ignore[import] -from pydantic import BaseModel, Field, ConfigDict, field_validator +import json class AliasResponse(BaseModel): @@ -91,13 +95,17 @@ def validate_private_key(cls, value): default_backend() ) if value.startswith('0x'): - return value[2:] + value = bytes.fromhex(value[2:]) return value @property def get_account(self): return Account.privateKeyToAccount(self.private_key) + @property + def private_key_hex(self): + return self.private_key.private_numbers().private_value.to_bytes(32, "big").hex() + @property def private_key_bin(self): """ @@ -142,6 +150,14 @@ def encrypted_private_key(self): ) ) + def import_account_from_private_key(self, passphrase: str): + account = Account.from_key(to_bytes(hexstr=self.private_key_hex)) + path = ManagerAccessMixin.account_manager.containers["accounts"].data_folder.joinpath( + f"{self.alias}.json" + ) + path.write_text(json.dumps(Account.encrypt(account.key, passphrase))) + return KmsAccount() + class DeleteKey(KeyBaseModel): key_id: str diff --git a/ape_aws/kms/_cli.py b/ape_aws/kms/_cli.py index adb8f13..33a9f61 100644 --- a/ape_aws/kms/_cli.py +++ b/ape_aws/kms/_cli.py @@ -1,6 +1,7 @@ import click from ape.cli import ape_cli_context +from ape_aws.accounts import AwsAccountContainer, KmsAccount from ape_aws.client import ( CreateKey, DeleteKey, @@ -95,6 +96,14 @@ def import_key( administrators: list[str], users: list[str], ): + def ask_for_passphrase(): + return click.prompt( + "Create Passphrase to encrypt account", + hide_input=True, + confirmation_prompt=True, + ) + + passphrase = ask_for_passphrase() key_spec = ImportKeyRequest( alias=alias_name, description=description, @@ -112,8 +121,12 @@ def import_key( private_key=private_key, import_token=import_token, ) - key_id = kms_client.import_key(import_key_spec) + response = kms_client.import_key(import_key_spec) + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + cli_ctx.abort("Key failed to import into KMS") cli_ctx.logger.success(f"Key imported successfully with ID: {key_id}") + aws_account_container = AwsAccountContainer(name="aws", account_type=KmsAccount) + aws_account_container.add_private_key(alias_name, passphrase, import_key_spec.private_key_hex) # TODO: Add `ape aws kms sign-message [message]`