diff --git a/acapy_agent/connections/base_manager.py b/acapy_agent/connections/base_manager.py index f5ea4451b5..fedd0b7242 100644 --- a/acapy_agent/connections/base_manager.py +++ b/acapy_agent/connections/base_manager.py @@ -444,7 +444,7 @@ async def resolve_didcomm_services( try: doc_dict: dict = await resolver.resolve(self._profile, did, service_accept) doc: ResolvedDocument = pydid.deserialize_document(doc_dict, strict=True) - except ResolverError as error: + except (ResolverError, ValueError) as error: raise BaseConnectionManagerError("Failed to resolve DID services") from error if not doc.service: diff --git a/acapy_agent/ledger/indy_vdr.py b/acapy_agent/ledger/indy_vdr.py index aad8963cd5..5a9b7a2ba0 100644 --- a/acapy_agent/ledger/indy_vdr.py +++ b/acapy_agent/ledger/indy_vdr.py @@ -624,7 +624,7 @@ async def credential_definition_id2schema_id(self, credential_definition_id): seq_no = tokens[3] return (await self.get_schema(seq_no))["id"] - async def get_key_for_did(self, did: str) -> str: + async def get_key_for_did(self, did: str) -> Optional[str]: """Fetch the verkey for a ledger DID. Args: diff --git a/acapy_agent/protocols/did_rotate/v1_0/manager.py b/acapy_agent/protocols/did_rotate/v1_0/manager.py index 8c5e508e06..684bb20dc6 100644 --- a/acapy_agent/protocols/did_rotate/v1_0/manager.py +++ b/acapy_agent/protocols/did_rotate/v1_0/manager.py @@ -128,7 +128,7 @@ async def receive_rotate(self, conn: ConnRecord, rotate: Rotate) -> RotateRecord ) try: - await self._ensure_supported_did(rotate.to_did) + await self.ensure_supported_did(rotate.to_did) except ReportableDIDRotateError as err: responder = self.profile.inject(BaseResponder) err.message.assign_thread_from(rotate) @@ -234,7 +234,7 @@ async def receive_hangup(self, conn: ConnRecord): async with self.profile.session() as session: await conn.delete_record(session) - async def _ensure_supported_did(self, did: str): + async def ensure_supported_did(self, did: str): """Check if the DID is supported.""" resolver = self.profile.inject(DIDResolver) conn_mgr = BaseConnectionManager(self.profile) diff --git a/acapy_agent/protocols/did_rotate/v1_0/message_types.py b/acapy_agent/protocols/did_rotate/v1_0/message_types.py index 2d0c0bfb6e..7cb418fc91 100644 --- a/acapy_agent/protocols/did_rotate/v1_0/message_types.py +++ b/acapy_agent/protocols/did_rotate/v1_0/message_types.py @@ -15,8 +15,8 @@ MESSAGE_TYPES = DIDCommPrefix.qualify_all( { ROTATE: f"{PROTOCOL_PACKAGE}.messages.rotate.Rotate", - ACK: f"{PROTOCOL_PACKAGE}.messages.ack.Ack", + ACK: f"{PROTOCOL_PACKAGE}.messages.ack.RotateAck", HANGUP: f"{PROTOCOL_PACKAGE}.messages.hangup.Hangup", - PROBLEM_REPORT: f"{PROTOCOL_PACKAGE}.messages.problem_report.ProblemReport", + PROBLEM_REPORT: f"{PROTOCOL_PACKAGE}.messages.problem_report.RotateProblemReport", } ) diff --git a/acapy_agent/protocols/did_rotate/v1_0/routes.py b/acapy_agent/protocols/did_rotate/v1_0/routes.py index f441ded27b..15c485371f 100644 --- a/acapy_agent/protocols/did_rotate/v1_0/routes.py +++ b/acapy_agent/protocols/did_rotate/v1_0/routes.py @@ -12,7 +12,12 @@ from ....messaging.models.openapi import OpenAPISchema from ....messaging.valid import DID_WEB_EXAMPLE, UUID4_EXAMPLE from ....storage.error import StorageNotFoundError -from .manager import DIDRotateManager +from .manager import ( + DIDRotateManager, + UnresolvableDIDCommServicesError, + UnresolvableDIDError, + UnsupportedDIDMethodError, +) from .message_types import SPEC_URI from .messages.hangup import HangupSchema as HangupMessageSchema from .messages.rotate import RotateSchema as RotateMessageSchema @@ -63,6 +68,16 @@ async def rotate(request: web.BaseRequest): body = await request.json() to_did = body["to_did"] + # Validate DID before proceeding + try: + await did_rotate_mgr.ensure_supported_did(to_did) + except ( + UnsupportedDIDMethodError, + UnresolvableDIDError, + UnresolvableDIDCommServicesError, + ) as err: + raise web.HTTPBadRequest(reason=str(err)) from err + async with context.session() as session: try: conn = await ConnRecord.retrieve_by_id(session, connection_id) diff --git a/acapy_agent/protocols/did_rotate/v1_0/tests/test_manager.py b/acapy_agent/protocols/did_rotate/v1_0/tests/test_manager.py index cf72da930e..2383bfdc17 100644 --- a/acapy_agent/protocols/did_rotate/v1_0/tests/test_manager.py +++ b/acapy_agent/protocols/did_rotate/v1_0/tests/test_manager.py @@ -94,7 +94,7 @@ async def test_receive_rotate_x(self): with ( mock.patch.object( - self.manager, "_ensure_supported_did", side_effect=test_problem_report + self.manager, "ensure_supported_did", side_effect=test_problem_report ), mock.patch.object(self.responder, "send", mock.CoroutineMock()) as mock_send, ): diff --git a/acapy_agent/protocols/did_rotate/v1_0/tests/test_routes.py b/acapy_agent/protocols/did_rotate/v1_0/tests/test_routes.py index 932d413ed9..22ec03fed2 100644 --- a/acapy_agent/protocols/did_rotate/v1_0/tests/test_routes.py +++ b/acapy_agent/protocols/did_rotate/v1_0/tests/test_routes.py @@ -57,7 +57,8 @@ async def asyncSetUp(self): "DIDRotateManager", autospec=True, return_value=mock.MagicMock( - rotate_my_did=mock.CoroutineMock(return_value=generate_mock_rotate_message()) + rotate_my_did=mock.CoroutineMock(return_value=generate_mock_rotate_message()), + ensure_supported_did=mock.CoroutineMock(), ), ) async def test_rotate(self, *_): @@ -102,7 +103,15 @@ async def test_hangup(self, *_): } ) - async def test_rotate_conn_not_found(self): + @mock.patch.object( + test_module, + "DIDRotateManager", + autospec=True, + return_value=mock.MagicMock( + ensure_supported_did=mock.CoroutineMock(), + ), + ) + async def test_rotate_conn_not_found(self, *_): self.request.match_info = {"conn_id": test_conn_id} self.request.json = mock.CoroutineMock(return_value=test_valid_rotate_request) @@ -114,6 +123,28 @@ async def test_rotate_conn_not_found(self): with self.assertRaises(test_module.web.HTTPNotFound): await test_module.rotate(self.request) + async def test_rotate_did_validation_errors(self): + self.request.match_info = {"conn_id": test_conn_id} + self.request.json = mock.CoroutineMock(return_value=test_valid_rotate_request) + + for error_class in [ + test_module.UnsupportedDIDMethodError, + test_module.UnresolvableDIDError, + test_module.UnresolvableDIDCommServicesError, + ]: + with mock.patch.object( + test_module, + "DIDRotateManager", + autospec=True, + return_value=mock.MagicMock( + ensure_supported_did=mock.CoroutineMock( + side_effect=error_class("test error") + ), + ), + ): + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.rotate(self.request) + if __name__ == "__main__": unittest.main()