diff --git a/aries_cloudagent/connections/base_manager.py b/aries_cloudagent/connections/base_manager.py index 2e9911f90e..4df1111c37 100644 --- a/aries_cloudagent/connections/base_manager.py +++ b/aries_cloudagent/connections/base_manager.py @@ -5,7 +5,7 @@ """ import logging -from typing import Optional, List, Sequence, Tuple, Text +from typing import List, Optional, Sequence, Text, Tuple, Union from multiformats import multibase, multicodec from pydid import ( @@ -16,34 +16,43 @@ import pydid from pydid.verification_method import ( Ed25519VerificationKey2018, - JsonWebKey2020, Ed25519VerificationKey2020, + JsonWebKey2020, ) +from ..cache.base import BaseCache +from ..config.base import InjectionError from ..config.logging import get_logger_inst from ..core.error import BaseError from ..core.profile import Profile from ..did.did_key import DIDKey +from ..multitenant.base import BaseMultitenantManager +from ..protocols.connections.v1_0.message_types import ARIES_PROTOCOL as CONN_PROTO from ..protocols.connections.v1_0.messages.connection_invitation import ( ConnectionInvitation, ) from ..protocols.coordinate_mediation.v1_0.models.mediation_record import ( MediationRecord, ) -from ..protocols.coordinate_mediation.v1_0.route_manager import ( - RouteManager, -) +from ..protocols.coordinate_mediation.v1_0.route_manager import RouteManager +from ..protocols.discovery.v2_0.manager import V20DiscoveryMgr +from ..protocols.out_of_band.v1_0.messages.invitation import InvitationMessage from ..resolver.base import ResolverError from ..resolver.did_resolver import DIDResolver from ..storage.base import BaseStorage -from ..storage.error import StorageNotFoundError +from ..storage.error import StorageError, StorageNotFoundError from ..storage.record import StorageRecord +from ..transport.inbound.receipt import MessageReceipt from ..wallet.base import BaseWallet +from ..wallet.crypto import create_keypair, seed_to_did from ..wallet.did_info import DIDInfo +from ..wallet.did_method import SOV +from ..wallet.error import WalletNotFoundError +from ..wallet.key_type import ED25519 +from ..wallet.util import b64_to_bytes, bytes_to_b58 from .models.conn_record import ConnRecord from .models.connection_target import ConnectionTarget from .models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service -from ..wallet.util import bytes_to_b58, b64_to_bytes class BaseConnectionManagerError(BaseError): @@ -73,9 +82,9 @@ def __init__(self, profile: Profile): async def create_did_document( self, did_info: DIDInfo, - inbound_connection_id: str = None, - svc_endpoints: Sequence[str] = None, - mediation_records: List[MediationRecord] = None, + inbound_connection_id: Optional[str] = None, + svc_endpoints: Optional[Sequence[str]] = None, + mediation_records: Optional[List[MediationRecord]] = None, ) -> DIDDoc: """Create our DID doc for a given DID. @@ -213,7 +222,10 @@ async def add_key_for_did(self, did: str, key: str): record = StorageRecord(self.RECORD_TYPE_DID_KEY, key, {"did": did, "key": key}) async with self._profile.session() as session: storage: BaseStorage = session.inject(BaseStorage) - await storage.add_record(record) + try: + await storage.find_record(self.RECORD_TYPE_DID_KEY, {"key": key}) + except StorageNotFoundError: + await storage.add_record(record) async def find_did_for_key(self, key: str) -> str: """Find the DID previously associated with a key. @@ -236,15 +248,10 @@ async def remove_keys_for_did(self, did: str): storage: BaseStorage = session.inject(BaseStorage) await storage.delete_all_records(self.RECORD_TYPE_DID_KEY, {"did": did}) - async def resolve_invitation( + async def resolve_didcomm_services( self, did: str, service_accept: Optional[Sequence[Text]] = None - ): - """ - Resolve invitation with the DID Resolver. - - Args: - did: Document ID to resolve - """ + ) -> Tuple[ResolvedDocument, List[DIDCommService]]: + """Resolve a DIDComm services for a given DID.""" if not did.startswith("did:"): # DID is bare indy "nym" # prefix with did:sov: for backwards compatibility @@ -269,6 +276,41 @@ async def resolve_invitation( key=lambda service: service.priority, ) + return doc, didcomm_services + + async def verification_methods_for_service( + self, doc: ResolvedDocument, service: DIDCommService + ) -> Tuple[List[VerificationMethod], List[VerificationMethod]]: + """Dereference recipient and routing keys. + + Returns verification methods for a DIDComm service to enable extracting + key material. + """ + resolver = self._profile.inject(DIDResolver) + recipient_keys: List[VerificationMethod] = [ + await resolver.dereference_verification_method( + self._profile, url, document=doc + ) + for url in service.recipient_keys + ] + routing_keys: List[VerificationMethod] = [ + await resolver.dereference_verification_method( + self._profile, url, document=doc + ) + for url in service.routing_keys + ] + return recipient_keys, routing_keys + + async def resolve_invitation( + self, did: str, service_accept: Optional[Sequence[Text]] = None + ) -> Tuple[str, List[str], List[str]]: + """ + Resolve invitation with the DID Resolver. + + Args: + did: Document ID to resolve + """ + doc, didcomm_services = await self.resolve_didcomm_services(did, service_accept) if not didcomm_services: raise BaseConnectionManagerError( "Cannot connect via public DID that has no associated DIDComm services" @@ -276,15 +318,10 @@ async def resolve_invitation( first_didcomm_service, *_ = didcomm_services - endpoint = first_didcomm_service.service_endpoint - recipient_keys: List[VerificationMethod] = [ - await resolver.dereference(self._profile, url, document=doc) - for url in first_didcomm_service.recipient_keys - ] - routing_keys: List[VerificationMethod] = [ - await resolver.dereference(self._profile, url, document=doc) - for url in first_didcomm_service.routing_keys - ] + endpoint = str(first_didcomm_service.service_endpoint) + recipient_keys, routing_keys = await self.verification_methods_for_service( + doc, first_didcomm_service + ) return ( endpoint, @@ -295,6 +332,62 @@ async def resolve_invitation( [self._extract_key_material_in_base58_format(key) for key in routing_keys], ) + async def record_keys_for_public_did(self, did: str): + """Record the keys for a public DID. + + This is required to correlate sender verkeys back to a connection. + """ + doc, didcomm_services = await self.resolve_didcomm_services(did) + for service in didcomm_services: + recips, _ = await self.verification_methods_for_service(doc, service) + for recip in recips: + await self.add_key_for_did( + did, self._extract_key_material_in_base58_format(recip) + ) + + async def resolve_connection_targets( + self, + did: str, + sender_verkey: Optional[str] = None, + their_label: Optional[str] = None, + ) -> List[ConnectionTarget]: + """Resolve connection targets for a DID.""" + self._logger.debug("Resolving connection targets for DID %s", did) + doc, didcomm_services = await self.resolve_didcomm_services(did) + self._logger.debug("Resolved DID document: %s", doc) + self._logger.debug("Resolved DIDComm services: %s", didcomm_services) + targets = [] + for service in didcomm_services: + try: + recips, routing = await self.verification_methods_for_service( + doc, service + ) + endpoint = str(service.service_endpoint) + targets.append( + ConnectionTarget( + did=doc.id, + endpoint=endpoint, + label=their_label, + recipient_keys=[ + self._extract_key_material_in_base58_format(key) + for key in recips + ], + routing_keys=[ + self._extract_key_material_in_base58_format(key) + for key in routing + ], + sender_key=sender_verkey, + ) + ) + except ResolverError: + self._logger.exception( + "Failed to resolve service details while determining " + "connection targets; skipping service" + ) + continue + + return targets + @staticmethod def _extract_key_material_in_base58_format(method: VerificationMethod) -> str: if isinstance(method, Ed25519VerificationKey2018): @@ -326,6 +419,117 @@ def _extract_key_material_in_base58_format(method: VerificationMethod) -> str: f"Key type {type(method).__name__} is not supported" ) + async def _fetch_connection_targets_for_invitation( + self, + connection: ConnRecord, + invitation: Union[ConnectionInvitation, InvitationMessage], + sender_verkey: str, + ) -> Sequence[ConnectionTarget]: + """Get a list of connection targets for an invitation. + + This will extract target info for either a connection or OOB invitation. + + Args: + connection: ConnRecord the invitation is associated with. + invitation: Connection or OOB invitation retrieved from conn record. + + Returns: + A list of `ConnectionTarget` objects + """ + if isinstance(invitation, ConnectionInvitation): + # conn protocol invitation + if invitation.did: + did = invitation.did + ( + endpoint, + recipient_keys, + routing_keys, + ) = await self.resolve_invitation(did) + + else: + endpoint = invitation.endpoint + recipient_keys = invitation.recipient_keys + routing_keys = invitation.routing_keys + else: + # out-of-band invitation + oob_service_item = invitation.services[0] + if isinstance(oob_service_item, str): + ( + endpoint, + recipient_keys, + routing_keys, + ) = await self.resolve_invitation(oob_service_item) + + else: + endpoint = oob_service_item.service_endpoint + recipient_keys = [ + DIDKey.from_did(k).public_key_b58 + for k in oob_service_item.recipient_keys + ] + routing_keys = [ + DIDKey.from_did(k).public_key_b58 + for k in oob_service_item.routing_keys + ] + + return [ + ConnectionTarget( + did=connection.their_did, + endpoint=endpoint, + label=invitation.label if invitation else None, + recipient_keys=recipient_keys, + routing_keys=routing_keys, + sender_key=sender_verkey, + ) + ] + + async def _fetch_targets_for_connection_in_progress( + self, connection: ConnRecord, sender_verkey: str + ) -> Sequence[ConnectionTarget]: + """Get a list of connection targets from an incomplete `ConnRecord`. + + This covers retrieving targets for connections that are still in the + process of bootstrapping. This includes connections that are in states + invitation-received or request-received. + + Args: + connection: The connection record (with associated `DIDDoc`) + used to generate the connection target + Returns: + A list of `ConnectionTarget` objects + """ + if ( + connection.invitation_msg_id + or connection.invitation_key + or not connection.their_did + ): # invitation received or sending request to invitation + async with self._profile.session() as session: + invitation = await connection.retrieve_invitation(session) + targets = await self._fetch_connection_targets_for_invitation( + connection, + invitation, + sender_verkey, + ) + else: # sending implicit request + # request is implicit; did isn't set if we've received an + # invitation, only the invitation key + ( + endpoint, + recipient_keys, + routing_keys, + ) = await self.resolve_invitation(connection.their_did) + targets = [ + ConnectionTarget( + did=connection.their_did, + endpoint=endpoint, + label=None, + recipient_keys=recipient_keys, + routing_keys=routing_keys, + sender_key=sender_verkey, + ) + ] + + return targets + async def fetch_connection_targets( self, connection: ConnRecord ) -> Sequence[ConnectionTarget]: @@ -338,98 +542,92 @@ async def fetch_connection_targets( if not connection.my_did: self._logger.debug("No local DID associated with connection") - return None - results = None + return [] + + async with self._profile.session() as session: + wallet = session.inject(BaseWallet) + my_info = await wallet.get_local_did(connection.my_did) if ( ConnRecord.State.get(connection.state) in (ConnRecord.State.INVITATION, ConnRecord.State.REQUEST) and ConnRecord.Role.get(connection.their_role) is ConnRecord.Role.RESPONDER - ): - if ( - connection.invitation_msg_id - or connection.invitation_key - or not connection.their_did - ): - async with self._profile.session() as session: - invitation = await connection.retrieve_invitation(session) - if isinstance( - invitation, ConnectionInvitation - ): # conn protocol invitation - if invitation.did: - did = invitation.did - ( - endpoint, - recipient_keys, - routing_keys, - ) = await self.resolve_invitation(did) + ): # invitation received or sending request + return await self._fetch_targets_for_connection_in_progress( + connection, my_info.verkey + ) - else: - endpoint = invitation.endpoint - recipient_keys = invitation.recipient_keys - routing_keys = invitation.routing_keys - else: # out-of-band invitation - oob_service_item = invitation.services[0] - if isinstance(oob_service_item, str): - ( - endpoint, - recipient_keys, - routing_keys, - ) = await self.resolve_invitation(oob_service_item) + if not connection.their_did: + self._logger.debug("No target DID associated with connection") + return [] - else: - endpoint = oob_service_item.service_endpoint - recipient_keys = [ - DIDKey.from_did(k).public_key_b58 - for k in oob_service_item.recipient_keys - ] - routing_keys = [ - DIDKey.from_did(k).public_key_b58 - for k in oob_service_item.routing_keys - ] - else: - if connection.their_did: - invitation = None - did = connection.their_did - ( - endpoint, - recipient_keys, - routing_keys, - ) = await self.resolve_invitation(did) + return await self.resolve_connection_targets( + connection.their_did, my_info.verkey, connection.their_label + ) - async with self._profile.session() as session: - wallet = session.inject(BaseWallet) - my_info = await wallet.get_local_did(connection.my_did) + async def get_connection_targets( + self, + *, + connection_id: Optional[str] = None, + connection: Optional[ConnRecord] = None, + ): + """Create a connection target from a `ConnRecord`. - results = [ - ConnectionTarget( - did=connection.their_did, - endpoint=endpoint, - label=invitation.label if invitation else None, - recipient_keys=recipient_keys, - routing_keys=routing_keys, - sender_key=my_info.verkey, - ) - ] + Args: + connection_id: The connection ID to search for + connection: The connection record itself, if already available + """ + if connection_id is None and connection is None: + raise ValueError("Must supply either connection_id or connection") + + if not connection_id: + assert connection + connection_id = connection.connection_id + + cache = self._profile.inject_or(BaseCache) + cache_key = f"connection_target::{connection_id}" + if cache: + async with cache.acquire(cache_key) as entry: + if entry.result: + self._logger.debug("Connection targets retrieved from cache") + targets = [ + ConnectionTarget.deserialize(row) for row in entry.result + ] + else: + if not connection: + async with self._profile.session() as session: + connection = await ConnRecord.retrieve_by_id( + session, connection_id + ) + + targets = await self.fetch_connection_targets(connection) + + if connection.state == ConnRecord.State.COMPLETED.rfc160: + # Only set cache if connection has reached completed state + # Otherwise, a replica that participated early in exchange + # may have bad data set in cache. + self._logger.debug("Caching connection targets") + await entry.set_result( + [row.serialize() for row in targets], 3600 + ) + else: + self._logger.debug( + "Not caching connection targets for connection in " + f"state ({connection.state})" + ) else: - if not connection.their_did: - self._logger.debug("No target DID associated with connection") - return None - - did_doc, _ = await self.fetch_did_document(connection.their_did) - - async with self._profile.session() as session: - wallet = session.inject(BaseWallet) - my_info = await wallet.get_local_did(connection.my_did) - - results = self.diddoc_connection_targets( - did_doc, my_info.verkey, connection.their_label - ) + if not connection: + async with self._profile.session() as session: + connection = await ConnRecord.retrieve_by_id(session, connection_id) - return results + targets = await self.fetch_connection_targets(connection) + return targets def diddoc_connection_targets( - self, doc: DIDDoc, sender_verkey: str, their_label: str = None + self, + doc: DIDDoc, + sender_verkey: str, + their_label: Optional[str] = None, ) -> Sequence[ConnectionTarget]: """Get a list of connection targets from a DID Document. @@ -438,7 +636,6 @@ def diddoc_connection_targets( sender_verkey: The verkey we are using their_label: The connection label they are using """ - if not doc: raise BaseConnectionManagerError("No DIDDoc provided for connection target") if not doc.did: @@ -475,3 +672,279 @@ async def fetch_did_document(self, did: str) -> Tuple[DIDDoc, StorageRecord]: storage = session.inject(BaseStorage) record = await storage.find_record(self.RECORD_TYPE_DID_DOC, {"did": did}) return DIDDoc.from_json(record.value), record + + async def find_connection( + self, + their_did: str, + my_did: Optional[str] = None, + my_verkey: Optional[str] = None, + auto_complete=False, + ) -> Optional[ConnRecord]: + """ + Look up existing connection information for a sender verkey. + + Args: + their_did: Their DID + my_did: My DID + my_verkey: My verkey + auto_complete: Should this connection automatically be promoted to active + + Returns: + The located `ConnRecord`, if any + + """ + connection = None + if their_did: + try: + async with self._profile.session() as session: + connection = await ConnRecord.retrieve_by_did( + session, their_did, my_did + ) + except StorageNotFoundError: + pass + + if ( + connection + and ConnRecord.State.get(connection.state) is ConnRecord.State.RESPONSE + and auto_complete + ): + connection.state = ConnRecord.State.COMPLETED.rfc160 + async with self._profile.session() as session: + await connection.save(session, reason="Connection promoted to active") + if session.settings.get("auto_disclose_features"): + discovery_mgr = V20DiscoveryMgr(self._profile) + await discovery_mgr.proactive_disclose_features( + connection_id=connection.connection_id + ) + + if not connection and my_verkey: + try: + async with self._profile.session() as session: + connection = await ConnRecord.retrieve_by_invitation_key( + session, + my_verkey, + their_role=ConnRecord.Role.REQUESTER.rfc160, + ) + except StorageError: + pass + + return connection + + async def find_inbound_connection( + self, receipt: MessageReceipt + ) -> Optional[ConnRecord]: + """ + Deserialize an incoming message and further populate the request context. + + Args: + receipt: The message receipt + + Returns: + The `ConnRecord` associated with the expanded message, if any + + """ + + cache_key = None + connection = None + resolved = False + + if receipt.sender_verkey and receipt.recipient_verkey: + cache_key = ( + f"connection_by_verkey::{receipt.sender_verkey}" + f"::{receipt.recipient_verkey}" + ) + cache = self._profile.inject_or(BaseCache) + if cache: + async with cache.acquire(cache_key) as entry: + if entry.result: + cached = entry.result + receipt.sender_did = cached["sender_did"] + receipt.recipient_did_public = cached["recipient_did_public"] + receipt.recipient_did = cached["recipient_did"] + async with self._profile.session() as session: + connection = await ConnRecord.retrieve_by_id( + session, cached["id"] + ) + else: + connection = await self.resolve_inbound_connection(receipt) + if connection: + cache_val = { + "id": connection.connection_id, + "sender_did": receipt.sender_did, + "recipient_did": receipt.recipient_did, + "recipient_did_public": receipt.recipient_did_public, + } + await entry.set_result(cache_val, 3600) + resolved = True + + if not connection and not resolved: + connection = await self.resolve_inbound_connection(receipt) + return connection + + async def resolve_inbound_connection( + self, receipt: MessageReceipt + ) -> Optional[ConnRecord]: + """ + Populate the receipt DID information and find the related `ConnRecord`. + + Args: + receipt: The message receipt + + Returns: + The `ConnRecord` associated with the expanded message, if any + + """ + + if receipt.sender_verkey: + try: + receipt.sender_did = await self.find_did_for_key(receipt.sender_verkey) + except StorageNotFoundError: + self._logger.warning( + "No corresponding DID found for sender verkey: %s", + receipt.sender_verkey, + ) + + if receipt.recipient_verkey: + try: + async with self._profile.session() as session: + wallet = session.inject(BaseWallet) + my_info = await wallet.get_local_did_for_verkey( + receipt.recipient_verkey + ) + receipt.recipient_did = my_info.did + if "posted" in my_info.metadata and my_info.metadata["posted"] is True: + receipt.recipient_did_public = True + except InjectionError: + self._logger.warning( + "Cannot resolve recipient verkey, no wallet defined by " + "context: %s", + receipt.recipient_verkey, + ) + except WalletNotFoundError: + self._logger.warning( + "No corresponding DID found for recipient verkey: %s", + receipt.recipient_verkey, + ) + + return await self.find_connection( + receipt.sender_did, receipt.recipient_did, receipt.recipient_verkey, True + ) + + async def get_endpoints(self, conn_id: str) -> Tuple[Optional[str], Optional[str]]: + """ + Get connection endpoints. + + Args: + conn_id: connection identifier + + Returns: + Their endpoint for this connection + + """ + async with self._profile.session() as session: + connection = await ConnRecord.retrieve_by_id(session, conn_id) + wallet = session.inject(BaseWallet) + my_did_info = await wallet.get_local_did(connection.my_did) + my_endpoint = my_did_info.metadata.get( + "endpoint", + self._profile.settings.get("default_endpoint"), + ) + + conn_targets = await self.get_connection_targets( + connection_id=connection.connection_id, + connection=connection, + ) + return (my_endpoint, conn_targets[0].endpoint) + + async def create_static_connection( + self, + my_did: Optional[str] = None, + my_seed: Optional[str] = None, + their_did: Optional[str] = None, + their_seed: Optional[str] = None, + their_verkey: Optional[str] = None, + their_endpoint: Optional[str] = None, + their_label: Optional[str] = None, + alias: Optional[str] = None, + mediation_id: Optional[str] = None, + ) -> Tuple[DIDInfo, DIDInfo, ConnRecord]: + """ + Register a new static connection (for use by the test suite). + + Args: + my_did: override the DID used in the connection + my_seed: provide a seed used to generate our DID and keys + their_did: provide the DID used by the other party + their_seed: provide a seed used to generate their DID and keys + their_verkey: provide the verkey used by the other party + their_endpoint: their URL endpoint for routing messages + alias: an alias for this connection record + + Returns: + Tuple: my DIDInfo, their DIDInfo, new `ConnRecord` instance + + """ + async with self._profile.session() as session: + wallet = session.inject(BaseWallet) + # seed and DID optional + my_info = await wallet.create_local_did(SOV, ED25519, my_seed, my_did) + + # must provide their DID and verkey if the seed is not known + if (not their_did or not their_verkey) and not their_seed: + raise BaseConnectionManagerError( + "Either a verkey or seed must be provided for the other party" + ) + if not their_did: + their_did = seed_to_did(their_seed) + if not their_verkey: + their_verkey_bin, _ = create_keypair(ED25519, their_seed.encode()) + their_verkey = bytes_to_b58(their_verkey_bin) + their_info = DIDInfo(their_did, their_verkey, {}, method=SOV, key_type=ED25519) + + # Create connection record + connection = ConnRecord( + invitation_mode=ConnRecord.INVITATION_MODE_STATIC, + my_did=my_info.did, + their_did=their_info.did, + their_label=their_label, + state=ConnRecord.State.COMPLETED.rfc160, + alias=alias, + connection_protocol=CONN_PROTO, + ) + async with self._profile.session() as session: + await connection.save(session, reason="Created new static connection") + if session.settings.get("auto_disclose_features"): + discovery_mgr = V20DiscoveryMgr(self._profile) + await discovery_mgr.proactive_disclose_features( + connection_id=connection.connection_id + ) + + # Routing + mediation_record = await self._route_manager.mediation_record_if_id( + self._profile, mediation_id, or_default=True + ) + + multitenant_mgr = self._profile.inject_or(BaseMultitenantManager) + wallet_id = self._profile.settings.get("wallet.id") + + base_mediation_record = None + if multitenant_mgr and wallet_id: + base_mediation_record = await multitenant_mgr.get_default_mediator() + + await self._route_manager.route_static( + self._profile, connection, mediation_record + ) + + # Synthesize their DID doc + did_doc = await self.create_did_document( + their_info, + None, + [their_endpoint or ""], + mediation_records=list( + filter(None, [base_mediation_record, mediation_record]) + ), + ) + + await self.store_did_document(did_doc) + + return my_info, their_info, connection diff --git a/aries_cloudagent/connections/models/conn_record.py b/aries_cloudagent/connections/models/conn_record.py index 131f2eff64..a75b80e318 100644 --- a/aries_cloudagent/connections/models/conn_record.py +++ b/aries_cloudagent/connections/models/conn_record.py @@ -294,9 +294,9 @@ def record_value(self) -> dict: async def retrieve_by_did( cls, session: ProfileSession, - their_did: str = None, - my_did: str = None, - their_role: str = None, + their_did: Optional[str] = None, + my_did: Optional[str] = None, + their_role: Optional[str] = None, ) -> "ConnRecord": """Retrieve a connection record by target DID. diff --git a/aries_cloudagent/connections/models/connection_target.py b/aries_cloudagent/connections/models/connection_target.py index 4e34dc46b9..6795b970ba 100644 --- a/aries_cloudagent/connections/models/connection_target.py +++ b/aries_cloudagent/connections/models/connection_target.py @@ -1,6 +1,6 @@ """Record used to handle routing of messages to another agent.""" -from typing import Sequence +from typing import Optional, Sequence from marshmallow import EXCLUDE, fields @@ -24,12 +24,12 @@ class Meta: def __init__( self, *, - did: str = None, - endpoint: str = None, - label: str = None, - recipient_keys: Sequence[str] = None, - routing_keys: Sequence[str] = None, - sender_key: str = None, + did: Optional[str] = None, + endpoint: Optional[str] = None, + label: Optional[str] = None, + recipient_keys: Optional[Sequence[str]] = None, + routing_keys: Optional[Sequence[str]] = None, + sender_key: Optional[str] = None, ): """ Initialize a ConnectionTarget instance. diff --git a/aries_cloudagent/connections/tests/__init__.py b/aries_cloudagent/connections/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aries_cloudagent/connections/tests/test_base_manager.py b/aries_cloudagent/connections/tests/test_base_manager.py new file mode 100644 index 0000000000..4d7c05515e --- /dev/null +++ b/aries_cloudagent/connections/tests/test_base_manager.py @@ -0,0 +1,1908 @@ +"""Test connections base manager.""" + +from unittest.mock import call + +from asynctest import TestCase as AsyncTestCase, mock as async_mock +from multiformats import multibase, multicodec +from pydid import DID, DIDDocument, DIDDocumentBuilder +from pydid.doc.builder import ServiceBuilder +from pydid.verification_method import ( + Ed25519VerificationKey2018, + Ed25519VerificationKey2020, + EcdsaSecp256k1VerificationKey2019, + JsonWebKey2020, +) + +from .. import base_manager as test_module +from ...cache.base import BaseCache +from ...cache.in_memory import InMemoryCache +from ...config.base import InjectionError +from ...connections.base_manager import BaseConnectionManagerError +from ...connections.models.conn_record import ConnRecord +from ...connections.models.connection_target import ConnectionTarget +from ...connections.models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service +from ...core.in_memory import InMemoryProfile +from ...core.oob_processor import OobMessageProcessor +from ...did.did_key import DIDKey +from ...messaging.responder import BaseResponder, MockResponder +from ...multitenant.base import BaseMultitenantManager +from ...multitenant.manager import MultitenantManager +from ...protocols.connections.v1_0.messages.connection_invitation import ( + ConnectionInvitation, +) +from ...protocols.coordinate_mediation.v1_0.models.mediation_record import ( + MediationRecord, +) +from ...protocols.coordinate_mediation.v1_0.route_manager import RouteManager +from ...protocols.discovery.v2_0.manager import V20DiscoveryMgr +from ...resolver.default.key import KeyDIDResolver +from ...resolver.default.legacy_peer import LegacyPeerDIDResolver +from ...resolver.did_resolver import DIDResolver +from ...storage.error import StorageNotFoundError +from ...transport.inbound.receipt import MessageReceipt +from ...wallet.base import DIDInfo +from ...wallet.did_method import DIDMethods, SOV +from ...wallet.error import WalletNotFoundError +from ...wallet.in_memory import InMemoryWallet +from ...wallet.key_type import ED25519 +from ...wallet.util import b58_to_bytes, bytes_to_b64 +from ..base_manager import BaseConnectionManager + + +class TestBaseConnectionManager(AsyncTestCase): + def make_did_doc(self, did, verkey): + doc = DIDDoc(did=did) + controller = did + ident = "1" + pk_value = verkey + pk = PublicKey( + did, ident, pk_value, PublicKeyType.ED25519_SIG_2018, controller, False + ) + doc.set(pk) + recip_keys = [pk] + router_keys = [] + service = Service( + did, "indy", "IndyAgent", recip_keys, router_keys, self.test_endpoint + ) + doc.set(service) + return doc + + async def setUp(self): + self.test_seed = "testseed000000000000000000000001" + self.test_did = "55GkHamhTU1ZbTbV2ab9DE" + self.test_verkey = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" + self.test_endpoint = "http://localhost" + + self.test_target_did = "GbuDUYXaUZRfHD2jeDuQuP" + self.test_target_verkey = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" + + self.responder = MockResponder() + + self.oob_mock = async_mock.MagicMock( + clean_finished_oob_record=async_mock.CoroutineMock(return_value=None) + ) + self.route_manager = async_mock.MagicMock(RouteManager) + self.route_manager.routing_info = async_mock.CoroutineMock( + return_value=([], self.test_endpoint) + ) + self.route_manager.mediation_record_if_id = async_mock.CoroutineMock( + return_value=None + ) + self.resolver = DIDResolver() + self.resolver.register_resolver(LegacyPeerDIDResolver()) + self.resolver.register_resolver(KeyDIDResolver()) + + self.profile = InMemoryProfile.test_profile( + { + "default_endpoint": "http://aries.ca/endpoint", + "default_label": "This guy", + "additional_endpoints": ["http://aries.ca/another-endpoint"], + "debug.auto_accept_invites": True, + "debug.auto_accept_requests": True, + }, + bind={ + BaseResponder: self.responder, + BaseCache: InMemoryCache(), + OobMessageProcessor: self.oob_mock, + RouteManager: self.route_manager, + DIDMethods: DIDMethods(), + DIDResolver: self.resolver, + }, + ) + self.context = self.profile.context + + self.multitenant_mgr = async_mock.MagicMock(MultitenantManager, autospec=True) + self.context.injector.bind_instance( + BaseMultitenantManager, self.multitenant_mgr + ) + + self.test_mediator_routing_keys = [ + "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRR" + ] + self.test_mediator_conn_id = "mediator-conn-id" + self.test_mediator_endpoint = "http://mediator.example.com" + + self.manager = BaseConnectionManager(self.profile) + assert self.manager._profile + + async def test_create_did_document(self): + did_info = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + + mock_conn = async_mock.MagicMock( + connection_id="dummy", + inbound_connection_id=None, + their_did=self.test_target_did, + state=ConnRecord.State.COMPLETED.rfc23, + ) + + did_doc = self.make_did_doc( + did=self.test_target_did, verkey=self.test_target_verkey + ) + for i in range(2): # first cover store-record, then update-value + await self.manager.store_did_document(did_doc) + + with async_mock.patch.object( + ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.return_value = mock_conn + + did_doc = await self.manager.create_did_document( + did_info=did_info, + inbound_connection_id="dummy", + svc_endpoints=[self.test_endpoint], + ) + + async def test_create_did_document_not_active(self): + did_info = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + + mock_conn = async_mock.MagicMock( + connection_id="dummy", + inbound_connection_id=None, + their_did=self.test_target_did, + state=ConnRecord.State.ABANDONED.rfc23, + ) + + with async_mock.patch.object( + ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.return_value = mock_conn + + with self.assertRaises(BaseConnectionManagerError): + await self.manager.create_did_document( + did_info=did_info, + inbound_connection_id="dummy", + svc_endpoints=[self.test_endpoint], + ) + + async def test_create_did_document_no_services(self): + did_info = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + + mock_conn = async_mock.MagicMock( + connection_id="dummy", + inbound_connection_id=None, + their_did=self.test_target_did, + state=ConnRecord.State.COMPLETED.rfc23, + ) + + x_did_doc = self.make_did_doc( + did=self.test_target_did, verkey=self.test_target_verkey + ) + x_did_doc._service = {} + for i in range(2): # first cover store-record, then update-value + await self.manager.store_did_document(x_did_doc) + + with async_mock.patch.object( + ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.return_value = mock_conn + + with self.assertRaises(BaseConnectionManagerError): + await self.manager.create_did_document( + did_info=did_info, + inbound_connection_id="dummy", + svc_endpoints=[self.test_endpoint], + ) + + async def test_create_did_document_no_service_endpoint(self): + did_info = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + + mock_conn = async_mock.MagicMock( + connection_id="dummy", + inbound_connection_id=None, + their_did=self.test_target_did, + state=ConnRecord.State.COMPLETED.rfc23, + ) + + x_did_doc = self.make_did_doc( + did=self.test_target_did, verkey=self.test_target_verkey + ) + x_did_doc._service = {} + x_did_doc.set( + Service(self.test_target_did, "dummy", "IndyAgent", [], [], "", 0) + ) + for i in range(2): # first cover store-record, then update-value + await self.manager.store_did_document(x_did_doc) + + with async_mock.patch.object( + ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.return_value = mock_conn + + with self.assertRaises(BaseConnectionManagerError): + await self.manager.create_did_document( + did_info=did_info, + inbound_connection_id="dummy", + svc_endpoints=[self.test_endpoint], + ) + + async def test_create_did_document_no_service_recip_keys(self): + did_info = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + + mock_conn = async_mock.MagicMock( + connection_id="dummy", + inbound_connection_id=None, + their_did=self.test_target_did, + state=ConnRecord.State.COMPLETED.rfc23, + ) + + x_did_doc = self.make_did_doc( + did=self.test_target_did, verkey=self.test_target_verkey + ) + x_did_doc._service = {} + x_did_doc.set( + Service( + self.test_target_did, + "dummy", + "IndyAgent", + [], + [], + self.test_endpoint, + 0, + ) + ) + for i in range(2): # first cover store-record, then update-value + await self.manager.store_did_document(x_did_doc) + + with async_mock.patch.object( + ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.return_value = mock_conn + + with self.assertRaises(BaseConnectionManagerError): + await self.manager.create_did_document( + did_info=did_info, + inbound_connection_id="dummy", + svc_endpoints=[self.test_endpoint], + ) + + async def test_create_did_document_mediation(self): + did_info = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + mediation_record = MediationRecord( + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + doc = await self.manager.create_did_document( + did_info, mediation_records=[mediation_record] + ) + assert doc.service + services = list(doc.service.values()) + assert len(services) == 1 + (service,) = services + service_public_keys = service.routing_keys[0] + assert service_public_keys.value == mediation_record.routing_keys[0] + assert service.endpoint == mediation_record.endpoint + + async def test_create_did_document_multiple_mediators(self): + did_info = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + mediation_record1 = MediationRecord( + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + mediation_record2 = MediationRecord( + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id="mediator-conn-id2", + routing_keys=["05e8afd1-b4f0-46b7-a285-7a08c8a37caf"], + endpoint="http://mediatorw.example.com", + ) + doc = await self.manager.create_did_document( + did_info, mediation_records=[mediation_record1, mediation_record2] + ) + assert doc.service + services = list(doc.service.values()) + assert len(services) == 1 + (service,) = services + assert service.routing_keys[0].value == mediation_record1.routing_keys[0] + assert service.routing_keys[1].value == mediation_record2.routing_keys[0] + assert service.endpoint == mediation_record2.endpoint + + async def test_create_did_document_mediation_svc_endpoints_overwritten(self): + did_info = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + mediation_record = MediationRecord( + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + doc = await self.manager.create_did_document( + did_info, + svc_endpoints=[self.test_endpoint], + mediation_records=[mediation_record], + ) + assert doc.service + services = list(doc.service.values()) + assert len(services) == 1 + (service,) = services + service_public_keys = service.routing_keys[0] + assert service_public_keys.value == mediation_record.routing_keys[0] + assert service.endpoint == mediation_record.endpoint + + async def test_did_key_storage(self): + await self.manager.add_key_for_did( + did=self.test_target_did, key=self.test_target_verkey + ) + await self.manager.add_key_for_did( + did=self.test_target_did, key=self.test_target_verkey + ) + + did = await self.manager.find_did_for_key(key=self.test_target_verkey) + assert did == self.test_target_did + await self.manager.remove_keys_for_did(self.test_target_did) + + async def test_fetch_connection_targets_no_my_did(self): + mock_conn = async_mock.MagicMock() + mock_conn.my_did = None + assert await self.manager.fetch_connection_targets(mock_conn) == [] + + async def test_fetch_connection_targets_conn_invitation_did_no_resolver(self): + async with self.profile.session() as session: + self.context.injector.bind_instance(DIDResolver, DIDResolver([])) + await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + conn_invite = ConnectionInvitation( + did=self.test_target_did, + endpoint=self.test_endpoint, + recipient_keys=[self.test_target_verkey], + routing_keys=[self.test_verkey], + label="label", + ) + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), + ) + + with self.assertRaises(BaseConnectionManagerError): + await self.manager.fetch_connection_targets(mock_conn) + + async def test_fetch_connection_targets_conn_invitation_did_resolver(self): + async with self.profile.session() as session: + builder = DIDDocumentBuilder("did:sov:" + self.test_target_did) + vmethod = builder.verification_method.add( + Ed25519VerificationKey2018, public_key_base58=self.test_target_verkey + ) + builder.service.add_didcomm( + ident="did-communication", + service_endpoint=self.test_endpoint, + recipient_keys=[vmethod], + ) + did_doc = builder.build() + self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( + return_value=self.test_endpoint + ) + self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) + self.resolver.dereference = async_mock.CoroutineMock( + return_value=did_doc.verification_method[0] + ) + self.context.injector.bind_instance(DIDResolver, self.resolver) + + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + conn_invite = ConnectionInvitation( + did=self.test_target_did, + endpoint=self.test_endpoint, + recipient_keys=[self.test_target_verkey], + routing_keys=[self.test_verkey], + label="label", + ) + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), + ) + + targets = await self.manager.fetch_connection_targets(mock_conn) + assert len(targets) == 1 + target = targets[0] + assert target.did == mock_conn.their_did + assert target.endpoint == conn_invite.endpoint + assert target.label == conn_invite.label + assert target.recipient_keys == conn_invite.recipient_keys + assert target.routing_keys == [] + assert target.sender_key == local_did.verkey + + async def test_fetch_connection_targets_conn_invitation_btcr_resolver(self): + async with self.profile.session() as session: + builder = DIDDocumentBuilder("did:btcr:x705-jznz-q3nl-srs") + vmethod = builder.verification_method.add( + Ed25519VerificationKey2018, public_key_base58=self.test_target_verkey + ) + builder.service.add_didcomm( + type_="IndyAgent", + recipient_keys=[vmethod], + routing_keys=[vmethod], + service_endpoint=self.test_endpoint, + priority=1, + ) + + builder.service.add_didcomm( + recipient_keys=[vmethod], + routing_keys=[vmethod], + service_endpoint=self.test_endpoint, + priority=0, + ) + builder.service.add_didcomm( + recipient_keys=[vmethod], + routing_keys=[vmethod], + service_endpoint="{}/priority2".format(self.test_endpoint), + priority=2, + ) + did_doc = builder.build() + + self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( + return_value=self.test_endpoint + ) + self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) + self.resolver.dereference = async_mock.CoroutineMock( + return_value=did_doc.verification_method[0] + ) + self.context.injector.bind_instance(DIDResolver, self.resolver) + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=did_doc.id, + metadata=None, + ) + + conn_invite = ConnectionInvitation( + did=did_doc.id, + endpoint=self.test_endpoint, + recipient_keys=[vmethod.material], + routing_keys=[self.test_verkey], + label="label", + ) + mock_conn = async_mock.MagicMock( + my_did=did_doc.id, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), + ) + + targets = await self.manager.fetch_connection_targets(mock_conn) + assert len(targets) == 1 + target = targets[0] + assert target.did == mock_conn.their_did + assert target.endpoint == self.test_endpoint + assert target.label == conn_invite.label + assert target.recipient_keys == conn_invite.recipient_keys + assert target.routing_keys == [vmethod.material] + assert target.sender_key == local_did.verkey + + async def test_fetch_connection_targets_conn_invitation_btcr_without_services(self): + async with self.profile.session() as session: + did_doc_json = { + "@context": ["https://www.w3.org/ns/did/v1"], + "id": "did:btcr:x705-jznz-q3nl-srs", + "verificationMethod": [ + { + "type": "EcdsaSecp256k1VerificationKey2019", + "id": "did:btcr:x705-jznz-q3nl-srs#key-0", + "publicKeyBase58": "02e0e01a8c302976e1556e95c54146e8464adac8626a5d29474718a7281133ff49", + }, + { + "type": "EcdsaSecp256k1VerificationKey2019", + "id": "did:btcr:x705-jznz-q3nl-srs#key-1", + "publicKeyBase58": "02e0e01a8c302976e1556e95c54146e8464adac8626a5d29474718a7281133ff49", + }, + { + "type": "EcdsaSecp256k1VerificationKey2019", + "id": "did:btcr:x705-jznz-q3nl-srs#satoshi", + "publicKeyBase58": "02e0e01a8c302976e1556e95c54146e8464adac8626a5d29474718a7281133ff49", + }, + ], + } + did_doc = DIDDocument.deserialize(did_doc_json) + self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( + return_value=self.test_endpoint + ) + self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) + self.context.injector.bind_instance(DIDResolver, self.resolver) + + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=did_doc.id, + metadata=None, + ) + + conn_invite = ConnectionInvitation( + did=did_doc.id, + endpoint=self.test_endpoint, + recipient_keys=["{}#1".format(did_doc.id)], + routing_keys=[self.test_verkey], + label="label", + ) + mock_conn = async_mock.MagicMock( + my_did=did_doc.id, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), + ) + with self.assertRaises(BaseConnectionManagerError): + await self.manager.fetch_connection_targets(mock_conn) + + async def test_fetch_connection_targets_conn_invitation_no_didcomm_services(self): + async with self.profile.session() as session: + builder = DIDDocumentBuilder("did:btcr:x705-jznz-q3nl-srs") + builder.verification_method.add( + Ed25519VerificationKey2018, public_key_base58=self.test_target_verkey + ) + builder.service.add(type_="LinkedData", service_endpoint=self.test_endpoint) + did_doc = builder.build() + self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( + return_value=self.test_endpoint + ) + self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) + self.context.injector.bind_instance(DIDResolver, self.resolver) + await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=did_doc.id, + metadata=None, + ) + + conn_invite = ConnectionInvitation( + did=did_doc.id, + endpoint=self.test_endpoint, + recipient_keys=["{}#1".format(did_doc.id)], + routing_keys=[self.test_verkey], + label="label", + ) + mock_conn = async_mock.MagicMock( + my_did=did_doc.id, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), + ) + with self.assertRaises(BaseConnectionManagerError): + await self.manager.fetch_connection_targets(mock_conn) + + async def test_fetch_connection_targets_conn_invitation_supports_Ed25519VerificationKey2018_key_type_no_multicodec( + self, + ): + async with self.profile.session() as session: + builder = DIDDocumentBuilder("did:btcr:x705-jznz-q3nl-srs") + vmethod = builder.verification_method.add( + Ed25519VerificationKey2020, + public_key_multibase=multibase.encode( + b58_to_bytes(self.test_target_verkey), "base58btc" + ), + ) + builder.service.add_didcomm( + type_="IndyAgent", + service_endpoint=self.test_endpoint, + recipient_keys=[vmethod], + ) + did_doc = builder.build() + self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( + return_value=self.test_endpoint + ) + self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) + self.resolver.dereference = async_mock.CoroutineMock( + return_value=did_doc.verification_method[0] + ) + self.context.injector.bind_instance(DIDResolver, self.resolver) + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=did_doc.id, + metadata=None, + ) + + conn_invite = ConnectionInvitation( + did=did_doc.id, + endpoint=self.test_endpoint, + recipient_keys=[vmethod.public_key_jwk], + routing_keys=[self.test_verkey], + label="label", + ) + mock_conn = async_mock.MagicMock( + my_did=did_doc.id, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), + ) + + targets = await self.manager.fetch_connection_targets(mock_conn) + assert len(targets) == 1 + target = targets[0] + assert target.did == mock_conn.their_did + assert target.endpoint == self.test_endpoint + assert target.label == conn_invite.label + assert target.recipient_keys == [self.test_target_verkey] + assert target.routing_keys == [] + assert target.sender_key == local_did.verkey + + async def test_fetch_connection_targets_conn_invitation_supports_Ed25519VerificationKey2018_key_type_with_multicodec( + self, + ): + async with self.profile.session() as session: + builder = DIDDocumentBuilder("did:btcr:x705-jznz-q3nl-srs") + vmethod = builder.verification_method.add( + Ed25519VerificationKey2020, + public_key_multibase=multibase.encode( + multicodec.wrap( + "ed25519-pub", b58_to_bytes(self.test_target_verkey) + ), + "base58btc", + ), + ) + builder.service.add_didcomm( + type_="IndyAgent", + service_endpoint=self.test_endpoint, + recipient_keys=[vmethod], + ) + did_doc = builder.build() + self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( + return_value=self.test_endpoint + ) + self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) + self.resolver.dereference = async_mock.CoroutineMock( + return_value=did_doc.verification_method[0] + ) + self.context.injector.bind_instance(DIDResolver, self.resolver) + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=did_doc.id, + metadata=None, + ) + + conn_invite = ConnectionInvitation( + did=did_doc.id, + endpoint=self.test_endpoint, + recipient_keys=[vmethod.public_key_jwk], + routing_keys=[self.test_verkey], + label="label", + ) + mock_conn = async_mock.MagicMock( + my_did=did_doc.id, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), + ) + + targets = await self.manager.fetch_connection_targets(mock_conn) + assert len(targets) == 1 + target = targets[0] + assert target.did == mock_conn.their_did + assert target.endpoint == self.test_endpoint + assert target.label == conn_invite.label + assert target.recipient_keys == [self.test_target_verkey] + assert target.routing_keys == [] + assert target.sender_key == local_did.verkey + + async def test_fetch_connection_targets_conn_invitation_supported_JsonWebKey2020_key_type( + self, + ): + async with self.profile.session() as session: + builder = DIDDocumentBuilder("did:btcr:x705-jznz-q3nl-srs") + vmethod = builder.verification_method.add( + JsonWebKey2020, + ident="1", + public_key_jwk={ + "kty": "OKP", + "crv": "Ed25519", + "x": bytes_to_b64(b58_to_bytes(self.test_target_verkey), True), + }, + ) + builder.service.add_didcomm( + type_="IndyAgent", + service_endpoint=self.test_endpoint, + recipient_keys=[vmethod], + ) + did_doc = builder.build() + self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( + return_value=self.test_endpoint + ) + self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) + self.resolver.dereference = async_mock.CoroutineMock( + return_value=did_doc.verification_method[0] + ) + self.context.injector.bind_instance(DIDResolver, self.resolver) + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=did_doc.id, + metadata=None, + ) + + conn_invite = ConnectionInvitation( + did=did_doc.id, + endpoint=self.test_endpoint, + recipient_keys=[vmethod.public_key_jwk], + routing_keys=[self.test_verkey], + label="label", + ) + mock_conn = async_mock.MagicMock( + my_did=did_doc.id, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), + ) + + targets = await self.manager.fetch_connection_targets(mock_conn) + assert len(targets) == 1 + target = targets[0] + assert target.did == mock_conn.their_did + assert target.endpoint == self.test_endpoint + assert target.label == conn_invite.label + assert target.recipient_keys == [self.test_target_verkey] + assert target.routing_keys == [] + assert target.sender_key == local_did.verkey + + async def test_fetch_connection_targets_conn_invitation_unsupported_key_type(self): + async with self.profile.session() as session: + builder = DIDDocumentBuilder("did:btcr:x705-jznz-q3nl-srs") + vmethod = builder.verification_method.add( + JsonWebKey2020, + ident="1", + public_key_jwk={ + "kty": "EC", + "crv": "P-256", + "x": "2syLh57B-dGpa0F8p1JrO6JU7UUSF6j7qL-vfk1eOoY", + "y": "BgsGtI7UPsObMRjdElxLOrgAO9JggNMjOcfzEPox18w", + }, + ) + builder.service.add_didcomm( + type_="IndyAgent", + service_endpoint=self.test_endpoint, + recipient_keys=[vmethod], + ) + did_doc = builder.build() + self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( + return_value=self.test_endpoint + ) + self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) + self.resolver.dereference = async_mock.CoroutineMock( + return_value=did_doc.verification_method[0] + ) + self.context.injector.bind_instance(DIDResolver, self.resolver) + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=did_doc.id, + metadata=None, + ) + + conn_invite = ConnectionInvitation( + did=did_doc.id, + endpoint=self.test_endpoint, + recipient_keys=["{}#1".format(did_doc.id)], + routing_keys=[self.test_verkey], + label="label", + ) + mock_conn = async_mock.MagicMock( + my_did=did_doc.id, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), + ) + with self.assertRaises(BaseConnectionManagerError): + await self.manager.fetch_connection_targets(mock_conn) + + async def test_fetch_connection_targets_oob_invitation_svc_did_no_resolver(self): + async with self.profile.session() as session: + self.context.injector.bind_instance(DIDResolver, DIDResolver([])) + await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + mock_oob_invite = async_mock.MagicMock(services=[self.test_did]) + + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + retrieve_invitation=async_mock.CoroutineMock( + return_value=mock_oob_invite + ), + state=ConnRecord.State.INVITATION.rfc23, + their_role=ConnRecord.Role.RESPONDER.rfc23, + ) + + with self.assertRaises(BaseConnectionManagerError): + await self.manager.fetch_connection_targets(mock_conn) + + async def test_fetch_connection_targets_oob_invitation_svc_did_resolver(self): + async with self.profile.session() as session: + builder = DIDDocumentBuilder("did:sov:" + self.test_target_did) + vmethod = builder.verification_method.add( + Ed25519VerificationKey2018, + ident="1", + public_key_base58=self.test_target_verkey, + ) + builder.service.add_didcomm( + ident="did-communication", + service_endpoint=self.test_endpoint, + recipient_keys=[vmethod], + ) + did_doc = builder.build() + + self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) + self.resolver.dereference = async_mock.CoroutineMock( + return_value=did_doc.verification_method[0] + ) + self.context.injector.bind_instance(DIDResolver, self.resolver) + + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + mock_oob_invite = async_mock.MagicMock( + label="a label", + their_did=self.test_target_did, + services=["dummy"], + ) + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock( + return_value=mock_oob_invite + ), + ) + + targets = await self.manager.fetch_connection_targets(mock_conn) + assert len(targets) == 1 + target = targets[0] + assert target.did == mock_conn.their_did + assert target.endpoint == self.test_endpoint + assert target.label == mock_oob_invite.label + assert target.recipient_keys == [vmethod.material] + assert target.routing_keys == [] + assert target.sender_key == local_did.verkey + + async def test_fetch_connection_targets_oob_invitation_svc_block_resolver(self): + async with self.profile.session() as session: + self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( + return_value=self.test_endpoint + ) + self.resolver.get_key_for_did = async_mock.CoroutineMock( + return_value=self.test_target_verkey + ) + self.context.injector.bind_instance(DIDResolver, self.resolver) + + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + mock_oob_invite = async_mock.MagicMock( + label="a label", + their_did=self.test_target_did, + services=[ + async_mock.MagicMock( + service_endpoint=self.test_endpoint, + recipient_keys=[ + DIDKey.from_public_key_b58( + self.test_target_verkey, ED25519 + ).did + ], + routing_keys=[], + ) + ], + ) + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock( + return_value=mock_oob_invite + ), + ) + + targets = await self.manager.fetch_connection_targets(mock_conn) + assert len(targets) == 1 + target = targets[0] + assert target.did == mock_conn.their_did + assert target.endpoint == self.test_endpoint + assert target.label == mock_oob_invite.label + assert target.recipient_keys == [self.test_target_verkey] + assert target.routing_keys == [] + assert target.sender_key == local_did.verkey + + async def test_fetch_connection_targets_conn_initiator_completed_no_their_did(self): + async with self.profile.session() as session: + await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + their_did=None, + state=ConnRecord.State.COMPLETED.rfc23, + ) + assert await self.manager.fetch_connection_targets(mock_conn) == [] + + async def test_fetch_connection_targets_conn_completed_their_did(self): + async with self.profile.session() as session: + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + did_doc = self.make_did_doc(did=self.test_did, verkey=self.test_verkey) + await self.manager.store_did_document(did_doc) + + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + their_did=self.test_did, + their_label="label", + their_role=ConnRecord.Role.REQUESTER.rfc160, + state=ConnRecord.State.COMPLETED.rfc23, + ) + + targets = await self.manager.fetch_connection_targets(mock_conn) + assert len(targets) == 1 + target = targets[0] + # did:sov: dropped for this check + assert target.did[8:] == mock_conn.their_did + assert target.endpoint == self.test_endpoint + assert target.label == mock_conn.their_label + assert target.recipient_keys == [self.test_verkey] + assert target.routing_keys == [] + assert target.sender_key == local_did.verkey + + async def test_fetch_connection_targets_conn_no_invi_with_their_did(self): + async with self.profile.session() as session: + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + self.manager.resolve_invitation = async_mock.CoroutineMock() + self.manager.resolve_invitation.return_value = ( + self.test_endpoint, + [self.test_verkey], + [], + ) + + did_doc = self.make_did_doc(did=self.test_did, verkey=self.test_verkey) + await self.manager.store_did_document(did_doc) + + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + their_did=self.test_did, + their_label="label", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.REQUEST.rfc23, + invitation_key=None, + invitation_msg_id=None, + ) + + targets = await self.manager.fetch_connection_targets(mock_conn) + assert len(targets) == 1 + target = targets[0] + assert target.did == mock_conn.their_did + assert target.endpoint == self.test_endpoint + assert target.label is None + assert target.recipient_keys == [self.test_verkey] + assert target.routing_keys == [] + assert target.sender_key == local_did.verkey + + async def test_verification_methods_for_service(self): + did = "did:sov:" + self.test_did + doc_builder = DIDDocumentBuilder(did) + vm = doc_builder.verification_method.add( + Ed25519VerificationKey2018, + public_key_base58=self.test_verkey, + ) + route_key = DIDKey.from_public_key_b58(self.test_verkey, ED25519) + service = doc_builder.service.add( + type_="did-communication", + service_endpoint=self.test_endpoint, + recipient_keys=[vm.id], + routing_keys=[route_key.key_id], + ) + doc = doc_builder.build() + self.manager.resolve_didcomm_services = async_mock.CoroutineMock( + return_value=(doc, doc.service) + ) + recip, routing = await self.manager.verification_methods_for_service( + doc, service + ) + assert recip == [vm] + assert routing + + async def test_resolve_connection_targets_empty(self): + """Test resolve connection targets.""" + did = "did:sov:" + self.test_did + self.manager.resolve_didcomm_services = async_mock.CoroutineMock( + return_value=(DIDDocument(id=DID(did)), []) + ) + targets = await self.manager.resolve_connection_targets(did) + assert targets == [] + + async def test_resolve_connection_targets(self): + """Test resolve connection targets.""" + did = "did:sov:" + self.test_did + doc_builder = DIDDocumentBuilder(did) + vm = doc_builder.verification_method.add( + Ed25519VerificationKey2018, + public_key_base58=self.test_verkey, + ) + route_key = DIDKey.from_public_key_b58(self.test_verkey, ED25519) + doc_builder.service.add( + type_="did-communication", + service_endpoint=self.test_endpoint, + recipient_keys=[vm.id], + routing_keys=[route_key.key_id], + ) + doc = doc_builder.build() + self.manager.resolve_didcomm_services = async_mock.CoroutineMock( + return_value=(doc, doc.service) + ) + targets = await self.manager.resolve_connection_targets(did) + assert targets + assert targets[0].routing_keys[0] == self.test_verkey + + async def test_resolve_connection_targets_x_bad_reference(self): + """Test resolve connection targets.""" + did = "did:sov:" + self.test_did + doc_builder = DIDDocumentBuilder(did) + vm = doc_builder.verification_method.add( + Ed25519VerificationKey2018, + public_key_base58=self.test_verkey, + ) + doc_builder.service.add( + type_="did-communication", + service_endpoint=self.test_endpoint, + recipient_keys=[vm.id], + routing_keys=["did:example:123#some-random-id"], + ) + doc = doc_builder.build() + self.manager.resolve_didcomm_services = async_mock.CoroutineMock( + return_value=(doc, doc.service) + ) + with self.assertLogs() as cm: + await self.manager.resolve_connection_targets(did) + assert cm.output and "Failed to resolve service" in cm.output[0] + + async def test_resolve_connection_targets_x_bad_key_material(self): + did = "did:sov:" + self.test_did + doc_builder = DIDDocumentBuilder(did) + vm = doc_builder.verification_method.add( + Ed25519VerificationKey2020, + public_key_multibase=multibase.encode( + multicodec.wrap("secp256k1-pub", b58_to_bytes(self.test_verkey)), + "base58btc", + ), + ) + route_key = DIDKey.from_public_key_b58(self.test_verkey, ED25519) + doc_builder.service.add( + type_="did-communication", + service_endpoint=self.test_endpoint, + recipient_keys=[vm.id], + routing_keys=[route_key.key_id], + ) + doc = doc_builder.build() + self.manager.resolve_didcomm_services = async_mock.CoroutineMock( + return_value=(doc, doc.service) + ) + with self.assertRaises(BaseConnectionManagerError) as cm: + await self.manager.resolve_connection_targets(did) + assert "not supported" in str(cm.exception) + + async def test_resolve_connection_targets_x_unsupported_key(self): + did = "did:sov:" + self.test_did + doc_builder = DIDDocumentBuilder(did) + vm = doc_builder.verification_method.add( + EcdsaSecp256k1VerificationKey2019, + public_key_hex="deadbeef", + ) + route_key = DIDKey.from_public_key_b58(self.test_verkey, ED25519) + doc_builder.service.add( + type_="did-communication", + service_endpoint=self.test_endpoint, + recipient_keys=[vm.id], + routing_keys=[route_key.key_id], + ) + doc = doc_builder.build() + self.manager.resolve_didcomm_services = async_mock.CoroutineMock( + return_value=(doc, doc.service) + ) + with self.assertRaises(BaseConnectionManagerError) as cm: + await self.manager.resolve_connection_targets(did) + assert "not supported" in str(cm.exception) + + async def test_record_keys_for_public_did_empty(self): + did = "did:sov:" + self.test_did + service_builder = ServiceBuilder(DID(did)) + service_builder.add_didcomm( + self.test_endpoint, recipient_keys=[], routing_keys=[] + ) + self.manager.resolve_didcomm_services = async_mock.CoroutineMock( + return_value=(DIDDocument(id=DID(did)), service_builder.services) + ) + await self.manager.record_keys_for_public_did(did) + + async def test_record_keys_for_public_did(self): + did = "did:sov:" + self.test_did + doc_builder = DIDDocumentBuilder(did) + vm = doc_builder.verification_method.add( + Ed25519VerificationKey2018, + public_key_base58=self.test_verkey, + ) + doc_builder.service.add_didcomm( + self.test_endpoint, recipient_keys=[vm], routing_keys=[] + ) + doc = doc_builder.build() + self.manager.resolve_didcomm_services = async_mock.CoroutineMock( + return_value=(doc, doc.service) + ) + await self.manager.record_keys_for_public_did(did) + + async def test_diddoc_connection_targets_diddoc_underspecified(self): + with self.assertRaises(BaseConnectionManagerError): + self.manager.diddoc_connection_targets(None, self.test_verkey) + + x_did_doc = DIDDoc(did=None) + with self.assertRaises(BaseConnectionManagerError): + self.manager.diddoc_connection_targets(x_did_doc, self.test_verkey) + + x_did_doc = self.make_did_doc( + did=self.test_target_did, verkey=self.test_target_verkey + ) + x_did_doc._service = {} + with self.assertRaises(BaseConnectionManagerError): + self.manager.diddoc_connection_targets(x_did_doc, self.test_verkey) + + async def test_find_inbound_connection(self): + receipt = MessageReceipt( + sender_verkey=self.test_verkey, + recipient_verkey=self.test_target_verkey, + recipient_did_public=False, + ) + + mock_conn = async_mock.MagicMock() + mock_conn.connection_id = "dummy" + + # First pass: not yet in cache + with async_mock.patch.object( + BaseConnectionManager, + "resolve_inbound_connection", + async_mock.CoroutineMock(), + ) as mock_conn_mgr_resolve_conn: + mock_conn_mgr_resolve_conn.return_value = mock_conn + + conn_rec = await self.manager.find_inbound_connection(receipt) + assert conn_rec + + # Second pass: in cache + with async_mock.patch.object( + ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.return_value = mock_conn + + conn_rec = await self.manager.find_inbound_connection(receipt) + assert conn_rec.id == mock_conn.id + + async def test_find_inbound_connection_no_cache(self): + receipt = MessageReceipt( + sender_verkey=self.test_verkey, + recipient_verkey=self.test_target_verkey, + recipient_did_public=False, + ) + + mock_conn = async_mock.MagicMock() + mock_conn.connection_id = "dummy" + + with async_mock.patch.object( + BaseConnectionManager, + "resolve_inbound_connection", + async_mock.CoroutineMock(), + ) as mock_conn_mgr_resolve_conn: + self.context.injector.clear_binding(BaseCache) + mock_conn_mgr_resolve_conn.return_value = mock_conn + + conn_rec = await self.manager.find_inbound_connection(receipt) + assert conn_rec + + async def test_resolve_inbound_connection(self): + receipt = MessageReceipt( + sender_verkey=self.test_verkey, + recipient_verkey=self.test_target_verkey, + recipient_did_public=True, + ) + + mock_conn = async_mock.MagicMock() + mock_conn.connection_id = "dummy" + + with async_mock.patch.object( + InMemoryWallet, "get_local_did_for_verkey", async_mock.CoroutineMock() + ) as mock_wallet_get_local_did_for_verkey, async_mock.patch.object( + self.manager, "find_connection", async_mock.CoroutineMock() + ) as mock_mgr_find_conn: + mock_wallet_get_local_did_for_verkey.return_value = DIDInfo( + self.test_did, + self.test_verkey, + {"posted": True}, + method=SOV, + key_type=ED25519, + ) + mock_mgr_find_conn.return_value = mock_conn + + assert await self.manager.resolve_inbound_connection(receipt) + + async def test_resolve_inbound_connection_injector_error(self): + receipt = MessageReceipt( + sender_verkey=self.test_verkey, + recipient_verkey=self.test_target_verkey, + recipient_did_public=True, + ) + + mock_conn = async_mock.MagicMock() + mock_conn.connection_id = "dummy" + + with async_mock.patch.object( + InMemoryWallet, "get_local_did_for_verkey", async_mock.CoroutineMock() + ) as mock_wallet_get_local_did_for_verkey, async_mock.patch.object( + self.manager, "find_connection", async_mock.CoroutineMock() + ) as mock_mgr_find_conn: + mock_wallet_get_local_did_for_verkey.side_effect = InjectionError() + mock_mgr_find_conn.return_value = mock_conn + + assert await self.manager.resolve_inbound_connection(receipt) + + async def test_resolve_inbound_connection_wallet_not_found_error(self): + receipt = MessageReceipt( + sender_verkey=self.test_verkey, + recipient_verkey=self.test_target_verkey, + recipient_did_public=True, + ) + + mock_conn = async_mock.MagicMock() + mock_conn.connection_id = "dummy" + + with async_mock.patch.object( + InMemoryWallet, "get_local_did_for_verkey", async_mock.CoroutineMock() + ) as mock_wallet_get_local_did_for_verkey, async_mock.patch.object( + self.manager, "find_connection", async_mock.CoroutineMock() + ) as mock_mgr_find_conn: + mock_wallet_get_local_did_for_verkey.side_effect = WalletNotFoundError() + mock_mgr_find_conn.return_value = mock_conn + + assert await self.manager.resolve_inbound_connection(receipt) + + async def test_get_connection_targets_conn_invitation_no_did(self): + async with self.profile.session() as session: + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + did_doc = self.make_did_doc( + did=self.test_target_did, verkey=self.test_target_verkey + ) + await self.manager.store_did_document(did_doc) + + # First pass: not yet in cache + conn_invite = ConnectionInvitation( + did=None, + endpoint=self.test_endpoint, + recipient_keys=[self.test_target_verkey], + routing_keys=[self.test_verkey], + label="label", + ) + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), + ) + + targets = await self.manager.get_connection_targets( + connection_id=None, + connection=mock_conn, + ) + assert len(targets) == 1 + target = targets[0] + assert target.did == mock_conn.their_did + assert target.endpoint == conn_invite.endpoint + assert target.label == conn_invite.label + assert target.recipient_keys == conn_invite.recipient_keys + assert target.routing_keys == conn_invite.routing_keys + assert target.sender_key == local_did.verkey + + # Next pass: exercise cache + targets = await self.manager.get_connection_targets( + connection_id=None, + connection=mock_conn, + ) + assert len(targets) == 1 + target = targets[0] + assert target.did == mock_conn.their_did + assert target.endpoint == conn_invite.endpoint + assert target.label == conn_invite.label + assert target.recipient_keys == conn_invite.recipient_keys + assert target.routing_keys == conn_invite.routing_keys + assert target.sender_key == local_did.verkey + + async def test_get_connection_targets_retrieve_connection(self): + async with self.profile.session() as session: + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + did_doc = self.make_did_doc( + did=self.test_target_did, verkey=self.test_target_verkey + ) + await self.manager.store_did_document(did_doc) + + # Connection target not in cache + conn_invite = ConnectionInvitation( + did=None, + endpoint=self.test_endpoint, + recipient_keys=[self.test_target_verkey], + routing_keys=[self.test_verkey], + label="label", + ) + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), + ) + + with async_mock.patch.object( + ConnectionTarget, "serialize", autospec=True + ) as mock_conn_target_ser, async_mock.patch.object( + ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id: + mock_conn_rec_retrieve_by_id.return_value = mock_conn + mock_conn_target_ser.return_value = {"serialized": "value"} + targets = await self.manager.get_connection_targets( + connection_id="dummy", + connection=None, + ) + assert len(targets) == 1 + target = targets[0] + assert target.did == mock_conn.their_did + assert target.endpoint == conn_invite.endpoint + assert target.label == conn_invite.label + assert target.recipient_keys == conn_invite.recipient_keys + assert target.routing_keys == conn_invite.routing_keys + assert target.sender_key == local_did.verkey + + async def test_get_connection_targets_from_cache(self): + async with self.profile.session() as session: + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + did_doc = self.make_did_doc( + did=self.test_target_did, verkey=self.test_target_verkey + ) + await self.manager.store_did_document(did_doc) + + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.COMPLETED.rfc160, + ) + + with async_mock.patch.object( + ConnectionTarget, "serialize", autospec=True + ) as mock_conn_target_ser, async_mock.patch.object( + ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id, async_mock.patch.object( + self.manager, "fetch_connection_targets", async_mock.CoroutineMock() + ) as mock_fetch_connection_targets: + mock_fetch_connection_targets.return_value = [ConnectionTarget()] + mock_conn_rec_retrieve_by_id.return_value = mock_conn + mock_conn_target_ser.return_value = {"serialized": "value"} + targets = await self.manager.get_connection_targets( + connection_id="dummy", + connection=None, + ) + + cached_targets = await self.manager.get_connection_targets( + connection_id="dummy", + connection=None, + ) + assert mock_fetch_connection_targets.call_count == 1 + + async def test_get_connection_targets_no_cache(self): + async with self.profile.session() as session: + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + did_doc = self.make_did_doc( + did=self.test_target_did, verkey=self.test_target_verkey + ) + await self.manager.store_did_document(did_doc) + + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.COMPLETED.rfc160, + ) + + with async_mock.patch.object( + ConnectionTarget, "serialize", autospec=True + ) as mock_conn_target_ser, async_mock.patch.object( + ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() + ) as mock_conn_rec_retrieve_by_id, async_mock.patch.object( + self.manager, "fetch_connection_targets", async_mock.CoroutineMock() + ) as mock_fetch_connection_targets: + mock_fetch_connection_targets.return_value = [ConnectionTarget()] + mock_conn_rec_retrieve_by_id.return_value = mock_conn + mock_conn_target_ser.return_value = {"serialized": "value"} + self.profile.context.injector.clear_binding(BaseCache) + targets = await self.manager.get_connection_targets( + connection_id="dummy", + connection=None, + ) + assert targets + targets = await self.manager.get_connection_targets( + connection_id=None, + connection=mock_conn, + ) + assert targets + + async def test_get_connection_targets_no_conn_or_id(self): + with self.assertRaises(ValueError): + await self.manager.get_connection_targets() + + async def test_get_conn_targets_conn_invitation_no_cache(self): + async with self.profile.session() as session: + self.context.injector.clear_binding(BaseCache) + local_did = await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=self.test_seed, + did=self.test_did, + metadata=None, + ) + + did_doc = self.make_did_doc( + did=self.test_target_did, verkey=self.test_target_verkey + ) + await self.manager.store_did_document(did_doc) + + conn_invite = ConnectionInvitation( + did=None, + endpoint=self.test_endpoint, + recipient_keys=[self.test_target_verkey], + routing_keys=[self.test_verkey], + label="label", + ) + mock_conn = async_mock.MagicMock( + my_did=self.test_did, + their_did=self.test_target_did, + connection_id="dummy", + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), + ) + + targets = await self.manager.get_connection_targets( + connection_id=None, + connection=mock_conn, + ) + assert len(targets) == 1 + target = targets[0] + assert target.did == mock_conn.their_did + assert target.endpoint == conn_invite.endpoint + assert target.label == conn_invite.label + assert target.recipient_keys == conn_invite.recipient_keys + assert target.routing_keys == conn_invite.routing_keys + assert target.sender_key == local_did.verkey + + async def test_create_static_connection(self): + with async_mock.patch.object( + ConnRecord, "save", autospec=True + ) as mock_conn_rec_save: + _my, _their, conn_rec = await self.manager.create_static_connection( + my_did=self.test_did, + their_did=self.test_target_did, + their_verkey=self.test_target_verkey, + their_endpoint=self.test_endpoint, + ) + + assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED + + async def test_create_static_connection_multitenant(self): + self.context.update_settings( + {"wallet.id": "test_wallet", "multitenant.enabled": True} + ) + + self.multitenant_mgr.get_default_mediator.return_value = None + + with async_mock.patch.object( + ConnRecord, "save", autospec=True + ), async_mock.patch.object( + InMemoryWallet, "create_local_did", autospec=True + ) as mock_wallet_create_local_did: + mock_wallet_create_local_did.return_value = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + + await self.manager.create_static_connection( + my_did=self.test_did, + their_did=self.test_target_did, + their_verkey=self.test_target_verkey, + their_endpoint=self.test_endpoint, + ) + + self.route_manager.route_static.assert_called_once() + + async def test_create_static_connection_multitenant_auto_disclose_features(self): + self.context.update_settings( + { + "auto_disclose_features": True, + "multitenant.enabled": True, + "wallet.id": "test_wallet", + } + ) + self.multitenant_mgr.get_default_mediator.return_value = None + with async_mock.patch.object( + ConnRecord, "save", autospec=True + ), async_mock.patch.object( + InMemoryWallet, "create_local_did", autospec=True + ) as mock_wallet_create_local_did, async_mock.patch.object( + V20DiscoveryMgr, "proactive_disclose_features", async_mock.CoroutineMock() + ) as mock_proactive_disclose_features: + mock_wallet_create_local_did.return_value = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + await self.manager.create_static_connection( + my_did=self.test_did, + their_did=self.test_target_did, + their_verkey=self.test_target_verkey, + their_endpoint=self.test_endpoint, + ) + self.route_manager.route_static.assert_called_once() + mock_proactive_disclose_features.assert_called_once() + + async def test_create_static_connection_multitenant_mediator(self): + self.context.update_settings( + {"wallet.id": "test_wallet", "multitenant.enabled": True} + ) + + default_mediator = async_mock.MagicMock() + + with async_mock.patch.object( + ConnRecord, "save", autospec=True + ), async_mock.patch.object( + InMemoryWallet, "create_local_did", autospec=True + ) as mock_wallet_create_local_did, async_mock.patch.object( + BaseConnectionManager, "create_did_document" + ) as create_did_document, async_mock.patch.object( + BaseConnectionManager, "store_did_document" + ) as store_did_document: + mock_wallet_create_local_did.return_value = DIDInfo( + self.test_did, + self.test_verkey, + None, + method=SOV, + key_type=ED25519, + ) + + # With default mediator + self.multitenant_mgr.get_default_mediator.return_value = default_mediator + await self.manager.create_static_connection( + my_did=self.test_did, + their_did=self.test_target_did, + their_verkey=self.test_target_verkey, + their_endpoint=self.test_endpoint, + ) + + # Without default mediator + self.multitenant_mgr.get_default_mediator.return_value = None + await self.manager.create_static_connection( + my_did=self.test_did, + their_did=self.test_target_did, + their_verkey=self.test_target_verkey, + their_endpoint=self.test_endpoint, + ) + + assert self.route_manager.route_static.call_count == 2 + + their_info = DIDInfo( + self.test_target_did, + self.test_target_verkey, + {}, + method=SOV, + key_type=ED25519, + ) + create_did_document.assert_has_calls( + [ + call( + their_info, + None, + [self.test_endpoint], + mediation_records=[default_mediator], + ), + call(their_info, None, [self.test_endpoint], mediation_records=[]), + ] + ) + + async def test_create_static_connection_no_their(self): + with async_mock.patch.object( + ConnRecord, "save", autospec=True + ) as mock_conn_rec_save: + with self.assertRaises(BaseConnectionManagerError): + await self.manager.create_static_connection( + my_did=self.test_did, + their_did=None, + their_verkey=self.test_target_verkey, + their_endpoint=self.test_endpoint, + ) + + async def test_create_static_connection_their_seed_only(self): + with async_mock.patch.object( + ConnRecord, "save", autospec=True + ) as mock_conn_rec_save: + _my, _their, conn_rec = await self.manager.create_static_connection( + my_did=self.test_did, + their_seed=self.test_seed, + their_endpoint=self.test_endpoint, + ) + + assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED + + async def test_find_connection_retrieve_by_did(self): + with async_mock.patch.object( + ConnRecord, "retrieve_by_did", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_did: + mock_conn_retrieve_by_did.return_value = async_mock.MagicMock( + state=ConnRecord.State.RESPONSE.rfc23, + save=async_mock.CoroutineMock(), + ) + + conn_rec = await self.manager.find_connection( + their_did=self.test_target_did, + my_did=self.test_did, + my_verkey=self.test_verkey, + auto_complete=True, + ) + assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED + + async def test_find_connection_retrieve_by_did_auto_disclose_features(self): + self.context.update_settings({"auto_disclose_features": True}) + with async_mock.patch.object( + ConnRecord, "retrieve_by_did", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_did, async_mock.patch.object( + V20DiscoveryMgr, "proactive_disclose_features", async_mock.CoroutineMock() + ) as mock_proactive_disclose_features: + mock_conn_retrieve_by_did.return_value = async_mock.MagicMock( + state=ConnRecord.State.RESPONSE.rfc23, + save=async_mock.CoroutineMock(), + ) + + conn_rec = await self.manager.find_connection( + their_did=self.test_target_did, + my_did=self.test_did, + my_verkey=self.test_verkey, + auto_complete=True, + ) + assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED + mock_proactive_disclose_features.assert_called_once() + + async def test_find_connection_retrieve_by_invitation_key(self): + with async_mock.patch.object( + ConnRecord, "retrieve_by_did", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_did, async_mock.patch.object( + ConnRecord, "retrieve_by_invitation_key", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_invitation_key: + mock_conn_retrieve_by_did.side_effect = StorageNotFoundError() + mock_conn_retrieve_by_invitation_key.return_value = async_mock.MagicMock( + state=ConnRecord.State.RESPONSE, + save=async_mock.CoroutineMock(), + ) + + conn_rec = await self.manager.find_connection( + their_did=self.test_target_did, + my_did=self.test_did, + my_verkey=self.test_verkey, + ) + assert conn_rec + + async def test_find_connection_retrieve_none_by_invitation_key(self): + with async_mock.patch.object( + ConnRecord, "retrieve_by_did", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_did, async_mock.patch.object( + ConnRecord, "retrieve_by_invitation_key", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_invitation_key: + mock_conn_retrieve_by_did.side_effect = StorageNotFoundError() + mock_conn_retrieve_by_invitation_key.side_effect = StorageNotFoundError() + + conn_rec = await self.manager.find_connection( + their_did=self.test_target_did, + my_did=self.test_did, + my_verkey=self.test_verkey, + ) + assert conn_rec is None + + async def test_get_endpoints(self): + conn_id = "dummy" + + with async_mock.patch.object( + ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() + ) as mock_retrieve, async_mock.patch.object( + InMemoryWallet, "get_local_did", autospec=True + ) as mock_wallet_get_local_did, async_mock.patch.object( + self.manager, "get_connection_targets", async_mock.CoroutineMock() + ) as mock_get_conn_targets: + mock_retrieve.return_value = async_mock.MagicMock() + mock_wallet_get_local_did.return_value = async_mock.MagicMock( + metadata={"endpoint": "localhost:8020"} + ) + mock_get_conn_targets.return_value = [ + async_mock.MagicMock(endpoint="10.20.30.40:5060") + ] + assert await self.manager.get_endpoints(conn_id) == ( + "localhost:8020", + "10.20.30.40:5060", + ) diff --git a/aries_cloudagent/core/dispatcher.py b/aries_cloudagent/core/dispatcher.py index c1a176dd4f..99197ed0f4 100644 --- a/aries_cloudagent/core/dispatcher.py +++ b/aries_cloudagent/core/dispatcher.py @@ -8,14 +8,14 @@ import asyncio import logging import os +from typing import Callable, Coroutine, Optional, Tuple, Union import warnings - -from typing import Callable, Coroutine, Optional, Union, Tuple import weakref from aiohttp.web import HTTPException from ..config.logging import get_logger_inst +from ..connections.base_manager import BaseConnectionManager from ..connections.models.conn_record import ConnRecord from ..core.profile import Profile from ..messaging.agent_message import AgentMessage @@ -25,7 +25,6 @@ from ..messaging.request_context import RequestContext from ..messaging.responder import BaseResponder, SKIP_ACTIVE_CONN_CHECK_MSG_TYPES from ..messaging.util import datetime_now -from ..protocols.connections.v1_0.manager import ConnectionManager from ..protocols.problem_report.v1_0.message import ProblemReport from ..transport.inbound.message import InboundMessage from ..transport.outbound.message import OutboundMessage @@ -33,16 +32,9 @@ from ..utils.stats import Collector from ..utils.task_queue import CompletedTask, PendingTask, TaskQueue from ..utils.tracing import get_timer, trace_event - from .error import ProtocolMinorVersionNotSupported from .protocol_registry import ProtocolRegistry -from .util import ( - get_version_from_message_type, - validate_get_response_version, - # WARNING_DEGRADED_FEATURES, - # WARNING_VERSION_MISMATCH, - # WARNING_VERSION_NOT_SUPPORTED, -) +from .util import get_version_from_message_type, validate_get_response_version class ProblemReportParseError(MessageParseError): @@ -242,7 +234,7 @@ async def handle_message( session, inbound_message.connection_id ) else: - connection_mgr = ConnectionManager(profile) + connection_mgr = BaseConnectionManager(profile) connection = await connection_mgr.find_inbound_connection( inbound_message.receipt ) diff --git a/aries_cloudagent/core/tests/test_dispatcher.py b/aries_cloudagent/core/tests/test_dispatcher.py index c76415133a..8c793c327d 100644 --- a/aries_cloudagent/core/tests/test_dispatcher.py +++ b/aries_cloudagent/core/tests/test_dispatcher.py @@ -112,7 +112,7 @@ async def test_dispatch(self): with async_mock.patch.object( StubAgentMessageHandler, "handle", autospec=True ) as handler_mock, async_mock.patch.object( - test_module, "ConnectionManager", autospec=True + test_module, "BaseConnectionManager", autospec=True ) as conn_mgr_mock, async_mock.patch.object( test_module, "get_version_from_message_type", diff --git a/aries_cloudagent/protocols/connections/v1_0/manager.py b/aries_cloudagent/protocols/connections/v1_0/manager.py index 8a55161bbb..ab53e06db4 100644 --- a/aries_cloudagent/protocols/connections/v1_0/manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/manager.py @@ -5,28 +5,20 @@ from ....core.oob_processor import OobMessageProcessor -from ....cache.base import BaseCache -from ....config.base import InjectionError from ....config.logging import get_logger_inst from ....connections.base_manager import BaseConnectionManager from ....connections.models.conn_record import ConnRecord -from ....connections.models.connection_target import ConnectionTarget from ....core.error import BaseError from ....core.profile import Profile from ....messaging.responder import BaseResponder from ....messaging.valid import IndyDID from ....multitenant.base import BaseMultitenantManager -from ....storage.error import StorageError, StorageNotFoundError +from ....storage.error import StorageNotFoundError from ....transport.inbound.receipt import MessageReceipt from ....wallet.base import BaseWallet -from ....wallet.crypto import create_keypair, seed_to_did -from ....wallet.did_info import DIDInfo from ....wallet.did_method import SOV -from ....wallet.error import WalletNotFoundError from ....wallet.key_type import ED25519 -from ....wallet.util import bytes_to_b58 from ...coordinate_mediation.v1_0.manager import MediationManager -from ...discovery.v2_0.manager import V20DiscoveryMgr from ...routing.v1_0.manager import RoutingManager from .message_types import ARIES_PROTOCOL as CONN_PROTO from .messages.connection_invitation import ConnectionInvitation @@ -797,321 +789,6 @@ async def accept_response( return connection - async def get_endpoints(self, conn_id: str) -> Tuple[str, str]: - """ - Get connection endpoints. - - Args: - conn_id: connection identifier - - Returns: - Their endpoint for this connection - - """ - async with self.profile.session() as session: - connection = await ConnRecord.retrieve_by_id(session, conn_id) - wallet = session.inject(BaseWallet) - my_did_info = await wallet.get_local_did(connection.my_did) - my_endpoint = my_did_info.metadata.get( - "endpoint", - self.profile.settings.get("default_endpoint"), - ) - - conn_targets = await self.get_connection_targets( - connection_id=connection.connection_id, - connection=connection, - ) - return (my_endpoint, conn_targets[0].endpoint) - - async def create_static_connection( - self, - my_did: str = None, - my_seed: str = None, - their_did: str = None, - their_seed: str = None, - their_verkey: str = None, - their_endpoint: str = None, - their_label: str = None, - alias: str = None, - mediation_id: str = None, - ) -> Tuple[DIDInfo, DIDInfo, ConnRecord]: - """ - Register a new static connection (for use by the test suite). - - Args: - my_did: override the DID used in the connection - my_seed: provide a seed used to generate our DID and keys - their_did: provide the DID used by the other party - their_seed: provide a seed used to generate their DID and keys - their_verkey: provide the verkey used by the other party - their_endpoint: their URL endpoint for routing messages - alias: an alias for this connection record - - Returns: - Tuple: my DIDInfo, their DIDInfo, new `ConnRecord` instance - - """ - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - # seed and DID optional - my_info = await wallet.create_local_did(SOV, ED25519, my_seed, my_did) - - # must provide their DID and verkey if the seed is not known - if (not their_did or not their_verkey) and not their_seed: - raise ConnectionManagerError( - "Either a verkey or seed must be provided for the other party" - ) - if not their_did: - their_did = seed_to_did(their_seed) - if not their_verkey: - their_verkey_bin, _ = create_keypair(ED25519, their_seed.encode()) - their_verkey = bytes_to_b58(their_verkey_bin) - their_info = DIDInfo(their_did, their_verkey, {}, method=SOV, key_type=ED25519) - - # Create connection record - connection = ConnRecord( - invitation_mode=ConnRecord.INVITATION_MODE_STATIC, - my_did=my_info.did, - their_did=their_info.did, - their_label=their_label, - state=ConnRecord.State.COMPLETED.rfc160, - alias=alias, - connection_protocol=CONN_PROTO, - ) - async with self.profile.session() as session: - await connection.save(session, reason="Created new static connection") - if session.settings.get("auto_disclose_features"): - discovery_mgr = V20DiscoveryMgr(self._profile) - await discovery_mgr.proactive_disclose_features( - connection_id=connection.connection_id - ) - - # Routing - mediation_record = await self._route_manager.mediation_record_if_id( - self.profile, mediation_id, or_default=True - ) - - multitenant_mgr = self.profile.inject_or(BaseMultitenantManager) - wallet_id = self.profile.settings.get("wallet.id") - - base_mediation_record = None - if multitenant_mgr and wallet_id: - base_mediation_record = await multitenant_mgr.get_default_mediator() - - await self._route_manager.route_static( - self.profile, connection, mediation_record - ) - - # Synthesize their DID doc - did_doc = await self.create_did_document( - their_info, - None, - [their_endpoint or ""], - mediation_records=list( - filter(None, [base_mediation_record, mediation_record]) - ), - ) - - await self.store_did_document(did_doc) - - return my_info, their_info, connection - - async def find_connection( - self, - their_did: str, - my_did: str = None, - my_verkey: str = None, - auto_complete=False, - ) -> ConnRecord: - """ - Look up existing connection information for a sender verkey. - - Args: - their_did: Their DID - my_did: My DID - my_verkey: My verkey - auto_complete: Should this connection automatically be promoted to active - - Returns: - The located `ConnRecord`, if any - - """ - # self._log_state( - # "Finding connection", - # {"their_did": their_did, "my_did": my_did, "my_verkey": my_verkey}, - # ) - connection = None - if their_did: - try: - async with self.profile.session() as session: - connection = await ConnRecord.retrieve_by_did( - session, their_did, my_did - ) - except StorageNotFoundError: - pass - - if ( - connection - and ConnRecord.State.get(connection.state) is ConnRecord.State.RESPONSE - and auto_complete - ): - connection.state = ConnRecord.State.COMPLETED.rfc160 - async with self.profile.session() as session: - await connection.save(session, reason="Connection promoted to active") - if session.settings.get("auto_disclose_features"): - discovery_mgr = V20DiscoveryMgr(self._profile) - await discovery_mgr.proactive_disclose_features( - connection_id=connection.connection_id - ) - - if not connection and my_verkey: - try: - async with self.profile.session() as session: - connection = await ConnRecord.retrieve_by_invitation_key( - session, - my_verkey, - their_role=ConnRecord.Role.REQUESTER.rfc160, - ) - except StorageError: - pass - - return connection - - async def find_inbound_connection(self, receipt: MessageReceipt) -> ConnRecord: - """ - Deserialize an incoming message and further populate the request context. - - Args: - receipt: The message receipt - - Returns: - The `ConnRecord` associated with the expanded message, if any - - """ - - cache_key = None - connection = None - resolved = False - - if receipt.sender_verkey and receipt.recipient_verkey: - cache_key = ( - f"connection_by_verkey::{receipt.sender_verkey}" - f"::{receipt.recipient_verkey}" - ) - cache = self.profile.inject_or(BaseCache) - if cache: - async with cache.acquire(cache_key) as entry: - if entry.result: - cached = entry.result - receipt.sender_did = cached["sender_did"] - receipt.recipient_did_public = cached["recipient_did_public"] - receipt.recipient_did = cached["recipient_did"] - async with self.profile.session() as session: - connection = await ConnRecord.retrieve_by_id( - session, cached["id"] - ) - else: - connection = await self.resolve_inbound_connection(receipt) - if connection: - cache_val = { - "id": connection.connection_id, - "sender_did": receipt.sender_did, - "recipient_did": receipt.recipient_did, - "recipient_did_public": receipt.recipient_did_public, - } - await entry.set_result(cache_val, 3600) - resolved = True - - if not connection and not resolved: - connection = await self.resolve_inbound_connection(receipt) - return connection - - async def resolve_inbound_connection(self, receipt: MessageReceipt) -> ConnRecord: - """ - Populate the receipt DID information and find the related `ConnRecord`. - - Args: - receipt: The message receipt - - Returns: - The `ConnRecord` associated with the expanded message, if any - - """ - - if receipt.sender_verkey: - try: - receipt.sender_did = await self.find_did_for_key(receipt.sender_verkey) - except StorageNotFoundError: - self._logger.warning( - "No corresponding DID found for sender verkey: %s", - receipt.sender_verkey, - ) - - if receipt.recipient_verkey: - try: - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - my_info = await wallet.get_local_did_for_verkey( - receipt.recipient_verkey - ) - receipt.recipient_did = my_info.did - if "posted" in my_info.metadata and my_info.metadata["posted"] is True: - receipt.recipient_did_public = True - except InjectionError: - self._logger.warning( - "Cannot resolve recipient verkey, no wallet defined by " - "context: %s", - receipt.recipient_verkey, - ) - except WalletNotFoundError: - self._logger.warning( - "No corresponding DID found for recipient verkey: %s", - receipt.recipient_verkey, - ) - - return await self.find_connection( - receipt.sender_did, receipt.recipient_did, receipt.recipient_verkey, True - ) - - async def get_connection_targets( - self, *, connection_id: str = None, connection: ConnRecord = None - ): - """Create a connection target from a `ConnRecord`. - - Args: - connection_id: The connection ID to search for - connection: The connection record itself, if already available - """ - if not connection_id: - connection_id = connection.connection_id - cache = self.profile.inject_or(BaseCache) - cache_key = f"connection_target::{connection_id}" - if cache: - async with cache.acquire(cache_key) as entry: - if entry.result: - targets = [ - ConnectionTarget.deserialize(row) for row in entry.result - ] - else: - if not connection: - async with self.profile.session() as session: - connection = await ConnRecord.retrieve_by_id( - session, connection_id - ) - - targets = await self.fetch_connection_targets(connection) - - if connection.state == ConnRecord.State.COMPLETED.rfc160: - # Only set cache if connection has reached completed state - # Otherwise, a replica that participated early in exchange - # may have bad data set in cache. - await entry.set_result( - [row.serialize() for row in targets], 3600 - ) - else: - targets = await self.fetch_connection_targets(connection) - return targets - async def establish_inbound( self, connection: ConnRecord, diff --git a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py index f114bcfbe5..8bf13a35a9 100644 --- a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py @@ -9,6 +9,7 @@ JsonWebKey2020, ) +from .. import manager as test_module from .....cache.base import BaseCache from .....cache.in_memory import InMemoryCache from .....config.base import InjectionError @@ -16,35 +17,34 @@ from .....connections.models.conn_record import ConnRecord from .....connections.models.connection_target import ConnectionTarget from .....connections.models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service -from .....core.oob_processor import OobMessageProcessor from .....core.in_memory import InMemoryProfile +from .....core.oob_processor import OobMessageProcessor from .....core.profile import ProfileSession from .....did.did_key import DIDKey from .....messaging.responder import BaseResponder, MockResponder from .....multitenant.base import BaseMultitenantManager from .....multitenant.manager import MultitenantManager from .....protocols.routing.v1_0.manager import RoutingManager +from .....resolver.default.legacy_peer import LegacyPeerDIDResolver from .....resolver.did_resolver import DIDResolver from .....storage.error import StorageNotFoundError from .....transport.inbound.receipt import MessageReceipt from .....wallet.base import DIDInfo -from .....wallet.did_method import SOV, DIDMethods +from .....wallet.did_method import DIDMethods, SOV from .....wallet.error import WalletNotFoundError from .....wallet.in_memory import InMemoryWallet from .....wallet.key_type import ED25519 +from .....wallet.util import b58_to_bytes, bytes_to_b64 from ....coordinate_mediation.v1_0.manager import MediationManager -from ....coordinate_mediation.v1_0.route_manager import RouteManager from ....coordinate_mediation.v1_0.messages.mediate_request import MediationRequest from ....coordinate_mediation.v1_0.models.mediation_record import MediationRecord +from ....coordinate_mediation.v1_0.route_manager import RouteManager from ....discovery.v2_0.manager import V20DiscoveryMgr - from ..manager import ConnectionManager, ConnectionManagerError -from .. import manager as test_module from ..messages.connection_invitation import ConnectionInvitation from ..messages.connection_request import ConnectionRequest from ..messages.connection_response import ConnectionResponse from ..models.connection_detail import ConnectionDetail -from .....wallet.util import bytes_to_b64, b58_to_bytes class TestConnectionManager(AsyncTestCase): @@ -86,6 +86,8 @@ async def setUp(self): self.route_manager.mediation_record_if_id = async_mock.CoroutineMock( return_value=None ) + self.resolver = DIDResolver() + self.resolver.register_resolver(LegacyPeerDIDResolver()) self.profile = InMemoryProfile.test_profile( { @@ -101,6 +103,7 @@ async def setUp(self): OobMessageProcessor: self.oob_mock, RouteManager: self.route_manager, DIDMethods: DIDMethods(), + DIDResolver: self.resolver, }, ) self.context = self.profile.context @@ -1268,1545 +1271,6 @@ async def test_accept_response_auto_send_mediation_request(self): assert isinstance(message, MediationRequest) assert target["connection_id"] == conn_rec.connection_id - async def test_get_endpoints(self): - conn_id = "dummy" - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_retrieve, async_mock.patch.object( - InMemoryWallet, "get_local_did", autospec=True - ) as mock_wallet_get_local_did, async_mock.patch.object( - self.manager, "get_connection_targets", async_mock.CoroutineMock() - ) as mock_get_conn_targets: - mock_retrieve.return_value = async_mock.MagicMock() - mock_wallet_get_local_did.return_value = async_mock.MagicMock( - metadata={"endpoint": "localhost:8020"} - ) - mock_get_conn_targets.return_value = [ - async_mock.MagicMock(endpoint="10.20.30.40:5060") - ] - assert await self.manager.get_endpoints(conn_id) == ( - "localhost:8020", - "10.20.30.40:5060", - ) - - async def test_create_static_connection(self): - with async_mock.patch.object( - ConnRecord, "save", autospec=True - ) as mock_conn_rec_save: - _my, _their, conn_rec = await self.manager.create_static_connection( - my_did=self.test_did, - their_did=self.test_target_did, - their_verkey=self.test_target_verkey, - their_endpoint=self.test_endpoint, - ) - - assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED - - async def test_create_static_connection_multitenant(self): - self.context.update_settings( - {"wallet.id": "test_wallet", "multitenant.enabled": True} - ) - - self.multitenant_mgr.get_default_mediator.return_value = None - - with async_mock.patch.object( - ConnRecord, "save", autospec=True - ), async_mock.patch.object( - InMemoryWallet, "create_local_did", autospec=True - ) as mock_wallet_create_local_did: - mock_wallet_create_local_did.return_value = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - await self.manager.create_static_connection( - my_did=self.test_did, - their_did=self.test_target_did, - their_verkey=self.test_target_verkey, - their_endpoint=self.test_endpoint, - ) - - self.route_manager.route_static.assert_called_once() - - async def test_create_static_connection_multitenant_auto_disclose_features(self): - self.context.update_settings( - { - "auto_disclose_features": True, - "multitenant.enabled": True, - "wallet.id": "test_wallet", - } - ) - self.multitenant_mgr.get_default_mediator.return_value = None - with async_mock.patch.object( - ConnRecord, "save", autospec=True - ), async_mock.patch.object( - InMemoryWallet, "create_local_did", autospec=True - ) as mock_wallet_create_local_did, async_mock.patch.object( - V20DiscoveryMgr, "proactive_disclose_features", async_mock.CoroutineMock() - ) as mock_proactive_disclose_features: - mock_wallet_create_local_did.return_value = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - await self.manager.create_static_connection( - my_did=self.test_did, - their_did=self.test_target_did, - their_verkey=self.test_target_verkey, - their_endpoint=self.test_endpoint, - ) - self.route_manager.route_static.assert_called_once() - mock_proactive_disclose_features.assert_called_once() - - async def test_create_static_connection_multitenant_mediator(self): - self.context.update_settings( - {"wallet.id": "test_wallet", "multitenant.enabled": True} - ) - - default_mediator = async_mock.MagicMock() - - with async_mock.patch.object( - ConnRecord, "save", autospec=True - ), async_mock.patch.object( - InMemoryWallet, "create_local_did", autospec=True - ) as mock_wallet_create_local_did, async_mock.patch.object( - ConnectionManager, "create_did_document" - ) as create_did_document, async_mock.patch.object( - ConnectionManager, "store_did_document" - ) as store_did_document: - mock_wallet_create_local_did.return_value = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - # With default mediator - self.multitenant_mgr.get_default_mediator.return_value = default_mediator - await self.manager.create_static_connection( - my_did=self.test_did, - their_did=self.test_target_did, - their_verkey=self.test_target_verkey, - their_endpoint=self.test_endpoint, - ) - - # Without default mediator - self.multitenant_mgr.get_default_mediator.return_value = None - await self.manager.create_static_connection( - my_did=self.test_did, - their_did=self.test_target_did, - their_verkey=self.test_target_verkey, - their_endpoint=self.test_endpoint, - ) - - assert self.route_manager.route_static.call_count == 2 - - their_info = DIDInfo( - self.test_target_did, - self.test_target_verkey, - {}, - method=SOV, - key_type=ED25519, - ) - create_did_document.assert_has_calls( - [ - call( - their_info, - None, - [self.test_endpoint], - mediation_records=[default_mediator], - ), - call(their_info, None, [self.test_endpoint], mediation_records=[]), - ] - ) - - async def test_create_static_connection_no_their(self): - with async_mock.patch.object( - ConnRecord, "save", autospec=True - ) as mock_conn_rec_save: - with self.assertRaises(ConnectionManagerError): - await self.manager.create_static_connection( - my_did=self.test_did, - their_did=None, - their_verkey=self.test_target_verkey, - their_endpoint=self.test_endpoint, - ) - - async def test_create_static_connection_their_seed_only(self): - with async_mock.patch.object( - ConnRecord, "save", autospec=True - ) as mock_conn_rec_save: - _my, _their, conn_rec = await self.manager.create_static_connection( - my_did=self.test_did, - their_seed=self.test_seed, - their_endpoint=self.test_endpoint, - ) - - assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED - - async def test_find_connection_retrieve_by_did(self): - with async_mock.patch.object( - ConnRecord, "retrieve_by_did", async_mock.CoroutineMock() - ) as mock_conn_retrieve_by_did: - mock_conn_retrieve_by_did.return_value = async_mock.MagicMock( - state=ConnRecord.State.RESPONSE.rfc23, - save=async_mock.CoroutineMock(), - ) - - conn_rec = await self.manager.find_connection( - their_did=self.test_target_did, - my_did=self.test_did, - my_verkey=self.test_verkey, - auto_complete=True, - ) - assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED - - async def test_find_connection_retrieve_by_did_auto_disclose_features(self): - self.context.update_settings({"auto_disclose_features": True}) - with async_mock.patch.object( - ConnRecord, "retrieve_by_did", async_mock.CoroutineMock() - ) as mock_conn_retrieve_by_did, async_mock.patch.object( - V20DiscoveryMgr, "proactive_disclose_features", async_mock.CoroutineMock() - ) as mock_proactive_disclose_features: - mock_conn_retrieve_by_did.return_value = async_mock.MagicMock( - state=ConnRecord.State.RESPONSE.rfc23, - save=async_mock.CoroutineMock(), - ) - - conn_rec = await self.manager.find_connection( - their_did=self.test_target_did, - my_did=self.test_did, - my_verkey=self.test_verkey, - auto_complete=True, - ) - assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED - mock_proactive_disclose_features.assert_called_once() - - async def test_find_connection_retrieve_by_invitation_key(self): - with async_mock.patch.object( - ConnRecord, "retrieve_by_did", async_mock.CoroutineMock() - ) as mock_conn_retrieve_by_did, async_mock.patch.object( - ConnRecord, "retrieve_by_invitation_key", async_mock.CoroutineMock() - ) as mock_conn_retrieve_by_invitation_key: - mock_conn_retrieve_by_did.side_effect = StorageNotFoundError() - mock_conn_retrieve_by_invitation_key.return_value = async_mock.MagicMock( - state=ConnRecord.State.RESPONSE, - save=async_mock.CoroutineMock(), - ) - - conn_rec = await self.manager.find_connection( - their_did=self.test_target_did, - my_did=self.test_did, - my_verkey=self.test_verkey, - ) - assert conn_rec - - async def test_find_connection_retrieve_none_by_invitation_key(self): - with async_mock.patch.object( - ConnRecord, "retrieve_by_did", async_mock.CoroutineMock() - ) as mock_conn_retrieve_by_did, async_mock.patch.object( - ConnRecord, "retrieve_by_invitation_key", async_mock.CoroutineMock() - ) as mock_conn_retrieve_by_invitation_key: - mock_conn_retrieve_by_did.side_effect = StorageNotFoundError() - mock_conn_retrieve_by_invitation_key.side_effect = StorageNotFoundError() - - conn_rec = await self.manager.find_connection( - their_did=self.test_target_did, - my_did=self.test_did, - my_verkey=self.test_verkey, - ) - assert conn_rec is None - - async def test_find_inbound_connection(self): - receipt = MessageReceipt( - sender_verkey=self.test_verkey, - recipient_verkey=self.test_target_verkey, - recipient_did_public=False, - ) - - mock_conn = async_mock.MagicMock() - mock_conn.connection_id = "dummy" - - # First pass: not yet in cache - with async_mock.patch.object( - ConnectionManager, "resolve_inbound_connection", async_mock.CoroutineMock() - ) as mock_conn_mgr_resolve_conn: - mock_conn_mgr_resolve_conn.return_value = mock_conn - - conn_rec = await self.manager.find_inbound_connection(receipt) - assert conn_rec - - # Second pass: in cache - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - conn_rec = await self.manager.find_inbound_connection(receipt) - assert conn_rec.id == mock_conn.id - - async def test_find_inbound_connection_no_cache(self): - receipt = MessageReceipt( - sender_verkey=self.test_verkey, - recipient_verkey=self.test_target_verkey, - recipient_did_public=False, - ) - - mock_conn = async_mock.MagicMock() - mock_conn.connection_id = "dummy" - - with async_mock.patch.object( - ConnectionManager, "resolve_inbound_connection", async_mock.CoroutineMock() - ) as mock_conn_mgr_resolve_conn: - self.context.injector.clear_binding(BaseCache) - mock_conn_mgr_resolve_conn.return_value = mock_conn - - conn_rec = await self.manager.find_inbound_connection(receipt) - assert conn_rec - - async def test_resolve_inbound_connection(self): - receipt = MessageReceipt( - sender_verkey=self.test_verkey, - recipient_verkey=self.test_target_verkey, - recipient_did_public=True, - ) - - mock_conn = async_mock.MagicMock() - mock_conn.connection_id = "dummy" - - with async_mock.patch.object( - InMemoryWallet, "get_local_did_for_verkey", async_mock.CoroutineMock() - ) as mock_wallet_get_local_did_for_verkey, async_mock.patch.object( - self.manager, "find_connection", async_mock.CoroutineMock() - ) as mock_mgr_find_conn: - mock_wallet_get_local_did_for_verkey.return_value = DIDInfo( - self.test_did, - self.test_verkey, - {"posted": True}, - method=SOV, - key_type=ED25519, - ) - mock_mgr_find_conn.return_value = mock_conn - - assert await self.manager.resolve_inbound_connection(receipt) - - async def test_resolve_inbound_connection_injector_error(self): - receipt = MessageReceipt( - sender_verkey=self.test_verkey, - recipient_verkey=self.test_target_verkey, - recipient_did_public=True, - ) - - mock_conn = async_mock.MagicMock() - mock_conn.connection_id = "dummy" - - with async_mock.patch.object( - InMemoryWallet, "get_local_did_for_verkey", async_mock.CoroutineMock() - ) as mock_wallet_get_local_did_for_verkey, async_mock.patch.object( - self.manager, "find_connection", async_mock.CoroutineMock() - ) as mock_mgr_find_conn: - mock_wallet_get_local_did_for_verkey.side_effect = InjectionError() - mock_mgr_find_conn.return_value = mock_conn - - assert await self.manager.resolve_inbound_connection(receipt) - - async def test_resolve_inbound_connection_wallet_not_found_error(self): - receipt = MessageReceipt( - sender_verkey=self.test_verkey, - recipient_verkey=self.test_target_verkey, - recipient_did_public=True, - ) - - mock_conn = async_mock.MagicMock() - mock_conn.connection_id = "dummy" - - with async_mock.patch.object( - InMemoryWallet, "get_local_did_for_verkey", async_mock.CoroutineMock() - ) as mock_wallet_get_local_did_for_verkey, async_mock.patch.object( - self.manager, "find_connection", async_mock.CoroutineMock() - ) as mock_mgr_find_conn: - mock_wallet_get_local_did_for_verkey.side_effect = WalletNotFoundError() - mock_mgr_find_conn.return_value = mock_conn - - assert await self.manager.resolve_inbound_connection(receipt) - - async def test_create_did_document(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - did_doc = await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_not_active(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.ABANDONED.rfc23, - ) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_no_services(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - x_did_doc._service = {} - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_no_service_endpoint(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - x_did_doc._service = {} - x_did_doc.set( - Service(self.test_target_did, "dummy", "IndyAgent", [], [], "", 0) - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_no_service_recip_keys(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - x_did_doc._service = {} - x_did_doc.set( - Service( - self.test_target_did, - "dummy", - "IndyAgent", - [], - [], - self.test_endpoint, - 0, - ) - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_mediation(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - mediation_record = MediationRecord( - role=MediationRecord.ROLE_CLIENT, - state=MediationRecord.STATE_GRANTED, - connection_id=self.test_mediator_conn_id, - routing_keys=self.test_mediator_routing_keys, - endpoint=self.test_mediator_endpoint, - ) - doc = await self.manager.create_did_document( - did_info, mediation_records=[mediation_record] - ) - assert doc.service - services = list(doc.service.values()) - assert len(services) == 1 - (service,) = services - service_public_keys = service.routing_keys[0] - assert service_public_keys.value == mediation_record.routing_keys[0] - assert service.endpoint == mediation_record.endpoint - - async def test_create_did_document_multiple_mediators(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - mediation_record1 = MediationRecord( - role=MediationRecord.ROLE_CLIENT, - state=MediationRecord.STATE_GRANTED, - connection_id=self.test_mediator_conn_id, - routing_keys=self.test_mediator_routing_keys, - endpoint=self.test_mediator_endpoint, - ) - mediation_record2 = MediationRecord( - role=MediationRecord.ROLE_CLIENT, - state=MediationRecord.STATE_GRANTED, - connection_id="mediator-conn-id2", - routing_keys=["05e8afd1-b4f0-46b7-a285-7a08c8a37caf"], - endpoint="http://mediatorw.example.com", - ) - doc = await self.manager.create_did_document( - did_info, mediation_records=[mediation_record1, mediation_record2] - ) - assert doc.service - services = list(doc.service.values()) - assert len(services) == 1 - (service,) = services - assert service.routing_keys[0].value == mediation_record1.routing_keys[0] - assert service.routing_keys[1].value == mediation_record2.routing_keys[0] - assert service.endpoint == mediation_record2.endpoint - - async def test_create_did_document_mediation_svc_endpoints_overwritten(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - mediation_record = MediationRecord( - role=MediationRecord.ROLE_CLIENT, - state=MediationRecord.STATE_GRANTED, - connection_id=self.test_mediator_conn_id, - routing_keys=self.test_mediator_routing_keys, - endpoint=self.test_mediator_endpoint, - ) - doc = await self.manager.create_did_document( - did_info, - svc_endpoints=[self.test_endpoint], - mediation_records=[mediation_record], - ) - assert doc.service - services = list(doc.service.values()) - assert len(services) == 1 - (service,) = services - service_public_keys = service.routing_keys[0] - assert service_public_keys.value == mediation_record.routing_keys[0] - assert service.endpoint == mediation_record.endpoint - - async def test_did_key_storage(self): - await self.manager.add_key_for_did( - did=self.test_target_did, key=self.test_target_verkey - ) - - did = await self.manager.find_did_for_key(key=self.test_target_verkey) - assert did == self.test_target_did - await self.manager.remove_keys_for_did(self.test_target_did) - - async def test_get_connection_targets_conn_invitation_no_did(self): - async with self.profile.session() as session: - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=self.test_did, - metadata=None, - ) - - did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - await self.manager.store_did_document(did_doc) - - # First pass: not yet in cache - conn_invite = ConnectionInvitation( - did=None, - endpoint=self.test_endpoint, - recipient_keys=[self.test_target_verkey], - routing_keys=[self.test_verkey], - label="label", - ) - mock_conn = async_mock.MagicMock( - my_did=self.test_did, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), - ) - - targets = await self.manager.get_connection_targets( - connection_id=None, - connection=mock_conn, - ) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == conn_invite.endpoint - assert target.label == conn_invite.label - assert target.recipient_keys == conn_invite.recipient_keys - assert target.routing_keys == conn_invite.routing_keys - assert target.sender_key == local_did.verkey - - # Next pass: exercise cache - targets = await self.manager.get_connection_targets( - connection_id=None, - connection=mock_conn, - ) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == conn_invite.endpoint - assert target.label == conn_invite.label - assert target.recipient_keys == conn_invite.recipient_keys - assert target.routing_keys == conn_invite.routing_keys - assert target.sender_key == local_did.verkey - - async def test_get_connection_targets_retrieve_connection(self): - async with self.profile.session() as session: - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=self.test_did, - metadata=None, - ) - - did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - await self.manager.store_did_document(did_doc) - - # Connection target not in cache - conn_invite = ConnectionInvitation( - did=None, - endpoint=self.test_endpoint, - recipient_keys=[self.test_target_verkey], - routing_keys=[self.test_verkey], - label="label", - ) - mock_conn = async_mock.MagicMock( - my_did=self.test_did, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), - ) - - with async_mock.patch.object( - ConnectionTarget, "serialize", autospec=True - ) as mock_conn_target_ser, async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - mock_conn_target_ser.return_value = {"serialized": "value"} - targets = await self.manager.get_connection_targets( - connection_id="dummy", - connection=None, - ) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == conn_invite.endpoint - assert target.label == conn_invite.label - assert target.recipient_keys == conn_invite.recipient_keys - assert target.routing_keys == conn_invite.routing_keys - assert target.sender_key == local_did.verkey - - async def test_get_conn_targets_conn_invitation_no_cache(self): - async with self.profile.session() as session: - self.context.injector.clear_binding(BaseCache) - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=self.test_did, - metadata=None, - ) - - did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - await self.manager.store_did_document(did_doc) - - conn_invite = ConnectionInvitation( - did=None, - endpoint=self.test_endpoint, - recipient_keys=[self.test_target_verkey], - routing_keys=[self.test_verkey], - label="label", - ) - mock_conn = async_mock.MagicMock( - my_did=self.test_did, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), - ) - - targets = await self.manager.get_connection_targets( - connection_id=None, - connection=mock_conn, - ) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == conn_invite.endpoint - assert target.label == conn_invite.label - assert target.recipient_keys == conn_invite.recipient_keys - assert target.routing_keys == conn_invite.routing_keys - assert target.sender_key == local_did.verkey - - async def test_fetch_connection_targets_no_my_did(self): - mock_conn = async_mock.MagicMock() - mock_conn.my_did = None - assert await self.manager.fetch_connection_targets(mock_conn) is None - - async def test_fetch_connection_targets_conn_invitation_did_no_resolver(self): - async with self.profile.session() as session: - self.context.injector.bind_instance(DIDResolver, DIDResolver([])) - await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=self.test_did, - metadata=None, - ) - - conn_invite = ConnectionInvitation( - did=self.test_target_did, - endpoint=self.test_endpoint, - recipient_keys=[self.test_target_verkey], - routing_keys=[self.test_verkey], - label="label", - ) - mock_conn = async_mock.MagicMock( - my_did=self.test_did, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), - ) - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.fetch_connection_targets(mock_conn) - - async def test_fetch_connection_targets_conn_invitation_did_resolver(self): - async with self.profile.session() as session: - builder = DIDDocumentBuilder("did:sov:" + self.test_target_did) - vmethod = builder.verification_method.add( - Ed25519VerificationKey2018, public_key_base58=self.test_target_verkey - ) - builder.service.add_didcomm( - ident="did-communication", - service_endpoint=self.test_endpoint, - recipient_keys=[vmethod], - ) - did_doc = builder.build() - self.resolver = async_mock.MagicMock() - self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( - return_value=self.test_endpoint - ) - self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) - self.resolver.dereference = async_mock.CoroutineMock( - return_value=did_doc.verification_method[0] - ) - self.context.injector.bind_instance(DIDResolver, self.resolver) - - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=self.test_did, - metadata=None, - ) - - conn_invite = ConnectionInvitation( - did=self.test_target_did, - endpoint=self.test_endpoint, - recipient_keys=[self.test_target_verkey], - routing_keys=[self.test_verkey], - label="label", - ) - mock_conn = async_mock.MagicMock( - my_did=self.test_did, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), - ) - - targets = await self.manager.fetch_connection_targets(mock_conn) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == conn_invite.endpoint - assert target.label == conn_invite.label - assert target.recipient_keys == conn_invite.recipient_keys - assert target.routing_keys == [] - assert target.sender_key == local_did.verkey - - async def test_fetch_connection_targets_conn_invitation_btcr_resolver(self): - async with self.profile.session() as session: - builder = DIDDocumentBuilder("did:btcr:x705-jznz-q3nl-srs") - vmethod = builder.verification_method.add( - Ed25519VerificationKey2018, public_key_base58=self.test_target_verkey - ) - builder.service.add_didcomm( - type_="IndyAgent", - recipient_keys=[vmethod], - routing_keys=[vmethod], - service_endpoint=self.test_endpoint, - priority=1, - ) - - builder.service.add_didcomm( - recipient_keys=[vmethod], - routing_keys=[vmethod], - service_endpoint=self.test_endpoint, - priority=0, - ) - builder.service.add_didcomm( - recipient_keys=[vmethod], - routing_keys=[vmethod], - service_endpoint="{}/priority2".format(self.test_endpoint), - priority=2, - ) - did_doc = builder.build() - - self.resolver = async_mock.MagicMock() - self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( - return_value=self.test_endpoint - ) - self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) - self.resolver.dereference = async_mock.CoroutineMock( - return_value=did_doc.verification_method[0] - ) - self.context.injector.bind_instance(DIDResolver, self.resolver) - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=did_doc.id, - metadata=None, - ) - - conn_invite = ConnectionInvitation( - did=did_doc.id, - endpoint=self.test_endpoint, - recipient_keys=[vmethod.material], - routing_keys=[self.test_verkey], - label="label", - ) - mock_conn = async_mock.MagicMock( - my_did=did_doc.id, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), - ) - - targets = await self.manager.fetch_connection_targets(mock_conn) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == self.test_endpoint - assert target.label == conn_invite.label - assert target.recipient_keys == conn_invite.recipient_keys - assert target.routing_keys == [vmethod.material] - assert target.sender_key == local_did.verkey - - async def test_fetch_connection_targets_conn_invitation_btcr_without_services(self): - async with self.profile.session() as session: - did_doc_json = { - "@context": ["https://www.w3.org/ns/did/v1"], - "id": "did:btcr:x705-jznz-q3nl-srs", - "verificationMethod": [ - { - "type": "EcdsaSecp256k1VerificationKey2019", - "id": "did:btcr:x705-jznz-q3nl-srs#key-0", - "publicKeyBase58": "02e0e01a8c302976e1556e95c54146e8464adac8626a5d29474718a7281133ff49", - }, - { - "type": "EcdsaSecp256k1VerificationKey2019", - "id": "did:btcr:x705-jznz-q3nl-srs#key-1", - "publicKeyBase58": "02e0e01a8c302976e1556e95c54146e8464adac8626a5d29474718a7281133ff49", - }, - { - "type": "EcdsaSecp256k1VerificationKey2019", - "id": "did:btcr:x705-jznz-q3nl-srs#satoshi", - "publicKeyBase58": "02e0e01a8c302976e1556e95c54146e8464adac8626a5d29474718a7281133ff49", - }, - ], - } - did_doc = DIDDocument.deserialize(did_doc_json) - self.resolver = async_mock.MagicMock() - self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( - return_value=self.test_endpoint - ) - self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) - self.context.injector.bind_instance(DIDResolver, self.resolver) - - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=did_doc.id, - metadata=None, - ) - - conn_invite = ConnectionInvitation( - did=did_doc.id, - endpoint=self.test_endpoint, - recipient_keys=["{}#1".format(did_doc.id)], - routing_keys=[self.test_verkey], - label="label", - ) - mock_conn = async_mock.MagicMock( - my_did=did_doc.id, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), - ) - with self.assertRaises(BaseConnectionManagerError): - await self.manager.fetch_connection_targets(mock_conn) - - async def test_fetch_connection_targets_conn_invitation_no_didcomm_services(self): - async with self.profile.session() as session: - builder = DIDDocumentBuilder("did:btcr:x705-jznz-q3nl-srs") - builder.verification_method.add( - Ed25519VerificationKey2018, public_key_base58=self.test_target_verkey - ) - builder.service.add(type_="LinkedData", service_endpoint=self.test_endpoint) - did_doc = builder.build() - self.resolver = async_mock.MagicMock() - self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( - return_value=self.test_endpoint - ) - self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) - self.context.injector.bind_instance(DIDResolver, self.resolver) - await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=did_doc.id, - metadata=None, - ) - - conn_invite = ConnectionInvitation( - did=did_doc.id, - endpoint=self.test_endpoint, - recipient_keys=["{}#1".format(did_doc.id)], - routing_keys=[self.test_verkey], - label="label", - ) - mock_conn = async_mock.MagicMock( - my_did=did_doc.id, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), - ) - with self.assertRaises(BaseConnectionManagerError): - await self.manager.fetch_connection_targets(mock_conn) - - async def test_fetch_connection_targets_conn_invitation_supports_Ed25519VerificationKey2018_key_type_no_multicodec( - self, - ): - async with self.profile.session() as session: - builder = DIDDocumentBuilder("did:btcr:x705-jznz-q3nl-srs") - vmethod = builder.verification_method.add( - Ed25519VerificationKey2020, - public_key_multibase=multibase.encode( - b58_to_bytes(self.test_target_verkey), "base58btc" - ), - ) - builder.service.add_didcomm( - type_="IndyAgent", - service_endpoint=self.test_endpoint, - recipient_keys=[vmethod], - ) - did_doc = builder.build() - self.resolver = async_mock.MagicMock() - self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( - return_value=self.test_endpoint - ) - self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) - self.resolver.dereference = async_mock.CoroutineMock( - return_value=did_doc.verification_method[0] - ) - self.context.injector.bind_instance(DIDResolver, self.resolver) - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=did_doc.id, - metadata=None, - ) - - conn_invite = ConnectionInvitation( - did=did_doc.id, - endpoint=self.test_endpoint, - recipient_keys=[vmethod.public_key_jwk], - routing_keys=[self.test_verkey], - label="label", - ) - mock_conn = async_mock.MagicMock( - my_did=did_doc.id, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), - ) - - targets = await self.manager.fetch_connection_targets(mock_conn) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == self.test_endpoint - assert target.label == conn_invite.label - assert target.recipient_keys == [self.test_target_verkey] - assert target.routing_keys == [] - assert target.sender_key == local_did.verkey - - async def test_fetch_connection_targets_conn_invitation_supports_Ed25519VerificationKey2018_key_type_with_multicodec( - self, - ): - async with self.profile.session() as session: - builder = DIDDocumentBuilder("did:btcr:x705-jznz-q3nl-srs") - vmethod = builder.verification_method.add( - Ed25519VerificationKey2020, - public_key_multibase=multibase.encode( - multicodec.wrap( - "ed25519-pub", b58_to_bytes(self.test_target_verkey) - ), - "base58btc", - ), - ) - builder.service.add_didcomm( - type_="IndyAgent", - service_endpoint=self.test_endpoint, - recipient_keys=[vmethod], - ) - did_doc = builder.build() - self.resolver = async_mock.MagicMock() - self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( - return_value=self.test_endpoint - ) - self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) - self.resolver.dereference = async_mock.CoroutineMock( - return_value=did_doc.verification_method[0] - ) - self.context.injector.bind_instance(DIDResolver, self.resolver) - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=did_doc.id, - metadata=None, - ) - - conn_invite = ConnectionInvitation( - did=did_doc.id, - endpoint=self.test_endpoint, - recipient_keys=[vmethod.public_key_jwk], - routing_keys=[self.test_verkey], - label="label", - ) - mock_conn = async_mock.MagicMock( - my_did=did_doc.id, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), - ) - - targets = await self.manager.fetch_connection_targets(mock_conn) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == self.test_endpoint - assert target.label == conn_invite.label - assert target.recipient_keys == [self.test_target_verkey] - assert target.routing_keys == [] - assert target.sender_key == local_did.verkey - - async def test_fetch_connection_targets_conn_invitation_supported_JsonWebKey2020_key_type( - self, - ): - async with self.profile.session() as session: - builder = DIDDocumentBuilder("did:btcr:x705-jznz-q3nl-srs") - vmethod = builder.verification_method.add( - JsonWebKey2020, - ident="1", - public_key_jwk={ - "kty": "OKP", - "crv": "Ed25519", - "x": bytes_to_b64(b58_to_bytes(self.test_target_verkey), True), - }, - ) - builder.service.add_didcomm( - type_="IndyAgent", - service_endpoint=self.test_endpoint, - recipient_keys=[vmethod], - ) - did_doc = builder.build() - self.resolver = async_mock.MagicMock() - self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( - return_value=self.test_endpoint - ) - self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) - self.resolver.dereference = async_mock.CoroutineMock( - return_value=did_doc.verification_method[0] - ) - self.context.injector.bind_instance(DIDResolver, self.resolver) - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=did_doc.id, - metadata=None, - ) - - conn_invite = ConnectionInvitation( - did=did_doc.id, - endpoint=self.test_endpoint, - recipient_keys=[vmethod.public_key_jwk], - routing_keys=[self.test_verkey], - label="label", - ) - mock_conn = async_mock.MagicMock( - my_did=did_doc.id, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), - ) - - targets = await self.manager.fetch_connection_targets(mock_conn) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == self.test_endpoint - assert target.label == conn_invite.label - assert target.recipient_keys == [self.test_target_verkey] - assert target.routing_keys == [] - assert target.sender_key == local_did.verkey - - async def test_fetch_connection_targets_conn_invitation_unsupported_key_type(self): - async with self.profile.session() as session: - builder = DIDDocumentBuilder("did:btcr:x705-jznz-q3nl-srs") - vmethod = builder.verification_method.add( - JsonWebKey2020, - ident="1", - public_key_jwk={ - "kty": "EC", - "crv": "P-256", - "x": "2syLh57B-dGpa0F8p1JrO6JU7UUSF6j7qL-vfk1eOoY", - "y": "BgsGtI7UPsObMRjdElxLOrgAO9JggNMjOcfzEPox18w", - }, - ) - builder.service.add_didcomm( - type_="IndyAgent", - service_endpoint=self.test_endpoint, - recipient_keys=[vmethod], - ) - did_doc = builder.build() - self.resolver = async_mock.MagicMock() - self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( - return_value=self.test_endpoint - ) - self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) - self.resolver.dereference = async_mock.CoroutineMock( - return_value=did_doc.verification_method[0] - ) - self.context.injector.bind_instance(DIDResolver, self.resolver) - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=did_doc.id, - metadata=None, - ) - - conn_invite = ConnectionInvitation( - did=did_doc.id, - endpoint=self.test_endpoint, - recipient_keys=["{}#1".format(did_doc.id)], - routing_keys=[self.test_verkey], - label="label", - ) - mock_conn = async_mock.MagicMock( - my_did=did_doc.id, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock(return_value=conn_invite), - ) - with self.assertRaises(BaseConnectionManagerError): - await self.manager.fetch_connection_targets(mock_conn) - - async def test_fetch_connection_targets_oob_invitation_svc_did_no_resolver(self): - async with self.profile.session() as session: - self.context.injector.bind_instance(DIDResolver, DIDResolver([])) - await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=self.test_did, - metadata=None, - ) - - mock_oob_invite = async_mock.MagicMock(services=[self.test_did]) - - mock_conn = async_mock.MagicMock( - my_did=self.test_did, - retrieve_invitation=async_mock.CoroutineMock( - return_value=mock_oob_invite - ), - state=ConnRecord.State.INVITATION.rfc23, - their_role=ConnRecord.Role.RESPONDER.rfc23, - ) - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.fetch_connection_targets(mock_conn) - - async def test_fetch_connection_targets_oob_invitation_svc_did_resolver(self): - async with self.profile.session() as session: - builder = DIDDocumentBuilder("did:sov:" + self.test_target_did) - vmethod = builder.verification_method.add( - Ed25519VerificationKey2018, - ident="1", - public_key_base58=self.test_target_verkey, - ) - builder.service.add_didcomm( - ident="did-communication", - service_endpoint=self.test_endpoint, - recipient_keys=[vmethod], - ) - did_doc = builder.build() - - self.resolver = async_mock.MagicMock() - self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) - self.resolver.dereference = async_mock.CoroutineMock( - return_value=did_doc.verification_method[0] - ) - self.context.injector.bind_instance(DIDResolver, self.resolver) - - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=self.test_did, - metadata=None, - ) - - mock_oob_invite = async_mock.MagicMock( - label="a label", - their_did=self.test_target_did, - services=["dummy"], - ) - mock_conn = async_mock.MagicMock( - my_did=self.test_did, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock( - return_value=mock_oob_invite - ), - ) - - targets = await self.manager.fetch_connection_targets(mock_conn) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == self.test_endpoint - assert target.label == mock_oob_invite.label - assert target.recipient_keys == [vmethod.material] - assert target.routing_keys == [] - assert target.sender_key == local_did.verkey - - async def test_fetch_connection_targets_oob_invitation_svc_block_resolver(self): - async with self.profile.session() as session: - self.resolver = async_mock.MagicMock() - self.resolver.get_endpoint_for_did = async_mock.CoroutineMock( - return_value=self.test_endpoint - ) - self.resolver.get_key_for_did = async_mock.CoroutineMock( - return_value=self.test_target_verkey - ) - self.context.injector.bind_instance(DIDResolver, self.resolver) - - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=self.test_did, - metadata=None, - ) - - mock_oob_invite = async_mock.MagicMock( - label="a label", - their_did=self.test_target_did, - services=[ - async_mock.MagicMock( - service_endpoint=self.test_endpoint, - recipient_keys=[ - DIDKey.from_public_key_b58( - self.test_target_verkey, ED25519 - ).did - ], - routing_keys=[], - ) - ], - ) - mock_conn = async_mock.MagicMock( - my_did=self.test_did, - their_did=self.test_target_did, - connection_id="dummy", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.INVITATION.rfc23, - retrieve_invitation=async_mock.CoroutineMock( - return_value=mock_oob_invite - ), - ) - - targets = await self.manager.fetch_connection_targets(mock_conn) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == self.test_endpoint - assert target.label == mock_oob_invite.label - assert target.recipient_keys == [self.test_target_verkey] - assert target.routing_keys == [] - assert target.sender_key == local_did.verkey - - async def test_fetch_connection_targets_conn_initiator_completed_no_their_did(self): - async with self.profile.session() as session: - await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=self.test_did, - metadata=None, - ) - - mock_conn = async_mock.MagicMock( - my_did=self.test_did, - their_did=None, - state=ConnRecord.State.COMPLETED.rfc23, - ) - assert await self.manager.fetch_connection_targets(mock_conn) is None - - async def test_fetch_connection_targets_conn_completed_their_did(self): - async with self.profile.session() as session: - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=self.test_did, - metadata=None, - ) - - did_doc = self.make_did_doc(did=self.test_did, verkey=self.test_verkey) - await self.manager.store_did_document(did_doc) - - mock_conn = async_mock.MagicMock( - my_did=self.test_did, - their_did=self.test_did, - their_label="label", - their_role=ConnRecord.Role.REQUESTER.rfc160, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - targets = await self.manager.fetch_connection_targets(mock_conn) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == self.test_endpoint - assert target.label == mock_conn.their_label - assert target.recipient_keys == [self.test_verkey] - assert target.routing_keys == [] - assert target.sender_key == local_did.verkey - - async def test_fetch_connection_targets_conn_no_invi_with_their_did(self): - async with self.profile.session() as session: - local_did = await session.wallet.create_local_did( - method=SOV, - key_type=ED25519, - seed=self.test_seed, - did=self.test_did, - metadata=None, - ) - - self.manager.resolve_invitation = async_mock.CoroutineMock() - self.manager.resolve_invitation.return_value = ( - self.test_endpoint, - [self.test_verkey], - [], - ) - - did_doc = self.make_did_doc(did=self.test_did, verkey=self.test_verkey) - await self.manager.store_did_document(did_doc) - - mock_conn = async_mock.MagicMock( - my_did=self.test_did, - their_did=self.test_did, - their_label="label", - their_role=ConnRecord.Role.RESPONDER.rfc23, - state=ConnRecord.State.REQUEST.rfc23, - invitation_key=None, - invitation_msg_id=None, - ) - - targets = await self.manager.fetch_connection_targets(mock_conn) - assert len(targets) == 1 - target = targets[0] - assert target.did == mock_conn.their_did - assert target.endpoint == self.test_endpoint - assert target.label is None - assert target.recipient_keys == [self.test_verkey] - assert target.routing_keys == [] - assert target.sender_key == local_did.verkey - - async def test_diddoc_connection_targets_diddoc_underspecified(self): - with self.assertRaises(BaseConnectionManagerError): - self.manager.diddoc_connection_targets(None, self.test_verkey) - - x_did_doc = DIDDoc(did=None) - with self.assertRaises(BaseConnectionManagerError): - self.manager.diddoc_connection_targets(x_did_doc, self.test_verkey) - - x_did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - x_did_doc._service = {} - with self.assertRaises(BaseConnectionManagerError): - self.manager.diddoc_connection_targets(x_did_doc, self.test_verkey) - async def test_establish_inbound(self): async with self.profile.session() as session: await session.wallet.create_local_did( diff --git a/aries_cloudagent/protocols/didexchange/v1_0/manager.py b/aries_cloudagent/protocols/didexchange/v1_0/manager.py index 7e69780b8e..4b2b95966f 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/manager.py @@ -211,8 +211,27 @@ async def create_request_implicit( async with self.profile.session() as session: wallet = session.inject(BaseWallet) my_public_info = await wallet.get_public_did() - if not my_public_info: - raise WalletError("No public DID configured") + if not my_public_info: + raise WalletError("No public DID configured") + if ( + my_public_info.did == their_public_did + or f"did:sov:{my_public_info.did}" == their_public_did + ): + raise DIDXManagerError( + "Cannot connect to yourself through public DID" + ) + try: + await ConnRecord.retrieve_by_did( + session, + their_did=their_public_did, + my_did=my_public_info.did, + ) + raise DIDXManagerError( + "Connection already exists for their_did " + f"{their_public_did} and my_did {my_public_info.did}" + ) + except StorageNotFoundError: + pass conn_rec = ConnRecord( my_did=my_public_info.did @@ -320,6 +339,9 @@ async def create_request( # Omit DID Doc attachment if we're using a public DID did_doc = None attach = None + did = conn_rec.my_did + if not did.startswith("did:"): + did = f"did:sov:{did}" else: did_doc = await self.create_did_document( my_info, @@ -333,6 +355,7 @@ async def create_request( async with self.profile.session() as session: wallet = session.inject(BaseWallet) await attach.data.sign(my_info.verkey, wallet) + did = conn_rec.my_did if conn_rec.their_public_did is not None: qualified_did = conn_rec.their_public_did @@ -348,7 +371,7 @@ async def create_request( request = DIDXRequest( label=my_label, - did=conn_rec.my_did, + did=did, did_doc_attach=attach, goal_code=goal_code, goal=goal, @@ -478,23 +501,28 @@ async def receive_request( conn_rec = new_conn_rec # request DID doc describes requester DID - if not (request.did_doc_attach and request.did_doc_attach.data): - raise DIDXManagerError( - "DID Doc attachment missing or has no data: " - "cannot connect to public DID" - ) - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - conn_did_doc = await self.verify_diddoc(wallet, request.did_doc_attach) - if request.did != conn_did_doc.did: - raise DIDXManagerError( - ( - f"Connection DID {request.did} does not match " - f"DID Doc id {conn_did_doc.did}" - ), - error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value, + if request.did_doc_attach and request.did_doc_attach.data: + self._logger.debug("Received DID Doc attachment in request") + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + conn_did_doc = await self.verify_diddoc(wallet, request.did_doc_attach) + await self.store_did_document(conn_did_doc) + if request.did != conn_did_doc.did: + raise DIDXManagerError( + ( + f"Connection DID {request.did} does not match " + f"DID Doc id {conn_did_doc.did}" + ), + error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value, + ) + else: + if request.did is None: + raise DIDXManagerError("No DID in request") + + self._logger.debug( + "No DID Doc attachment in request; doc will be resolved from DID" ) - await self.store_did_document(conn_did_doc) + await self.record_keys_for_public_did(request.did) if conn_rec: # request is against explicit invitation auto_accept = ( @@ -515,13 +543,7 @@ async def receive_request( # request is against implicit invitation on public DID if not self.profile.settings.get("requests_through_public_did"): raise DIDXManagerError( - "Unsolicited connection requests to " "public DID is not enabled" - ) - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - my_info = await wallet.create_local_did( - method=SOV, - key_type=ED25519, + "Unsolicited connection requests to public DID is not enabled" ) auto_accept = bool( @@ -533,7 +555,7 @@ async def receive_request( ) conn_rec = ConnRecord( - my_did=my_info.did, + my_did=None, # Defer DID creation until create_response accept=( ConnRecord.ACCEPT_AUTO if auto_accept else ConnRecord.ACCEPT_MANUAL ), @@ -567,6 +589,7 @@ async def create_response( conn_rec: ConnRecord, my_endpoint: Optional[str] = None, mediation_id: Optional[str] = None, + use_public_did: Optional[bool] = None, ) -> DIDXResponse: """ Create a connection response for a received connection request. @@ -610,6 +633,17 @@ async def create_response( async with self.profile.session() as session: wallet = session.inject(BaseWallet) my_info = await wallet.get_local_did(conn_rec.my_did) + did = my_info.did + elif use_public_did: + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + my_info = await wallet.get_public_did() + if not my_info: + raise DIDXManagerError("No public DID configured") + conn_rec.my_did = my_info.did + did = my_info.did + if not did.startswith("did:"): + did = f"did:sov:{did}" else: async with self.profile.session() as session: wallet = session.inject(BaseWallet) @@ -618,6 +652,7 @@ async def create_response( key_type=ED25519, ) conn_rec.my_did = my_info.did + did = my_info.did # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_inviter( @@ -634,19 +669,25 @@ async def create_response( my_endpoints.append(default_endpoint) my_endpoints.extend(self.profile.settings.get("additional_endpoints", [])) - did_doc = await self.create_did_document( - my_info, - conn_rec.inbound_connection_id, - my_endpoints, - mediation_records=list( - filter(None, [base_mediation_record, mediation_record]) - ), - ) - attach = AttachDecorator.data_base64(did_doc.serialize()) - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - await attach.data.sign(conn_rec.invitation_key, wallet) - response = DIDXResponse(did=my_info.did, did_doc_attach=attach) + if use_public_did: + # Omit DID Doc attachment if we're using a public DID + did_doc = None + attach = None + else: + did_doc = await self.create_did_document( + my_info, + conn_rec.inbound_connection_id, + my_endpoints, + mediation_records=list( + filter(None, [base_mediation_record, mediation_record]) + ), + ) + attach = AttachDecorator.data_base64(did_doc.serialize()) + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + await attach.data.sign(conn_rec.invitation_key, wallet) + + response = DIDXResponse(did=did, did_doc_attach=attach) # Assign thread information response.assign_thread_from(request) response.assign_trace_from(request) @@ -748,19 +789,26 @@ async def accept_response( ) their_did = response.did - if not response.did_doc_attach: - raise DIDXManagerError("No DIDDoc attached; cannot connect to public DID") - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - conn_did_doc = await self.verify_diddoc( - wallet, response.did_doc_attach, conn_rec.invitation_key - ) - if their_did != conn_did_doc.did: - raise DIDXManagerError( - f"Connection DID {their_did} " - f"does not match DID doc id {conn_did_doc.did}" + if response.did_doc_attach: + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + conn_did_doc = await self.verify_diddoc( + wallet, response.did_doc_attach, conn_rec.invitation_key + ) + if their_did != conn_did_doc.did: + raise DIDXManagerError( + f"Connection DID {their_did} " + f"does not match DID doc id {conn_did_doc.did}" + ) + await self.store_did_document(conn_did_doc) + else: + if response.did is None: + raise DIDXManagerError("No DID in response") + + self._logger.debug( + "No DID Doc attachment in response; doc will be resolved from DID" ) - await self.store_did_document(conn_did_doc) + await self.record_keys_for_public_did(response.did) conn_rec.their_did = their_did conn_rec.state = ConnRecord.State.RESPONSE.rfc23 diff --git a/aries_cloudagent/protocols/didexchange/v1_0/messages/request.py b/aries_cloudagent/protocols/didexchange/v1_0/messages/request.py index 99cfd0c960..501f829f92 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/messages/request.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/messages/request.py @@ -9,7 +9,7 @@ AttachDecorator, AttachDecoratorSchema, ) -from .....messaging.valid import INDY_DID_EXAMPLE, INDY_DID_VALIDATE +from .....messaging.valid import GENERIC_DID_EXAMPLE, GENERIC_DID_VALIDATE from ..message_types import DIDX_REQUEST, PROTOCOL_PACKAGE HANDLER_CLASS = f"{PROTOCOL_PACKAGE}.handlers.request_handler.DIDXRequestHandler" @@ -75,8 +75,8 @@ class Meta: }, ) did = fields.Str( - validate=INDY_DID_VALIDATE, - metadata={"description": "DID of exchange", "example": INDY_DID_EXAMPLE}, + validate=GENERIC_DID_VALIDATE, + metadata={"description": "DID of exchange", "example": GENERIC_DID_EXAMPLE}, ) did_doc_attach = fields.Nested( AttachDecoratorSchema, diff --git a/aries_cloudagent/protocols/didexchange/v1_0/messages/response.py b/aries_cloudagent/protocols/didexchange/v1_0/messages/response.py index b438c68fd0..6634005e66 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/messages/response.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/messages/response.py @@ -1,5 +1,6 @@ """Represents a DID exchange response message under RFC 23.""" +from typing import Optional from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -7,7 +8,7 @@ AttachDecorator, AttachDecoratorSchema, ) -from .....messaging.valid import INDY_DID_EXAMPLE, INDY_DID_VALIDATE +from .....messaging.valid import GENERIC_DID_EXAMPLE, GENERIC_DID_VALIDATE from ..message_types import DIDX_RESPONSE, PROTOCOL_PACKAGE HANDLER_CLASS = f"{PROTOCOL_PACKAGE}.handlers.response_handler.DIDXResponseHandler" @@ -27,7 +28,7 @@ def __init__( self, *, did: str = None, - did_doc_attach: AttachDecorator = None, + did_doc_attach: Optional[AttachDecorator] = None, **kwargs, ): """ @@ -52,8 +53,8 @@ class Meta: unknown = EXCLUDE did = fields.Str( - validate=INDY_DID_VALIDATE, - metadata={"description": "DID of exchange", "example": INDY_DID_EXAMPLE}, + validate=GENERIC_DID_VALIDATE, + metadata={"description": "DID of exchange", "example": GENERIC_DID_EXAMPLE}, ) did_doc_attach = fields.Nested( AttachDecoratorSchema, diff --git a/aries_cloudagent/protocols/didexchange/v1_0/messages/tests/test_request.py b/aries_cloudagent/protocols/didexchange/v1_0/messages/tests/test_request.py index 098e343762..3b769d56c4 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/messages/tests/test_request.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/messages/tests/test_request.py @@ -109,6 +109,19 @@ def test_serialize(self, mock_request_schema_dump): assert request_dict is mock_request_schema_dump.return_value + def test_method_other_than_indy(self): + """Test method other than indy.""" + request = DIDXRequest( + label=TestConfig.test_label, + did="did:web:example.com:alice", + did_doc_attach=None, + goal_code=TestConfig.goal_code, + goal=TestConfig.goal, + ) + request_dict = request.serialize() + new_request = DIDXRequest.deserialize(request_dict) + assert request.serialize() == new_request.serialize() + class TestDIDXRequestSchema(AsyncTestCase, TestConfig): """Test request schema.""" diff --git a/aries_cloudagent/protocols/didexchange/v1_0/routes.py b/aries_cloudagent/protocols/didexchange/v1_0/routes.py index e072cb5a2f..c0a66c6cbb 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/routes.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/routes.py @@ -146,6 +146,9 @@ class DIDXAcceptRequestQueryStringSchema(OpenAPISchema): "example": UUID4_EXAMPLE, }, ) + use_public_did = fields.Boolean( + required=False, metadata={"description": "Use public DID for this connection"} + ) class DIDXConnIdMatchInfoSchema(OpenAPISchema): @@ -351,6 +354,7 @@ async def didx_accept_request(request: web.BaseRequest): connection_id = request.match_info["conn_id"] my_endpoint = request.query.get("my_endpoint") or None mediation_id = request.query.get("mediation_id") or None + use_public_did = json.loads(request.query.get("use_public_did", "null")) profile = context.profile didx_mgr = DIDXManager(profile) @@ -361,6 +365,7 @@ async def didx_accept_request(request: web.BaseRequest): conn_rec=conn_rec, my_endpoint=my_endpoint, mediation_id=mediation_id, + use_public_did=use_public_did, ) result = conn_rec.serialize() except StorageNotFoundError as err: diff --git a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py index c15bf555d5..041fc4a982 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py @@ -121,6 +121,10 @@ async def setUp(self): self.resolver = async_mock.MagicMock() did_doc = DIDDocument.deserialize(DOC) self.resolver.resolve = async_mock.CoroutineMock(return_value=did_doc) + assert did_doc.verification_method + self.resolver.dereference_verification_method = async_mock.CoroutineMock( + return_value=did_doc.verification_method[0] + ) self.context.injector.bind_instance(DIDResolver, self.resolver) self.multitenant_mgr = async_mock.MagicMock(MultitenantManager, autospec=True) @@ -328,6 +332,44 @@ async def test_create_request_implicit_no_public_did(self): assert "No public DID configured" in str(context.exception) + async def test_create_request_implicit_x_public_self(self): + async with self.profile.session() as session: + info_public = await session.wallet.create_public_did( + SOV, + ED25519, + ) + with self.assertRaises(DIDXManagerError) as context: + await self.manager.create_request_implicit( + their_public_did=info_public.did, + my_label=None, + my_endpoint=None, + mediation_id=None, + use_public_did=True, + alias="Tester", + ) + + assert "Cannot connect to yourself" in str(context.exception) + + async def test_create_request_implicit_x_public_already_connected(self): + async with self.profile.session() as session: + info_public = await session.wallet.create_public_did( + SOV, + ED25519, + ) + with self.assertRaises(DIDXManagerError) as context, async_mock.patch.object( + test_module.ConnRecord, "retrieve_by_did", async_mock.CoroutineMock() + ) as mock_retrieve_by_did: + await self.manager.create_request_implicit( + their_public_did=TestConfig.test_target_did, + my_label=None, + my_endpoint=None, + mediation_id=None, + use_public_did=True, + alias="Tester", + ) + + assert "Connection already exists for their_did" in str(context.exception) + async def test_create_request(self): mock_conn_rec = async_mock.MagicMock( connection_id="dummy", @@ -648,6 +690,15 @@ async def test_receive_request_public_did_no_did_doc_attachment(self): _thread=async_mock.MagicMock(pthid="did:sov:publicdid0000000000000"), ) + mediation_record = MediationRecord( + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + await mediation_record.save(session) + await session.wallet.create_local_did( method=SOV, key_type=ED25519, @@ -655,26 +706,138 @@ async def test_receive_request_public_did_no_did_doc_attachment(self): did=TestConfig.test_did, ) + STATE_REQUEST = ConnRecord.State.REQUEST self.profile.context.update_settings({"public_invites": True}) - mock_conn_rec_state_request = ConnRecord.State.REQUEST + ACCEPT_AUTO = ConnRecord.ACCEPT_AUTO with async_mock.patch.object( test_module, "ConnRecord", async_mock.MagicMock() ) as mock_conn_rec_cls, async_mock.patch.object( + test_module, "DIDDoc", autospec=True + ) as mock_did_doc, async_mock.patch.object( test_module, "DIDPosture", autospec=True - ) as mock_did_posture: + ) as mock_did_posture, async_mock.patch.object( + test_module, "AttachDecorator", autospec=True + ) as mock_attach_deco, async_mock.patch.object( + test_module, "DIDXResponse", autospec=True + ) as mock_response, async_mock.patch.object( + self.manager, + "verify_diddoc", + async_mock.CoroutineMock(return_value=DIDDoc(TestConfig.test_did)), + ), async_mock.patch.object( + self.manager, "create_did_document", async_mock.CoroutineMock() + ) as mock_create_did_doc, async_mock.patch.object( + self.manager, "record_keys_for_public_did", async_mock.CoroutineMock() + ) as mock_record_keys_for_public_did, async_mock.patch.object( + MediationManager, "prepare_request", autospec=True + ) as mock_mediation_mgr_prep_req: + mock_create_did_doc.return_value = async_mock.MagicMock( + serialize=async_mock.MagicMock(return_value={}) + ) + mock_mediation_mgr_prep_req.return_value = ( + mediation_record, + mock_request, + ) + mock_conn_record = async_mock.MagicMock( - accept=ConnRecord.ACCEPT_MANUAL, + accept=ACCEPT_AUTO, my_did=None, - state=mock_conn_rec_state_request.rfc23, + state=STATE_REQUEST.rfc23, attach_request=async_mock.CoroutineMock(), retrieve_request=async_mock.CoroutineMock(), metadata_get_all=async_mock.CoroutineMock(return_value={}), + metadata_get=async_mock.CoroutineMock(return_value=True), save=async_mock.CoroutineMock(), ) + + mock_conn_rec_cls.ACCEPT_AUTO = ConnRecord.ACCEPT_AUTO + mock_conn_rec_cls.State.REQUEST = STATE_REQUEST + mock_conn_rec_cls.State.get = async_mock.MagicMock( + return_value=STATE_REQUEST + ) + mock_conn_rec_cls.retrieve_by_id = async_mock.CoroutineMock( + return_value=async_mock.MagicMock(save=async_mock.CoroutineMock()) + ) + mock_conn_rec_cls.retrieve_by_invitation_msg_id = ( + async_mock.CoroutineMock(return_value=mock_conn_record) + ) mock_conn_rec_cls.return_value = mock_conn_record + + mock_did_posture.get = async_mock.MagicMock( + return_value=test_module.DIDPosture.PUBLIC + ) + + mock_did_doc.from_json = async_mock.MagicMock( + return_value=async_mock.MagicMock(did=TestConfig.test_did) + ) + mock_attach_deco.data_base64 = async_mock.MagicMock( + return_value=async_mock.MagicMock( + data=async_mock.MagicMock(sign=async_mock.CoroutineMock()) + ) + ) + mock_response.return_value = async_mock.MagicMock( + assign_thread_from=async_mock.MagicMock(), + assign_trace_from=async_mock.MagicMock(), + ) + + conn_rec = await self.manager.receive_request( + request=mock_request, + recipient_did=TestConfig.test_did, + recipient_verkey=None, + my_endpoint=None, + alias=None, + auto_accept_implicit=None, + ) + assert conn_rec + self.oob_mock.clean_finished_oob_record.assert_called_once_with( + self.profile, mock_request + ) + + async def test_receive_request_public_did_no_did_doc_attachment_no_did(self): + async with self.profile.session() as session: + mock_request = async_mock.MagicMock( + did=None, + did_doc_attach=None, + _thread=async_mock.MagicMock(pthid="did:sov:publicdid0000000000000"), + ) + + await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=TestConfig.test_did, + ) + + STATE_REQUEST = ConnRecord.State.REQUEST + self.profile.context.update_settings({"public_invites": True}) + ACCEPT_AUTO = ConnRecord.ACCEPT_AUTO + with async_mock.patch.object( + test_module, "ConnRecord", async_mock.MagicMock() + ) as mock_conn_rec_cls, async_mock.patch.object( + test_module, "DIDPosture", autospec=True + ) as mock_did_posture: + mock_conn_record = async_mock.MagicMock( + accept=ACCEPT_AUTO, + my_did=None, + state=STATE_REQUEST.rfc23, + attach_request=async_mock.CoroutineMock(), + retrieve_request=async_mock.CoroutineMock(), + metadata_get_all=async_mock.CoroutineMock(return_value={}), + metadata_get=async_mock.CoroutineMock(return_value=True), + save=async_mock.CoroutineMock(), + ) + + mock_conn_rec_cls.ACCEPT_AUTO = ConnRecord.ACCEPT_AUTO + mock_conn_rec_cls.State.REQUEST = STATE_REQUEST + mock_conn_rec_cls.State.get = async_mock.MagicMock( + return_value=STATE_REQUEST + ) + mock_conn_rec_cls.retrieve_by_id = async_mock.CoroutineMock( + return_value=async_mock.MagicMock(save=async_mock.CoroutineMock()) + ) mock_conn_rec_cls.retrieve_by_invitation_msg_id = ( async_mock.CoroutineMock(return_value=mock_conn_record) ) + mock_conn_rec_cls.return_value = mock_conn_record mock_did_posture.get = async_mock.MagicMock( return_value=test_module.DIDPosture.PUBLIC @@ -685,13 +848,11 @@ async def test_receive_request_public_did_no_did_doc_attachment(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=None, - my_endpoint=TestConfig.test_endpoint, - alias="Alias", + my_endpoint=None, + alias=None, auto_accept_implicit=None, ) - assert "DID Doc attachment missing or has no data" in str( - context.exception - ) + assert "No DID in request" in str(context.exception) async def test_receive_request_public_did_x_not_public(self): async with self.profile.session() as session: @@ -1475,6 +1636,76 @@ async def test_create_response_bad_state(self): ) ) + async def test_create_response_use_public_did(self): + async with self.profile.session() as session: + info_public = await session.wallet.create_public_did( + SOV, + ED25519, + ) + + conn_rec = ConnRecord( + connection_id="dummy", state=ConnRecord.State.REQUEST.rfc23 + ) + + with async_mock.patch.object( + test_module.ConnRecord, "retrieve_request", async_mock.CoroutineMock() + ) as mock_retrieve_req, async_mock.patch.object( + conn_rec, "save", async_mock.CoroutineMock() + ) as mock_save, async_mock.patch.object( + test_module, "DIDDoc", autospec=True + ) as mock_did_doc, async_mock.patch.object( + test_module, "AttachDecorator", autospec=True + ) as mock_attach_deco, async_mock.patch.object( + test_module, "DIDXResponse", autospec=True + ) as mock_response, async_mock.patch.object( + self.manager, "create_did_document", async_mock.CoroutineMock() + ) as mock_create_did_doc: + mock_create_did_doc.return_value = async_mock.MagicMock( + serialize=async_mock.MagicMock() + ) + mock_attach_deco.data_base64 = async_mock.MagicMock( + return_value=async_mock.MagicMock( + data=async_mock.MagicMock(sign=async_mock.CoroutineMock()) + ) + ) + + await self.manager.create_response( + conn_rec, "http://10.20.30.40:5060/", use_public_did=True + ) + + async def test_create_response_use_public_did_x_no_public_did(self): + conn_rec = ConnRecord( + connection_id="dummy", state=ConnRecord.State.REQUEST.rfc23 + ) + + with async_mock.patch.object( + test_module.ConnRecord, "retrieve_request", async_mock.CoroutineMock() + ) as mock_retrieve_req, async_mock.patch.object( + conn_rec, "save", async_mock.CoroutineMock() + ) as mock_save, async_mock.patch.object( + test_module, "DIDDoc", autospec=True + ) as mock_did_doc, async_mock.patch.object( + test_module, "AttachDecorator", autospec=True + ) as mock_attach_deco, async_mock.patch.object( + test_module, "DIDXResponse", autospec=True + ) as mock_response, async_mock.patch.object( + self.manager, "create_did_document", async_mock.CoroutineMock() + ) as mock_create_did_doc: + mock_create_did_doc.return_value = async_mock.MagicMock( + serialize=async_mock.MagicMock() + ) + mock_attach_deco.data_base64 = async_mock.MagicMock( + return_value=async_mock.MagicMock( + data=async_mock.MagicMock(sign=async_mock.CoroutineMock()) + ) + ) + + with self.assertRaises(DIDXManagerError) as context: + await self.manager.create_response( + conn_rec, "http://10.20.30.40:5060/", use_public_did=True + ) + assert "No public DID configured" in str(context.exception) + async def test_accept_response_find_by_thread_id(self): mock_response = async_mock.MagicMock() mock_response._thread = async_mock.MagicMock() @@ -1712,31 +1943,81 @@ async def test_accept_response_find_by_thread_id_no_did_doc_attached(self): mock_response.did = TestConfig.test_target_did mock_response.did_doc_attach = None - receipt = MessageReceipt(sender_did=TestConfig.test_target_did) + receipt = MessageReceipt( + recipient_did=TestConfig.test_did, + recipient_did_public=True, + ) with async_mock.patch.object( ConnRecord, "save", autospec=True ) as mock_conn_rec_save, async_mock.patch.object( ConnRecord, "retrieve_by_request_id", async_mock.CoroutineMock() - ) as mock_conn_retrieve_by_req_id: + ) as mock_conn_retrieve_by_req_id, async_mock.patch.object( + ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_id, async_mock.patch.object( + DIDDoc, "deserialize", async_mock.MagicMock() + ) as mock_did_doc_deser, async_mock.patch.object( + self.manager, "record_keys_for_public_did", async_mock.CoroutineMock() + ) as mock_record_keys_for_public_did: + mock_did_doc_deser.return_value = async_mock.MagicMock( + did=TestConfig.test_target_did + ) mock_conn_retrieve_by_req_id.return_value = async_mock.MagicMock( did=TestConfig.test_target_did, - did_doc_attach=async_mock.MagicMock( - data=async_mock.MagicMock( - verify=async_mock.CoroutineMock(return_value=True), - signed=async_mock.MagicMock( - decode=async_mock.MagicMock( - return_value=json.dumps({"dummy": "did-doc"}) - ) - ), - ) - ), state=ConnRecord.State.REQUEST.rfc23, save=async_mock.CoroutineMock(), + metadata_get=async_mock.CoroutineMock(), + connection_id="test-conn-id", + ) + mock_conn_retrieve_by_id.return_value = async_mock.MagicMock( + their_did=TestConfig.test_target_did, + save=async_mock.CoroutineMock(), ) - with self.assertRaises(DIDXManagerError): + conn_rec = await self.manager.accept_response(mock_response, receipt) + assert conn_rec.their_did == TestConfig.test_target_did + assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED + + async def test_accept_response_find_by_thread_id_no_did_doc_attached_no_did(self): + mock_response = async_mock.MagicMock() + mock_response._thread = async_mock.MagicMock() + mock_response.did = None + mock_response.did_doc_attach = None + + receipt = MessageReceipt( + recipient_did=TestConfig.test_did, + recipient_did_public=True, + ) + + with async_mock.patch.object( + ConnRecord, "save", autospec=True + ) as mock_conn_rec_save, async_mock.patch.object( + ConnRecord, "retrieve_by_request_id", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_req_id, async_mock.patch.object( + ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_id, async_mock.patch.object( + DIDDoc, "deserialize", async_mock.MagicMock() + ) as mock_did_doc_deser, async_mock.patch.object( + self.manager, "record_keys_for_public_did", async_mock.CoroutineMock() + ) as mock_record_keys_for_public_did: + mock_did_doc_deser.return_value = async_mock.MagicMock( + did=TestConfig.test_target_did + ) + mock_conn_retrieve_by_req_id.return_value = async_mock.MagicMock( + did=TestConfig.test_target_did, + state=ConnRecord.State.REQUEST.rfc23, + save=async_mock.CoroutineMock(), + metadata_get=async_mock.CoroutineMock(), + connection_id="test-conn-id", + ) + mock_conn_retrieve_by_id.return_value = async_mock.MagicMock( + their_did=TestConfig.test_target_did, + save=async_mock.CoroutineMock(), + ) + + with self.assertRaises(DIDXManagerError) as context: await self.manager.accept_response(mock_response, receipt) + assert "No DID in response" in str(context.exception) async def test_accept_response_find_by_thread_id_did_mismatch(self): mock_response = async_mock.MagicMock() diff --git a/aries_cloudagent/resolver/base.py b/aries_cloudagent/resolver/base.py index df31cf4dd3..3fd1487fe5 100644 --- a/aries_cloudagent/resolver/base.py +++ b/aries_cloudagent/resolver/base.py @@ -1,14 +1,14 @@ """Base Class for DID Resolvers.""" -import re -import warnings - from abc import ABC, abstractmethod from enum import Enum -from typing import Optional, NamedTuple, Pattern, Sequence, Union, Text +import re +from typing import NamedTuple, Optional, Pattern, Sequence, Text, Union +import warnings from pydid import DID +from ..cache.base import BaseCache from ..config.injection_context import InjectionContext from ..core.error import BaseError from ..core.profile import Profile @@ -70,7 +70,9 @@ def serialize(self) -> dict: class BaseDIDResolver(ABC): """Base Class for DID Resolvers.""" - def __init__(self, type_: ResolverType = None): + DEFAULT_TTL = 3600 + + def __init__(self, type_: Optional[ResolverType] = None): """Initialize BaseDIDResolver. Args: @@ -138,7 +140,10 @@ async def resolve( did: Union[str, DID], service_accept: Optional[Sequence[Text]] = None, ) -> dict: - """Resolve a DID using this resolver.""" + """Resolve a DID using this resolver. + + Handles caching of results. + """ if isinstance(did, DID): did = str(did) else: @@ -148,6 +153,17 @@ async def resolve( f"{self.__class__.__name__} does not support DID method for: {did}" ) + cache_key = f"resolver::{type(self).__name__}::{did}" + cache = profile.inject_or(BaseCache) + if cache: + async with cache.acquire(cache_key) as entry: + if entry.result: + return entry.result + else: + result = await self._resolve(profile, did, service_accept) + await entry.set_result(result, ttl=self.DEFAULT_TTL) + return result + return await self._resolve(profile, did, service_accept) @abstractmethod diff --git a/aries_cloudagent/resolver/default/legacy_peer.py b/aries_cloudagent/resolver/default/legacy_peer.py index b960eeed3e..238d92702d 100644 --- a/aries_cloudagent/resolver/default/legacy_peer.py +++ b/aries_cloudagent/resolver/default/legacy_peer.py @@ -3,13 +3,12 @@ Resolution is performed by looking up a stored DID Document. """ -from collections.abc import Awaitable from copy import deepcopy from dataclasses import asdict, dataclass -import functools import logging -from typing import Callable, Optional, Sequence, Text, TypeVar -from typing_extensions import ParamSpec +from typing import Optional, Sequence, Text, Union + +from pydid import DID from ...cache.base import BaseCache from ...config.injection_context import InjectionContext @@ -25,18 +24,6 @@ LOGGER = logging.getLogger(__name__) -@dataclass -class RetrieveResult: - """Entry in the peer DID cache.""" - - is_local: bool - doc: Optional[dict] = None - - -T = TypeVar("T") -P = ParamSpec("P") - - class LegacyDocCorrections: """Legacy peer DID document corrections. @@ -94,7 +81,9 @@ class LegacyDocCorrections: "type": "did-communication", "priority": 0, "recipientKeys": ["did:sov:JNKL9kJxQi5pNCfA8QBXdJ#1"], - "routingKeys": ["9NnKFUZoYcCqYC2PcaXH3cnaGsoRfyGgyEHbvbLJYh8j"], + "routingKeys": [ + "did:key:z6Mknq3MqipEt9hJegs6J9V7tiLa6T5H5rX3fFCXksJKTuv7#z6Mknq3MqipEt9hJegs6J9V7tiLa6T5H5rX3fFCXksJKTuv7" + ], "serviceEndpoint": "http://bob:3000" } ] @@ -144,6 +133,8 @@ def didcomm_services_recip_keys_are_refs_routing_keys_are_did_key( if "routingKeys" in service: service["routingKeys"] = [ DIDKey.from_public_key_b58(key, ED25519).key_id + if "did:key:" not in key + else key for key in service["routingKeys"] ] return value @@ -163,6 +154,14 @@ def apply(cls, value: dict) -> dict: return value +@dataclass +class RetrieveResult: + """Entry in the peer DID cache.""" + + is_local: bool + doc: Optional[dict] = None + + class LegacyPeerDIDResolver(BaseDIDResolver): """Resolve legacy peer DIDs.""" @@ -173,36 +172,10 @@ def __init__(self): async def setup(self, context: InjectionContext): """Perform required setup for the resolver.""" - def _cached_resource( - self, - profile: Profile, - key: str, - retrieve: Callable[P, Awaitable[RetrieveResult]], - ttl: Optional[int] = None, - ) -> Callable[P, Awaitable[RetrieveResult]]: - """Get a cached resource.""" - - @functools.wraps(retrieve) - async def _wrapped(*args: P.args, **kwargs: P.kwargs): - cache = profile.inject_or(BaseCache) - if cache: - async with cache.acquire(key) as entry: - if entry.result: - value = RetrieveResult(**entry.result) - else: - value = await retrieve(*args, **kwargs) - await entry.set_result(asdict(value), ttl) - else: - value = await retrieve(*args, **kwargs) - - return value - - return _wrapped - - async def _fetch_did_document(self, profile: Profile, did: str): + async def _fetch_did_document(self, profile: Profile, did: str) -> RetrieveResult: """Fetch DID from wallet if available. - This is the method to be used with _cached_resource to enable caching. + This is the method to be used with fetch_did_document to enable caching. """ conn_mgr = BaseConnectionManager(profile) if did.startswith("did:sov:"): @@ -217,15 +190,26 @@ async def _fetch_did_document(self, profile: Profile, did: str): return to_cache - async def fetch_did_document(self, profile: Profile, did: str): + async def fetch_did_document( + self, profile: Profile, did: str, *, ttl: Optional[int] = None + ): """Fetch DID from wallet if available. Return value is cached. """ - cache_key = f"legacy_peer_did_resolver::{did}" - return await self._cached_resource( - profile, cache_key, self._fetch_did_document, ttl=3600 - )(profile, did) + cache_key = f"resolver::LegacyPeerDIDResolver::{did}" + cache = profile.inject_or(BaseCache) + if cache: + async with cache.acquire(cache_key) as entry: + if entry.result: + result = RetrieveResult(**entry.result) + else: + result = await self._fetch_did_document(profile, did) + await entry.set_result(asdict(result), ttl or self.DEFAULT_TTL) + else: + result = await self._fetch_did_document(profile, did) + + return result async def supports(self, profile: Profile, did: str) -> bool: """Return whether this resolver supports the given DID. @@ -252,6 +236,19 @@ async def supports(self, profile: Profile, did: str) -> bool: else: return False + async def resolve( + self, + profile: Profile, + did: Union[str, DID], + service_accept: Optional[Sequence[Text]] = None, + ) -> dict: + """Resolve a Legacy Peer DID to a DID document by fetching from the wallet. + + This overrides the default resolve method so we can take care of caching + ourselves since we use it for the supports method as well. + """ + return await self._resolve(profile, str(did), service_accept) + async def _resolve( self, profile: Profile, diff --git a/aries_cloudagent/resolver/default/tests/test_legacy_peer.py b/aries_cloudagent/resolver/default/tests/test_legacy_peer.py index 1ed3837650..7e3a7c8ff7 100644 --- a/aries_cloudagent/resolver/default/tests/test_legacy_peer.py +++ b/aries_cloudagent/resolver/default/tests/test_legacy_peer.py @@ -178,3 +178,4 @@ def test_corrections_examples(self): } actual = test_module.LegacyDocCorrections.apply(input_doc) assert actual == expected + assert expected == test_module.LegacyDocCorrections.apply(expected) diff --git a/aries_cloudagent/resolver/default/tests/test_universal.py b/aries_cloudagent/resolver/default/tests/test_universal.py index 37fbfd3243..51bdde6f9f 100644 --- a/aries_cloudagent/resolver/default/tests/test_universal.py +++ b/aries_cloudagent/resolver/default/tests/test_universal.py @@ -6,7 +6,8 @@ from asynctest import mock as async_mock import pytest -from aries_cloudagent.config.settings import Settings +from ....config.settings import Settings +from ....core.in_memory import InMemoryProfile from .. import universal as test_module from ...base import DIDNotFound, ResolverError @@ -34,7 +35,7 @@ async def resolver(): @pytest.fixture def profile(): """Profile fixture.""" - yield async_mock.MagicMock() + yield InMemoryProfile.test_profile() class MockResponse: diff --git a/aries_cloudagent/resolver/did_resolver.py b/aries_cloudagent/resolver/did_resolver.py index a3567f7208..55f1f08314 100644 --- a/aries_cloudagent/resolver/did_resolver.py +++ b/aries_cloudagent/resolver/did_resolver.py @@ -10,7 +10,7 @@ import logging from typing import List, Optional, Sequence, Text, Tuple, Union -from pydid import DID, DIDError, DIDUrl, Resource +from pydid import DID, DIDError, DIDUrl, Resource, VerificationMethod import pydid from pydid.doc.doc import BaseDIDDocument, IDNotFoundError @@ -29,7 +29,7 @@ class DIDResolver: """did resolver singleton.""" - def __init__(self, resolvers: List[BaseDIDResolver] = None): + def __init__(self, resolvers: Optional[List[BaseDIDResolver]] = None): """Create DID Resolver.""" self.resolvers = resolvers or [] @@ -151,3 +151,16 @@ async def dereference( raise ResolverError( "Failed to dereference DID URL: {}".format(error) ) from error + + async def dereference_verification_method( + self, + profile: Profile, + did_url: str, + *, + document: Optional[BaseDIDDocument] = None, + ) -> VerificationMethod: + """Dereference a DID URL to a verification method.""" + dereferenced = await self.dereference(profile, did_url, document=document) + if isinstance(dereferenced, VerificationMethod): + return dereferenced + raise ValueError("DID URL does not dereference to a verification method") diff --git a/aries_cloudagent/resolver/tests/test_did_resolver.py b/aries_cloudagent/resolver/tests/test_did_resolver.py index 73c3d1bc64..a65a8d8dc1 100644 --- a/aries_cloudagent/resolver/tests/test_did_resolver.py +++ b/aries_cloudagent/resolver/tests/test_did_resolver.py @@ -9,6 +9,7 @@ from asynctest import mock as async_mock from pydid import DID, DIDDocument, VerificationMethod, BasicDIDDocument +from ...core.in_memory import InMemoryProfile from ..base import ( BaseDIDResolver, DIDMethodNotSupported, @@ -108,7 +109,7 @@ def resolver(): @pytest.fixture def profile(): - yield async_mock.MagicMock() + yield InMemoryProfile.test_profile() def test_create_resolver(resolver):