Skip to content

Commit

Permalink
Merge pull request #2521 from andrewwhitehead/fix/session-usage
Browse files Browse the repository at this point in the history
Avoid multiple open wallet connections
  • Loading branch information
swcurran authored Sep 29, 2023
2 parents bad688e + ced024a commit 950732f
Show file tree
Hide file tree
Showing 14 changed files with 154 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,6 @@ async def _get_suite(
did_info: DIDInfo = None,
):
"""Get signature suite for issuance of verification."""
session = await self.profile.session()
wallet = session.inject(BaseWallet)

# Get signature class based on proof type
SignatureClass = PROOF_TYPE_SIGNATURE_SUITE_MAPPING[proof_type]

Expand All @@ -314,7 +311,7 @@ async def _get_suite(
verification_method=verification_method,
proof=proof,
key_pair=WalletKeyPair(
wallet=wallet,
profile=self.profile,
key_type=SIGNATURE_SUITE_KEY_TYPE_MAPPING[SignatureClass],
public_key_base58=did_info.verkey if did_info else None,
),
Expand Down
52 changes: 20 additions & 32 deletions aries_cloudagent/protocols/present_proof/dif/pres_exch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def __init__(
async def _get_issue_suite(
self,
*,
wallet: BaseWallet,
issuer_id: str,
):
"""Get signature suite for signing presentation."""
Expand All @@ -139,17 +138,13 @@ async def _get_issue_suite(
return SignatureClass(
verification_method=verification_method,
key_pair=WalletKeyPair(
wallet=wallet,
profile=self.profile,
key_type=self.ISSUE_SIGNATURE_SUITE_KEY_TYPE_MAPPING[SignatureClass],
public_key_base58=did_info.verkey if did_info else None,
),
)

async def _get_derive_suite(
self,
*,
wallet: BaseWallet,
):
async def _get_derive_suite(self):
"""Get signature suite for deriving credentials."""
# Get signature class based on proof type
SignatureClass = self.DERIVED_PROOF_TYPE_SIGNATURE_SUITE_MAPPING[
Expand All @@ -159,7 +154,7 @@ async def _get_derive_suite(
# Generically create signature class
return SignatureClass(
key_pair=WalletKeyPair(
wallet=wallet,
profile=self.profile,
key_type=self.DERIVE_SIGNATURE_SUITE_KEY_TYPE_MAPPING[SignatureClass],
),
)
Expand Down Expand Up @@ -406,18 +401,14 @@ async def filter_constraints(
new_credential_dict = self.reveal_doc(
credential_dict=credential_dict, constraints=constraints
)
async with self.profile.session() as session:
wallet = session.inject(BaseWallet)
derive_suite = await self._get_derive_suite(
wallet=wallet,
)
signed_new_credential_dict = await derive_credential(
credential=credential_dict,
reveal_document=new_credential_dict,
suite=derive_suite,
document_loader=document_loader,
)
credential = self.create_vcrecord(signed_new_credential_dict)
derive_suite = await self._get_derive_suite()
signed_new_credential_dict = await derive_credential(
credential=credential_dict,
reveal_document=new_credential_dict,
suite=derive_suite,
document_loader=document_loader,
)
credential = self.create_vcrecord(signed_new_credential_dict)
result.append(credential)
return result

Expand Down Expand Up @@ -1297,18 +1288,15 @@ async def create_vp(
vp["presentation_submission"] = submission_property.serialize()
if self.proof_type is BbsBlsSignature2020.signature_type:
vp["@context"].append(SECURITY_CONTEXT_BBS_URL)
async with self.profile.session() as session:
wallet = session.inject(BaseWallet)
issue_suite = await self._get_issue_suite(
wallet=wallet,
issuer_id=issuer_id,
)
signed_vp = await sign_presentation(
presentation=vp,
suite=issue_suite,
challenge=challenge,
document_loader=document_loader,
)
issue_suite = await self._get_issue_suite(
issuer_id=issuer_id,
)
signed_vp = await sign_presentation(
presentation=vp,
suite=issue_suite,
challenge=challenge,
document_loader=document_loader,
)
result_vp.append(signed_vp)
if len(result_vp) == 1:
return result_vp[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
WalletKeyPair,
)
from ......vc.vc_ld.verify import verify_presentation
from ......wallet.base import BaseWallet
from ......wallet.key_type import ED25519, BLS12381G2

from .....problem_report.v1_0.message import ProblemReport
Expand Down Expand Up @@ -65,13 +64,13 @@ class DIFPresFormatHandler(V20PresFormatHandler):
ISSUE_SIGNATURE_SUITE_KEY_TYPE_MAPPING[BbsBlsSignature2020] = BLS12381G2
ISSUE_SIGNATURE_SUITE_KEY_TYPE_MAPPING[BbsBlsSignatureProof2020] = BLS12381G2

async def _get_all_suites(self, wallet: BaseWallet):
async def _get_all_suites(self):
"""Get all supported suites for verifying presentation."""
suites = []
for suite, key_type in self.ISSUE_SIGNATURE_SUITE_KEY_TYPE_MAPPING.items():
suites.append(
suite(
key_pair=WalletKeyPair(wallet=wallet, key_type=key_type),
key_pair=WalletKeyPair(profile=self._profile, key_type=key_type),
)
)
return suites
Expand Down Expand Up @@ -471,33 +470,31 @@ async def verify_pres(self, pres_ex_record: V20PresExRecord) -> V20PresExRecord:
presentation exchange record, updated
"""
async with self._profile.session() as session:
wallet = session.inject(BaseWallet)
dif_proof = pres_ex_record.pres.attachment(DIFPresFormatHandler.format)
pres_request = pres_ex_record.pres_request.attachment(
DIFPresFormatHandler.format
)
challenge = None
if "options" in pres_request:
challenge = pres_request["options"].get("challenge", str(uuid4()))
if not challenge:
challenge = str(uuid4())
if isinstance(dif_proof, Sequence):
for proof in dif_proof:
pres_ver_result = await verify_presentation(
presentation=proof,
suites=await self._get_all_suites(wallet=wallet),
document_loader=self._profile.inject(DocumentLoader),
challenge=challenge,
)
if not pres_ver_result.verified:
break
else:
dif_proof = pres_ex_record.pres.attachment(DIFPresFormatHandler.format)
pres_request = pres_ex_record.pres_request.attachment(
DIFPresFormatHandler.format
)
challenge = None
if "options" in pres_request:
challenge = pres_request["options"].get("challenge", str(uuid4()))
if not challenge:
challenge = str(uuid4())
if isinstance(dif_proof, Sequence):
for proof in dif_proof:
pres_ver_result = await verify_presentation(
presentation=dif_proof,
suites=await self._get_all_suites(wallet=wallet),
presentation=proof,
suites=await self._get_all_suites(),
document_loader=self._profile.inject(DocumentLoader),
challenge=challenge,
)
pres_ex_record.verified = json.dumps(pres_ver_result.verified)
return pres_ex_record
if not pres_ver_result.verified:
break
else:
pres_ver_result = await verify_presentation(
presentation=dif_proof,
suites=await self._get_all_suites(),
document_loader=self._profile.inject(DocumentLoader),
challenge=challenge,
)
pres_ex_record.verified = json.dumps(pres_ver_result.verified)
return pres_ex_record
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def test_validate_fields(self):
self.handler.validate_fields(PRES_20, incorrect_pres)

async def test_get_all_suites(self):
suites = await self.handler._get_all_suites(self.wallet)
suites = await self.handler._get_all_suites()
assert len(suites) == 4
types = [
Ed25519Signature2018,
Expand Down
56 changes: 32 additions & 24 deletions aries_cloudagent/vc/ld_proofs/crypto/tests/test_wallet_key_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@

from aries_cloudagent.wallet.key_type import ED25519


from .....core.in_memory import InMemoryProfile
from .....wallet.in_memory import InMemoryWallet
from ...error import LinkedDataProofException

from ..wallet_key_pair import WalletKeyPair


class TestWalletKeyPair(TestCase):
async def setUp(self):
self.wallet = async_mock.MagicMock()
self.profile = InMemoryProfile.test_profile()

async def test_sign_x_no_public_key(self):
key_pair = WalletKeyPair(wallet=self.wallet, key_type=ED25519)
key_pair = WalletKeyPair(profile=self.profile, key_type=ED25519)

with self.assertRaises(LinkedDataProofException) as context:
await key_pair.sign(b"Message")
Expand All @@ -22,23 +23,26 @@ async def test_sign_x_no_public_key(self):
async def test_sign(self):
public_key_base58 = "verkey"
key_pair = WalletKeyPair(
wallet=self.wallet,
profile=self.profile,
key_type=ED25519,
public_key_base58=public_key_base58,
)
signed = async_mock.MagicMock()

self.wallet.sign_message = async_mock.CoroutineMock(return_value=signed)

singed_ret = await key_pair.sign(b"Message")
with async_mock.patch.object(
InMemoryWallet,
"sign_message",
async_mock.CoroutineMock(return_value=signed),
) as sign_message:
singed_ret = await key_pair.sign(b"Message")

assert signed == singed_ret
self.wallet.sign_message.assert_called_once_with(
message=b"Message", from_verkey=public_key_base58
)
assert signed == singed_ret
sign_message.assert_called_once_with(
message=b"Message", from_verkey=public_key_base58
)

async def test_verify_x_no_public_key(self):
key_pair = WalletKeyPair(wallet=self.wallet, key_type=ED25519)
key_pair = WalletKeyPair(profile=self.profile, key_type=ED25519)

with self.assertRaises(LinkedDataProofException) as context:
await key_pair.verify(b"Message", b"signature")
Expand All @@ -47,24 +51,28 @@ async def test_verify_x_no_public_key(self):
async def test_verify(self):
public_key_base58 = "verkey"
key_pair = WalletKeyPair(
wallet=self.wallet,
profile=self.profile,
key_type=ED25519,
public_key_base58=public_key_base58,
)
self.wallet.verify_message = async_mock.CoroutineMock(return_value=True)

verified = await key_pair.verify(b"Message", b"signature")

assert verified
self.wallet.verify_message.assert_called_once_with(
message=b"Message",
signature=b"signature",
from_verkey=public_key_base58,
key_type=ED25519,
)
with async_mock.patch.object(
InMemoryWallet,
"verify_message",
async_mock.CoroutineMock(return_value=True),
) as verify_message:
verified = await key_pair.verify(b"Message", b"signature")

assert verified
verify_message.assert_called_once_with(
message=b"Message",
signature=b"signature",
from_verkey=public_key_base58,
key_type=ED25519,
)

async def test_from_verification_method_x_no_public_key_base58(self):
key_pair = WalletKeyPair(wallet=self.wallet, key_type=ED25519)
key_pair = WalletKeyPair(profile=self.profile, key_type=ED25519)

with self.assertRaises(LinkedDataProofException) as context:
key_pair.from_verification_method({})
Expand Down
35 changes: 20 additions & 15 deletions aries_cloudagent/vc/ld_proofs/crypto/wallet_key_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from typing import List, Optional, Union

from ....wallet.util import b58_to_bytes
from ....wallet.key_type import KeyType
from ....core.profile import Profile
from ....wallet.base import BaseWallet
from ....wallet.key_type import KeyType
from ....wallet.util import b58_to_bytes

from ..error import LinkedDataProofException

Expand All @@ -17,13 +18,13 @@ class WalletKeyPair(KeyPair):
def __init__(
self,
*,
wallet: BaseWallet,
profile: Profile,
key_type: KeyType,
public_key_base58: Optional[str] = None,
) -> None:
"""Initialize new WalletKeyPair instance."""
super().__init__()
self.wallet = wallet
self.profile = profile
self.key_type = key_type
self.public_key_base58 = public_key_base58

Expand All @@ -33,10 +34,12 @@ async def sign(self, message: Union[List[bytes], bytes]) -> bytes:
raise LinkedDataProofException(
"Unable to sign message with WalletKey: No key to sign with"
)
return await self.wallet.sign_message(
message=message,
from_verkey=self.public_key_base58,
)
async with self.profile.session() as session:
wallet = session.inject(BaseWallet)
return await wallet.sign_message(
message=message,
from_verkey=self.public_key_base58,
)

async def verify(
self, message: Union[List[bytes], bytes], signature: bytes
Expand All @@ -47,12 +50,14 @@ async def verify(
"Unable to verify message with key pair: No key to verify with"
)

return await self.wallet.verify_message(
message=message,
signature=signature,
from_verkey=self.public_key_base58,
key_type=self.key_type,
)
async with self.profile.session() as session:
wallet = session.inject(BaseWallet)
return await wallet.verify_message(
message=message,
signature=signature,
from_verkey=self.public_key_base58,
key_type=self.key_type,
)

def from_verification_method(self, verification_method: dict) -> "WalletKeyPair":
"""Create new WalletKeyPair from public key in verification method."""
Expand All @@ -62,7 +67,7 @@ def from_verification_method(self, verification_method: dict) -> "WalletKeyPair"
)

return WalletKeyPair(
wallet=self.wallet,
profile=self.profile,
key_type=self.key_type,
public_key_base58=verification_method["publicKeyBase58"],
)
Expand Down
Loading

0 comments on commit 950732f

Please sign in to comment.