Skip to content

Commit

Permalink
🐛 Ensure supported DID before calling Rotate (#3380)
Browse files Browse the repository at this point in the history
* 🎨 Make `ensure_supported_did` method public

Signed-off-by: ff137 <ff137@proton.me>

* 🐛 Ensure supported DID before calling Rotate

Resolves: #3379
Signed-off-by: ff137 <ff137@proton.me>

* ✅ Fix existing tests

Signed-off-by: ff137 <ff137@proton.me>

* ✅ Test coverage for new DID validation

Signed-off-by: ff137 <ff137@proton.me>

* 🎨 Fix lying method return type

Signed-off-by: ff137 <ff137@proton.me>

* 🐛 Fix message type class mapping

Signed-off-by: ff137 <ff137@proton.me>

* ✨ Handle pydantic.ValidationError when resolving did

Signed-off-by: ff137 <ff137@proton.me>

* 🎨 Replace pydantic.ValidationError with ValueError

Signed-off-by: ff137 <ff137@proton.me>

---------

Signed-off-by: ff137 <ff137@proton.me>
Co-authored-by: Stephen Curran <swcurran@gmail.com>
  • Loading branch information
ff137 and swcurran authored Dec 17, 2024
1 parent 5c1d357 commit e3b0841
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 10 deletions.
2 changes: 1 addition & 1 deletion acapy_agent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ async def resolve_didcomm_services(
try:
doc_dict: dict = await resolver.resolve(self._profile, did, service_accept)
doc: ResolvedDocument = pydid.deserialize_document(doc_dict, strict=True)
except ResolverError as error:
except (ResolverError, ValueError) as error:
raise BaseConnectionManagerError("Failed to resolve DID services") from error

if not doc.service:
Expand Down
2 changes: 1 addition & 1 deletion acapy_agent/ledger/indy_vdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ async def credential_definition_id2schema_id(self, credential_definition_id):
seq_no = tokens[3]
return (await self.get_schema(seq_no))["id"]

async def get_key_for_did(self, did: str) -> str:
async def get_key_for_did(self, did: str) -> Optional[str]:
"""Fetch the verkey for a ledger DID.
Args:
Expand Down
4 changes: 2 additions & 2 deletions acapy_agent/protocols/did_rotate/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async def receive_rotate(self, conn: ConnRecord, rotate: Rotate) -> RotateRecord
)

try:
await self._ensure_supported_did(rotate.to_did)
await self.ensure_supported_did(rotate.to_did)
except ReportableDIDRotateError as err:
responder = self.profile.inject(BaseResponder)
err.message.assign_thread_from(rotate)
Expand Down Expand Up @@ -234,7 +234,7 @@ async def receive_hangup(self, conn: ConnRecord):
async with self.profile.session() as session:
await conn.delete_record(session)

async def _ensure_supported_did(self, did: str):
async def ensure_supported_did(self, did: str):
"""Check if the DID is supported."""
resolver = self.profile.inject(DIDResolver)
conn_mgr = BaseConnectionManager(self.profile)
Expand Down
4 changes: 2 additions & 2 deletions acapy_agent/protocols/did_rotate/v1_0/message_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
MESSAGE_TYPES = DIDCommPrefix.qualify_all(
{
ROTATE: f"{PROTOCOL_PACKAGE}.messages.rotate.Rotate",
ACK: f"{PROTOCOL_PACKAGE}.messages.ack.Ack",
ACK: f"{PROTOCOL_PACKAGE}.messages.ack.RotateAck",
HANGUP: f"{PROTOCOL_PACKAGE}.messages.hangup.Hangup",
PROBLEM_REPORT: f"{PROTOCOL_PACKAGE}.messages.problem_report.ProblemReport",
PROBLEM_REPORT: f"{PROTOCOL_PACKAGE}.messages.problem_report.RotateProblemReport",
}
)
17 changes: 16 additions & 1 deletion acapy_agent/protocols/did_rotate/v1_0/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from ....messaging.models.openapi import OpenAPISchema
from ....messaging.valid import DID_WEB_EXAMPLE, UUID4_EXAMPLE
from ....storage.error import StorageNotFoundError
from .manager import DIDRotateManager
from .manager import (
DIDRotateManager,
UnresolvableDIDCommServicesError,
UnresolvableDIDError,
UnsupportedDIDMethodError,
)
from .message_types import SPEC_URI
from .messages.hangup import HangupSchema as HangupMessageSchema
from .messages.rotate import RotateSchema as RotateMessageSchema
Expand Down Expand Up @@ -63,6 +68,16 @@ async def rotate(request: web.BaseRequest):
body = await request.json()
to_did = body["to_did"]

# Validate DID before proceeding
try:
await did_rotate_mgr.ensure_supported_did(to_did)
except (
UnsupportedDIDMethodError,
UnresolvableDIDError,
UnresolvableDIDCommServicesError,
) as err:
raise web.HTTPBadRequest(reason=str(err)) from err

async with context.session() as session:
try:
conn = await ConnRecord.retrieve_by_id(session, connection_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def test_receive_rotate_x(self):

with (
mock.patch.object(
self.manager, "_ensure_supported_did", side_effect=test_problem_report
self.manager, "ensure_supported_did", side_effect=test_problem_report
),
mock.patch.object(self.responder, "send", mock.CoroutineMock()) as mock_send,
):
Expand Down
35 changes: 33 additions & 2 deletions acapy_agent/protocols/did_rotate/v1_0/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ async def asyncSetUp(self):
"DIDRotateManager",
autospec=True,
return_value=mock.MagicMock(
rotate_my_did=mock.CoroutineMock(return_value=generate_mock_rotate_message())
rotate_my_did=mock.CoroutineMock(return_value=generate_mock_rotate_message()),
ensure_supported_did=mock.CoroutineMock(),
),
)
async def test_rotate(self, *_):
Expand Down Expand Up @@ -102,7 +103,15 @@ async def test_hangup(self, *_):
}
)

async def test_rotate_conn_not_found(self):
@mock.patch.object(
test_module,
"DIDRotateManager",
autospec=True,
return_value=mock.MagicMock(
ensure_supported_did=mock.CoroutineMock(),
),
)
async def test_rotate_conn_not_found(self, *_):
self.request.match_info = {"conn_id": test_conn_id}
self.request.json = mock.CoroutineMock(return_value=test_valid_rotate_request)

Expand All @@ -114,6 +123,28 @@ async def test_rotate_conn_not_found(self):
with self.assertRaises(test_module.web.HTTPNotFound):
await test_module.rotate(self.request)

async def test_rotate_did_validation_errors(self):
self.request.match_info = {"conn_id": test_conn_id}
self.request.json = mock.CoroutineMock(return_value=test_valid_rotate_request)

for error_class in [
test_module.UnsupportedDIDMethodError,
test_module.UnresolvableDIDError,
test_module.UnresolvableDIDCommServicesError,
]:
with mock.patch.object(
test_module,
"DIDRotateManager",
autospec=True,
return_value=mock.MagicMock(
ensure_supported_did=mock.CoroutineMock(
side_effect=error_class("test error")
),
),
):
with self.assertRaises(test_module.web.HTTPBadRequest):
await test_module.rotate(self.request)


if __name__ == "__main__":
unittest.main()

0 comments on commit e3b0841

Please sign in to comment.