diff --git a/basicmessage_storage/basicmessage_storage/v1_0/models.py b/basicmessage_storage/basicmessage_storage/v1_0/models.py index 66cd9d29f..dc32ae9bd 100644 --- a/basicmessage_storage/basicmessage_storage/v1_0/models.py +++ b/basicmessage_storage/basicmessage_storage/v1_0/models.py @@ -6,6 +6,7 @@ INDY_ISO8601_DATETIME_VALIDATE, ) from marshmallow import fields +from aries_cloudagent.storage.base import BaseStorage class BasicMessageRecord(BaseRecord): @@ -62,6 +63,25 @@ def record_tags(self) -> dict: """Get tags for record.""" return {"connection_id": self.connection_id, "message_id": self.message_id} + + async def delete_record(self, session: ProfileSession): + """Perform connection record deletion actions. + + Args: + session (ProfileSession): session + + """ + await super().delete_record(session) + + # Delete metadata + if self.message_id: + storage = session.inject(BaseStorage) + await storage.delete_all_records( + self.RECORD_TYPE, + {"message_id": self.message_id}, + ) + + @classmethod async def retrieve_by_message_id( cls, session: ProfileSession, message_id: str diff --git a/basicmessage_storage/basicmessage_storage/v1_0/routes.py b/basicmessage_storage/basicmessage_storage/v1_0/routes.py index 27e85c9eb..c71293b99 100644 --- a/basicmessage_storage/basicmessage_storage/v1_0/routes.py +++ b/basicmessage_storage/basicmessage_storage/v1_0/routes.py @@ -15,6 +15,7 @@ from aries_cloudagent.messaging.models.base import BaseModelError from aries_cloudagent.messaging.models.openapi import OpenAPISchema from aries_cloudagent.messaging.util import time_now, str_to_epoch +from aries_cloudagent.messaging.valid import UUID4_EXAMPLE from aries_cloudagent.multitenant.error import WalletKeyMissingError from aries_cloudagent.protocols.basicmessage.v1_0.message_types import SPEC_URI from aries_cloudagent.protocols.basicmessage.v1_0.routes import ( @@ -60,6 +61,18 @@ class BasicMessageListSchema(OpenAPISchema): description="List of basic message records", ) +class BasicMessageIdMatchInfoSchema(OpenAPISchema): + """Path parameters and validators for request taking message id.""" + + message_id = fields.Str( + required=True, + metadata={"description": "Message identifier", "example": UUID4_EXAMPLE}, + ) + + +class DeleteResponseSchema(OpenAPISchema): + """Response schema for DELETE endpoint.""" + class BasicMessageListQueryStringSchema(OpenAPISchema): """Basic Messages List query string schema.""" @@ -167,6 +180,34 @@ async def all_messages_list(request: web.BaseRequest): return web.json_response({"results": results}) +@docs( + tags=["basicmessage"], + summary="delete stored message by message_id", +) +@match_info_schema(BasicMessageIdMatchInfoSchema()) +@response_schema(DeleteResponseSchema(), 200, description="") +@error_handler +async def delete_message(request: web.BaseRequest): + """Request handler for searching basic message record by id. + + Args: + request: aiohttp request object + """ + context: AdminRequestContext = request["context"] + profile = context.profile + message_id = request.match_info["message_id"] + try: + async with profile.session() as session: + record = await BasicMessageRecord.retrieve_by_message_id(session, message_id) + await record.delete_record(session) + + except StorageNotFoundError as err: + raise web.HTTPNotFound(reason=err.roll_up) from err + except (StorageError, BaseModelError) as err: + raise web.HTTPBadRequest(reason=err.roll_up) from err + return web.json_response({}) + + async def register(app: web.Application): """Register routes.""" # we want to save messages when sent, so replace the default send message endpoint @@ -200,6 +241,7 @@ async def register(app: web.Application): # add in the message list(s) route app.add_routes([web.get("/basicmessages", all_messages_list, allow_head=False)]) + app.add_routes([web.delete("/basicmessages/{message_id}", delete_message)]) def post_process_routes(app: web.Application): diff --git a/basicmessage_storage/integration/tests/__init__.py b/basicmessage_storage/integration/tests/__init__.py index 81b278bd4..5e73614f0 100644 --- a/basicmessage_storage/integration/tests/__init__.py +++ b/basicmessage_storage/integration/tests/__init__.py @@ -18,6 +18,10 @@ def post(agent: str, path: str, **kwargs): """Post.""" return requests.post(f"{agent}{path}", **kwargs) +def delete(agent: str, path: str, **kwargs): + """Post.""" + return requests.delete(f"{agent}{path}", **kwargs) + def fail_if_not_ok(message: str): """Fail the current test if wrapped call fails with message.""" @@ -80,6 +84,13 @@ def accept_invite(self, connection_id: str): def retrieve_basicmessages(self, **kwargs): """Retrieve connections.""" return get(self.url, "/basicmessages", params=kwargs) + + + @unwrap_json_response + @fail_if_not_ok("Failed to delete basic messages") + def delete_basicmessage(self, message_id, **kwargs): + """Retrieve connections.""" + return delete(self.url, f"/basicmessages/{message_id}", params=kwargs) @unwrap_json_response @fail_if_not_ok("Failed to send basic message") diff --git a/basicmessage_storage/integration/tests/test_basicmessage_storage.py b/basicmessage_storage/integration/tests/test_basicmessage_storage.py index 78487dd77..5613ac4d7 100644 --- a/basicmessage_storage/integration/tests/test_basicmessage_storage.py +++ b/basicmessage_storage/integration/tests/test_basicmessage_storage.py @@ -47,3 +47,29 @@ def test_storage(bob, alice, established_connection): # alice should have 1 sent and 1 received (auto-reponse) alice_messages = alice.retrieve_basicmessages() assert len(alice_messages["results"]) == 2 + +def test_deletion(bob, alice, established_connection): + # make sure connection is active... + time.sleep(2) + + # alice send bob a message (alice will store their sent message) + resp = alice.send_message(established_connection, "hello bob") + assert True + + # make sure auto-respond messages have been exchanged + time.sleep(2) + + # bob should have 1 received + bob_messages = bob.retrieve_basicmessages() + assert len(bob_messages["results"]) == 2 + + # alice should have 1 sent and 1 received (auto-reponse) + alice_messages = alice.retrieve_basicmessages() + assert len(alice_messages["results"]) == 4 + time.sleep(2) + + + alice.delete_basicmessage(alice_messages["results"][0]["message_id"]) + time.sleep(2) + alice_messages = alice.retrieve_basicmessages() + assert len(alice_messages["results"]) == 3