Skip to content

Commit

Permalink
Merge pull request openwallet-foundation#2891 from dbluhm/fix/did-reu…
Browse files Browse the repository at this point in the history
…se-invitation-issue

fix: look up conn record by invite msg id instead of key
  • Loading branch information
swcurran authored Apr 15, 2024
2 parents 58fc72c + 5fab831 commit 8df6b3e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 45 deletions.
27 changes: 12 additions & 15 deletions aries_cloudagent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ..resolver.base import ResolverError
from ..resolver.did_resolver import DIDResolver
from ..storage.base import BaseStorage
from ..storage.error import StorageDuplicateError, StorageError, StorageNotFoundError
from ..storage.error import StorageDuplicateError, StorageNotFoundError
from ..storage.record import StorageRecord
from ..transport.inbound.receipt import MessageReceipt
from ..utils.multiformats import multibase, multicodec
Expand Down Expand Up @@ -854,17 +854,17 @@ async def fetch_did_document(self, did: str) -> Tuple[dict, StorageRecord]:

async def find_connection(
self,
their_did: str,
their_did: Optional[str],
my_did: Optional[str] = None,
my_verkey: Optional[str] = None,
parent_thread_id: Optional[str] = None,
auto_complete=False,
) -> Optional[ConnRecord]:
"""Look up existing connection information for a sender verkey.
Args:
their_did: Their DID
my_did: My DID
my_verkey: My verkey
parent_thread_id: Parent thread ID
auto_complete: Should this connection automatically be promoted to active
Returns:
Expand Down Expand Up @@ -895,16 +895,13 @@ async def find_connection(
connection_id=connection.connection_id
)

if not connection and my_verkey:
try:
async with self._profile.session() as session:
connection = await ConnRecord.retrieve_by_invitation_key(
session,
my_verkey,
their_role=ConnRecord.Role.REQUESTER.rfc160,
)
except StorageError:
pass
if not connection and parent_thread_id:
async with self._profile.session() as session:
connection = await ConnRecord.retrieve_by_invitation_msg_id(
session,
parent_thread_id,
their_role=ConnRecord.Role.REQUESTER.rfc160,
)

return connection

Expand Down Expand Up @@ -1001,7 +998,7 @@ async def resolve_inbound_connection(
)

return await self.find_connection(
receipt.sender_did, receipt.recipient_did, receipt.recipient_verkey, True
receipt.sender_did, receipt.recipient_did, receipt.parent_thread_id, True
)

async def get_endpoints(self, conn_id: str) -> Tuple[Optional[str], Optional[str]]:
Expand Down
20 changes: 11 additions & 9 deletions aries_cloudagent/connections/tests/test_base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ async def asyncSetUp(self):
self.test_target_did = "GbuDUYXaUZRfHD2jeDuQuP"
self.test_target_verkey = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC"

self.test_pthid = "test-pthid"

self.responder = MockResponder()

self.oob_mock = mock.MagicMock(
Expand Down Expand Up @@ -1645,7 +1647,7 @@ async def test_find_connection_retrieve_by_did(self):
conn_rec = await self.manager.find_connection(
their_did=self.test_target_did,
my_did=self.test_did,
my_verkey=self.test_verkey,
parent_thread_id=self.test_pthid,
auto_complete=True,
)
assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED
Expand All @@ -1665,7 +1667,7 @@ async def test_find_connection_retrieve_by_did_auto_disclose_features(self):
conn_rec = await self.manager.find_connection(
their_did=self.test_target_did,
my_did=self.test_did,
my_verkey=self.test_verkey,
parent_thread_id=self.test_pthid,
auto_complete=True,
)
assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED
Expand All @@ -1675,18 +1677,18 @@ async def test_find_connection_retrieve_by_invitation_key(self):
with mock.patch.object(
ConnRecord, "retrieve_by_did", mock.CoroutineMock()
) as mock_conn_retrieve_by_did, mock.patch.object(
ConnRecord, "retrieve_by_invitation_key", mock.CoroutineMock()
) as mock_conn_retrieve_by_invitation_key:
ConnRecord, "retrieve_by_invitation_msg_id", mock.CoroutineMock()
) as mock_conn_retrieve_by_invitation_msg_id:
mock_conn_retrieve_by_did.side_effect = StorageNotFoundError()
mock_conn_retrieve_by_invitation_key.return_value = mock.MagicMock(
mock_conn_retrieve_by_invitation_msg_id.return_value = mock.MagicMock(
state=ConnRecord.State.RESPONSE,
save=mock.CoroutineMock(),
)

conn_rec = await self.manager.find_connection(
their_did=self.test_target_did,
my_did=self.test_did,
my_verkey=self.test_verkey,
parent_thread_id=self.test_pthid,
)
assert conn_rec

Expand All @@ -1695,14 +1697,14 @@ async def test_find_connection_retrieve_none_by_invitation_key(self):
ConnRecord, "retrieve_by_did", mock.CoroutineMock()
) as mock_conn_retrieve_by_did, mock.patch.object(
ConnRecord, "retrieve_by_invitation_key", mock.CoroutineMock()
) as mock_conn_retrieve_by_invitation_key:
) as mock_conn_retrieve_by_invitation_msg_id:
mock_conn_retrieve_by_did.side_effect = StorageNotFoundError()
mock_conn_retrieve_by_invitation_key.side_effect = StorageNotFoundError()
mock_conn_retrieve_by_invitation_msg_id.return_value = None

conn_rec = await self.manager.find_connection(
their_did=self.test_target_did,
my_did=self.test_did,
my_verkey=self.test_verkey,
parent_thread_id=self.test_pthid,
)
assert conn_rec is None

Expand Down
29 changes: 14 additions & 15 deletions aries_cloudagent/protocols/didexchange/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,7 @@ async def receive_request(
)

if recipient_verkey:
conn_rec = await self._receive_request_pairwise_did(
request, recipient_verkey, alias
)
conn_rec = await self._receive_request_pairwise_did(request, alias)
else:
conn_rec = await self._receive_request_public_did(
request, recipient_did, alias, auto_accept_implicit
Expand All @@ -539,22 +537,23 @@ async def receive_request(
async def _receive_request_pairwise_did(
self,
request: DIDXRequest,
recipient_verkey: str,
alias: Optional[str] = None,
) -> ConnRecord:
"""Receive a DID Exchange request against a pairwise (not public) DID."""
try:
async with self.profile.session() as session:
conn_rec = await ConnRecord.retrieve_by_invitation_key(
session=session,
invitation_key=recipient_verkey,
their_role=ConnRecord.Role.REQUESTER.rfc23,
)
except StorageNotFoundError:
if not request._thread.pthid:
raise DIDXManagerError("DID Exchange request missing parent thread ID")

async with self.profile.session() as session:
conn_rec = await ConnRecord.retrieve_by_invitation_msg_id(
session=session,
invitation_msg_id=request._thread.pthid,
their_role=ConnRecord.Role.REQUESTER.rfc23,
)

if not conn_rec:
raise DIDXManagerError(
"No explicit invitation found for pairwise connection "
f"in state {ConnRecord.State.INVITATION.rfc23}: "
"a prior connection request may have updated the connection state"
"Pairwise requests must be against explicit invitations that have not "
"been previously consumed"
)

if conn_rec.is_multiuse_invitation:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -721,8 +721,8 @@ async def test_receive_request_invi_not_found(self):
with mock.patch.object(
test_module, "ConnRecord", mock.MagicMock()
) as mock_conn_rec_cls:
mock_conn_rec_cls.retrieve_by_invitation_key = mock.CoroutineMock(
side_effect=StorageNotFoundError()
mock_conn_rec_cls.retrieve_by_invitation_msg_id = mock.CoroutineMock(
return_value=None
)
with self.assertRaises(DIDXManagerError) as context:
await self.manager.receive_request(
Expand All @@ -732,7 +732,7 @@ async def test_receive_request_invi_not_found(self):
alias=None,
auto_accept_implicit=None,
)
assert "No explicit invitation found" in str(context.exception)
assert "explicit invitations" in str(context.exception)

async def test_receive_request_public_did_no_did_doc_attachment(self):
async with self.profile.session() as session:
Expand Down Expand Up @@ -1376,7 +1376,7 @@ async def test_receive_request_peer_did(self):
), mock.patch.object(
self.manager, "store_did_document", mock.CoroutineMock()
):
mock_conn_rec_cls.retrieve_by_invitation_key = mock.CoroutineMock(
mock_conn_rec_cls.retrieve_by_invitation_msg_id = mock.CoroutineMock(
return_value=mock_conn
)
mock_conn_rec_cls.return_value = mock.MagicMock(
Expand Down Expand Up @@ -1435,8 +1435,8 @@ async def test_receive_request_peer_did_not_found_x(self):
with mock.patch.object(
test_module, "ConnRecord", mock.MagicMock()
) as mock_conn_rec_cls:
mock_conn_rec_cls.retrieve_by_invitation_key = mock.CoroutineMock(
side_effect=StorageNotFoundError()
mock_conn_rec_cls.retrieve_by_invitation_msg_id = mock.CoroutineMock(
return_value=None
)
with self.assertRaises(DIDXManagerError):
await self.manager.receive_request(
Expand Down

0 comments on commit 8df6b3e

Please sign in to comment.