-
Notifications
You must be signed in to change notification settings - Fork 516
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: Split out protocols from Trust Ping
Signed-off-by: Colton Wolkins (Laptop) <colton@indicio.tech>
- Loading branch information
1 parent
be0dfee
commit ebefef9
Showing
22 changed files
with
890 additions
and
165 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Version definitions for this protocol.""" | ||
|
||
versions = [ | ||
{ | ||
"major_version": 1, | ||
"minimum_minor_version": 0, | ||
"current_minor_version": 0, | ||
"path": "v1_0", | ||
} | ||
] |
Empty file.
46 changes: 46 additions & 0 deletions
46
acapy_agent/protocols_v2/basicmessage/v1_0/message_types.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Message type identifiers for Trust Pings.""" | ||
|
||
#from ...didcomm_prefix import DIDCommPrefix | ||
import logging | ||
from ....messaging.v2_agent_message import V2AgentMessage | ||
from ....connections.models.connection_target import ConnectionTarget | ||
from didcomm_messaging import DIDCommMessaging, RoutingService | ||
|
||
SPEC_URI = ( | ||
"https://didcomm.org/basicmessage/2.0/message" | ||
) | ||
|
||
# Message types | ||
BASIC_MESSAGE = "https://didcomm.org/basicmessage/2.0/message" | ||
|
||
PROTOCOL_PACKAGE = "acapy_agent.protocols_v2.basicmessage.v1_0" | ||
|
||
class basic_message: | ||
async def __call__(self, *args, **kwargs): | ||
await self.handle(*args, **kwargs) | ||
@staticmethod | ||
async def handle(context, responder, payload): | ||
logger = logging.getLogger(__name__) | ||
their_did = context.message_receipt.sender_verkey.split('#')[0] | ||
our_did = context.message_receipt.recipient_verkey.split('#')[0] | ||
error_result = V2AgentMessage( | ||
message={ | ||
"type": BASIC_MESSAGE, | ||
"body": { | ||
"content": "Hello from acapy", | ||
}, | ||
"to": [their_did], | ||
"from": our_did, | ||
"lang": "en", | ||
} | ||
) | ||
await responder.send_reply(error_result) | ||
|
||
|
||
HANDLERS = { | ||
BASIC_MESSAGE: f"{PROTOCOL_PACKAGE}.message_types.basic_message", | ||
}.items() | ||
|
||
MESSAGE_TYPES = { | ||
BASIC_MESSAGE: f"{PROTOCOL_PACKAGE}.message_types.basic_message", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
"""Trust ping admin routes.""" | ||
|
||
from aiohttp import web | ||
from aiohttp_apispec import docs, match_info_schema, request_schema, response_schema | ||
from marshmallow import fields | ||
from didcomm_messaging import DIDCommMessaging, RoutingService | ||
from didcomm_messaging.resolver import DIDResolver as DMPResolver | ||
|
||
from ....admin.decorators.auth import tenant_authentication | ||
from ....admin.request_context import AdminRequestContext | ||
from ....connections.models.conn_record import ConnRecord | ||
from ....messaging.models.openapi import OpenAPISchema | ||
from ....messaging.valid import UUID4_EXAMPLE | ||
from ....storage.error import StorageNotFoundError | ||
from .message_types import SPEC_URI | ||
|
||
|
||
class BaseDIDCommV2Schema(OpenAPISchema): | ||
"""Request schema for performing a ping.""" | ||
|
||
to_did = fields.Str( | ||
required=True, | ||
allow_none=False, | ||
metadata={"description": "Comment for the ping message"}, | ||
) | ||
|
||
|
||
class PingRequestSchema(BaseDIDCommV2Schema): | ||
"""Request schema for performing a ping.""" | ||
|
||
response_requested = fields.Bool( | ||
required=False, | ||
allow_none=True, | ||
metadata={"description": "Comment for the ping message"}, | ||
) | ||
|
||
|
||
class PingRequestResponseSchema(OpenAPISchema): | ||
"""Request schema for performing a ping.""" | ||
|
||
thread_id = fields.Str( | ||
required=False, metadata={"description": "Thread ID of the ping message"} | ||
) | ||
|
||
|
||
class PingConnIdMatchInfoSchema(OpenAPISchema): | ||
"""Path parameters and validators for request taking connection id.""" | ||
|
||
conn_id = fields.Str( | ||
required=True, | ||
metadata={"description": "Connection identifier", "example": UUID4_EXAMPLE}, | ||
) | ||
|
||
from ....wallet.base import BaseWallet | ||
from ....wallet.did_info import DIDInfo | ||
from ....wallet.did_method import KEY, PEER2, PEER4, SOV, DIDMethod, DIDMethods, HolderDefinedDid | ||
from ....wallet.did_posture import DIDPosture | ||
from ....wallet.error import WalletError, WalletNotFoundError | ||
from ....messaging.v2_agent_message import V2AgentMessage | ||
from ....connections.models.connection_target import ConnectionTarget | ||
from didcomm_messaging import DIDCommMessaging, RoutingService | ||
def format_did_info(info: DIDInfo): | ||
"""Serialize a DIDInfo object.""" | ||
if info: | ||
return { | ||
"did": info.did, | ||
"verkey": info.verkey, | ||
"posture": DIDPosture.get(info.metadata).moniker, | ||
"key_type": info.key_type.key_type, | ||
"method": info.method.method_name, | ||
"metadata": info.metadata, | ||
} | ||
|
||
async def get_mydid(request: web.BaseRequest): | ||
context: AdminRequestContext = request["context"] | ||
#filter_did = request.query.get("did") | ||
#filter_verkey = request.query.get("verkey") | ||
filter_posture = DIDPosture.get(request.query.get("posture")) | ||
results = [] | ||
async with context.session() as session: | ||
did_methods: DIDMethods = session.inject(DIDMethods) | ||
filter_method: DIDMethod | None = did_methods.from_method( | ||
request.query.get("method") or "did:peer:2" | ||
) | ||
#key_types = session.inject(KeyTypes) | ||
#filter_key_type = key_types.from_key_type(request.query.get("key_type", "")) | ||
wallet: BaseWallet | None = session.inject_or(BaseWallet) | ||
if not wallet: | ||
raise web.HTTPForbidden(reason="No wallet available") | ||
else: | ||
dids = await wallet.get_local_dids() | ||
results = [ | ||
format_did_info(info) | ||
for info in dids | ||
if ( | ||
filter_posture is None | ||
or DIDPosture.get(info.metadata) is DIDPosture.WALLET_ONLY | ||
) | ||
and (not filter_method or info.method == filter_method) | ||
#and (not filter_key_type or info.key_type == filter_key_type) | ||
] | ||
|
||
results.sort(key=lambda info: (DIDPosture.get(info["posture"]).ordinal, info["did"])) | ||
our_did = results[0]["did"] | ||
return our_did | ||
|
||
async def get_target(request: web.BaseRequest, to_did: str, from_did: str): | ||
context: AdminRequestContext = request["context"] | ||
|
||
try: | ||
async with context.profile.session() as session: | ||
resolver = session.inject(DMPResolver) | ||
did_doc = await resolver.resolve(to_did) | ||
except Exception as err: | ||
raise web.HTTPNotFound(reason=str(err)) from err | ||
|
||
async with context.session() as session: | ||
ctx = session | ||
messaging = ctx.inject(DIDCommMessaging) | ||
routing_service = ctx.inject(RoutingService) | ||
frm = to_did | ||
services = await routing_service._resolve_services(messaging.resolver, frm) | ||
chain = [ | ||
{ | ||
"did": frm, | ||
"service": services, | ||
} | ||
] | ||
|
||
# Loop through service DIDs until we run out of DIDs to forward to | ||
to_target = services[0].service_endpoint.uri | ||
found_forwardable_service = await routing_service.is_forwardable_service( | ||
messaging.resolver, services[0] | ||
) | ||
while found_forwardable_service: | ||
services = await routing_service._resolve_services(messaging.resolver, to_target) | ||
if services: | ||
chain.append( | ||
{ | ||
"did": to_target, | ||
"service": services, | ||
} | ||
) | ||
to_target = services[0].service_endpoint.uri | ||
found_forwardable_service = ( | ||
await routing_service.is_forwardable_service(messaging.resolver, services[0]) | ||
if services | ||
else False | ||
) | ||
reply_destination = [ | ||
ConnectionTarget( | ||
did=f"{to_did}#key-1", | ||
endpoint=service.service_endpoint.uri, | ||
recipient_keys=[f"{to_did}#key-1"], | ||
sender_key=from_did + "#key-1", | ||
) | ||
for service in chain[-1]["service"] | ||
] | ||
return reply_destination | ||
|
||
|
||
class BasicMessageSchema(BaseDIDCommV2Schema): | ||
"""Request schema for performing a ping.""" | ||
|
||
content = fields.Str( | ||
required=True, | ||
allow_none=False, | ||
metadata={"description": "Basic Message message content"}, | ||
) | ||
|
||
|
||
@docs(tags=["basicmessagev2", "didcommv2"], summary="Send a Basic Message") | ||
@request_schema(BasicMessageSchema()) | ||
@response_schema(PingRequestResponseSchema(), 200, description="") | ||
@tenant_authentication | ||
async def basic_message_send(request: web.BaseRequest): | ||
"""Request handler for sending a trust ping to a connection. | ||
Args: | ||
request: aiohttp request object | ||
""" | ||
context: AdminRequestContext = request["context"] | ||
outbound_handler = request["outbound_message_router"] | ||
body = await request.json() | ||
to_did = body.get("to_did") | ||
message = body.get("content") | ||
|
||
our_did = await get_mydid(request) | ||
their_did = to_did | ||
reply_destination = await get_target(request, to_did, our_did) | ||
msg = V2AgentMessage( | ||
message={ | ||
"type": "https://didcomm.org/basicmessage/2.0/message", | ||
"body": { | ||
"content": message | ||
}, | ||
"lang": "en", | ||
"to": [their_did], | ||
"from": our_did, | ||
} | ||
) | ||
await outbound_handler(msg, target_list=reply_destination) | ||
return web.json_response(msg.message) | ||
|
||
|
||
async def register(app: web.Application): | ||
"""Register routes.""" | ||
|
||
app.add_routes([web.post("/basic-message/send-message", basic_message_send)]) | ||
|
||
|
||
def post_process_routes(app: web.Application): | ||
"""Amend swagger API.""" | ||
|
||
# Add top-level tags description | ||
if "tags" not in app._state["swagger_dict"]: | ||
app._state["swagger_dict"]["tags"] = [] | ||
app._state["swagger_dict"]["tags"].append( | ||
{ | ||
"name": "basicmessagev2", | ||
"description": "Basic Message to contact", | ||
"externalDocs": {"description": "Specification", "url": SPEC_URI}, | ||
} | ||
) | ||
app._state["swagger_dict"]["tags"].append( | ||
{ | ||
"name": "didcommv2", | ||
"description": "DIDComm V2 based protocols for Interop-a-thon", | ||
"externalDocs": {"description": "Specification", "url": "https://didcomm.org"}, | ||
} | ||
) |
Empty file.
84 changes: 84 additions & 0 deletions
84
acapy_agent/protocols_v2/basicmessage/v1_0/tests/test_routes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from unittest import IsolatedAsyncioTestCase | ||
|
||
from acapy_agent.tests import mock | ||
|
||
from .....admin.request_context import AdminRequestContext | ||
from .....core.in_memory import InMemoryProfile | ||
from .. import routes as test_module | ||
|
||
|
||
class TestTrustpingRoutes(IsolatedAsyncioTestCase): | ||
def setUp(self): | ||
self.session_inject = {} | ||
profile = InMemoryProfile.test_profile( | ||
settings={ | ||
"admin.admin_api_key": "secret-key", | ||
} | ||
) | ||
self.context = AdminRequestContext.test_context(self.session_inject, profile) | ||
self.request_dict = { | ||
"context": self.context, | ||
"outbound_message_router": mock.CoroutineMock(), | ||
} | ||
self.request = mock.MagicMock( | ||
app={}, | ||
match_info={}, | ||
query={}, | ||
__getitem__=lambda _, k: self.request_dict[k], | ||
headers={"x-api-key": "secret-key"}, | ||
) | ||
|
||
async def test_connections_send_ping(self): | ||
self.request.json = mock.CoroutineMock(return_value={"comment": "some comment"}) | ||
self.request.match_info = {"conn_id": "dummy"} | ||
|
||
with mock.patch.object( | ||
test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() | ||
) as mock_retrieve, mock.patch.object( | ||
test_module, "Ping", mock.MagicMock() | ||
) as mock_ping, mock.patch.object( | ||
test_module.web, "json_response", mock.MagicMock() | ||
) as json_response: | ||
mock_ping.return_value = mock.MagicMock(_thread_id="dummy") | ||
mock_retrieve.return_value = mock.MagicMock(is_ready=True) | ||
result = await test_module.connections_send_ping(self.request) | ||
json_response.assert_called_once_with({"thread_id": "dummy"}) | ||
assert result is json_response.return_value | ||
|
||
async def test_connections_send_ping_no_conn(self): | ||
self.request.json = mock.CoroutineMock(return_value={"comment": "some comment"}) | ||
self.request.match_info = {"conn_id": "dummy"} | ||
|
||
with mock.patch.object( | ||
test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() | ||
) as mock_retrieve, mock.patch.object( | ||
test_module.web, "json_response", mock.MagicMock() | ||
) as json_response: | ||
mock_retrieve.side_effect = test_module.StorageNotFoundError() | ||
with self.assertRaises(test_module.web.HTTPNotFound): | ||
await test_module.connections_send_ping(self.request) | ||
|
||
async def test_connections_send_ping_not_ready(self): | ||
self.request.json = mock.CoroutineMock(return_value={"comment": "some comment"}) | ||
self.request.match_info = {"conn_id": "dummy"} | ||
|
||
with mock.patch.object( | ||
test_module.ConnRecord, "retrieve_by_id", mock.CoroutineMock() | ||
) as mock_retrieve, mock.patch.object( | ||
test_module.web, "json_response", mock.MagicMock() | ||
) as json_response: | ||
mock_retrieve.return_value = mock.MagicMock(is_ready=False) | ||
with self.assertRaises(test_module.web.HTTPBadRequest): | ||
await test_module.connections_send_ping(self.request) | ||
|
||
async def test_register(self): | ||
mock_app = mock.MagicMock() | ||
mock_app.add_routes = mock.MagicMock() | ||
|
||
await test_module.register(mock_app) | ||
mock_app.add_routes.assert_called_once() | ||
|
||
async def test_post_process_routes(self): | ||
mock_app = mock.MagicMock(_state={"swagger_dict": {}}) | ||
test_module.post_process_routes(mock_app) | ||
assert "tags" in mock_app._state["swagger_dict"] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Version definitions for this protocol.""" | ||
|
||
versions = [ | ||
{ | ||
"major_version": 1, | ||
"minimum_minor_version": 0, | ||
"current_minor_version": 0, | ||
"path": "v1_0", | ||
} | ||
] |
Empty file.
Oops, something went wrong.