From 4822034d47bffbe27e0047b410908e93481da44b Mon Sep 17 00:00:00 2001 From: Tom Foster Date: Fri, 30 Aug 2024 14:56:14 +0100 Subject: [PATCH] Error handling and comments --- maubot.yaml | 2 +- nsfwbot.py | 237 ++++++++++++++++++++++++++++++++-------------------- 2 files changed, 149 insertions(+), 90 deletions(-) diff --git a/maubot.yaml b/maubot.yaml index 3121f5a..f7d828e 100644 --- a/maubot.yaml +++ b/maubot.yaml @@ -1,6 +1,6 @@ maubot: 0.1.0 id: uk.tcpip.nsfwbot -version: 0.2.0 +version: 0.2.1 license: AGPL-3.0-or-later modules: - nsfwbot diff --git a/nsfwbot.py b/nsfwbot.py index 878f436..233145c 100644 --- a/nsfwbot.py +++ b/nsfwbot.py @@ -1,3 +1,10 @@ +""" +NSFW Model Plugin for Maubot + +This plugin detects NSFW content in images and text messages containing image tags, +and takes appropriate actions based on the configuration. +""" + import os from typing import List, Tuple, Type from uuid import uuid4 @@ -21,16 +28,25 @@ class Config(BaseProxyConfig): """ - Configuration manager + Configuration manager for the NSFWModelPlugin. """ def do_update(self, helper: ConfigUpdateHelper) -> None: + """ + Update the configuration with new values. + + :param helper: Helper object to copy configuration values. + """ helper.copy("max_concurrent_jobs") helper.copy("via_servers") helper.copy("actions") class NSFWModelPlugin(Plugin): + """ + Plugin to detect NSFW content in images and text messages. + """ + model = Model() semaphore = Semaphore(1) via_servers = [] @@ -39,35 +55,38 @@ class NSFWModelPlugin(Plugin): @classmethod def get_config_class(cls) -> Type[BaseProxyConfig]: + """ + Get the configuration class for the plugin. + + :return: Configuration class. + """ return Config async def start(self) -> None: - """Initialise plugin by loading config and setting up semaphore.""" + """ + Initialise plugin by loading config and setting up semaphore. + """ await super().start() - # Check if config exists - if not isinstance(self.config, Config): - self.log.error("Plugin not yet configured.") - else: - # Load in config - self.config.load_and_update() - # Load via_servers from config, with a default fallback - self.via_servers = self.config["via_servers"] - # Load actions from config - self.actions = self.config["actions"] - # Initialise the Semaphore based on the max_concurrent_jobs setting - max_concurrent_jobs = self.config["max_concurrent_jobs"] - self.semaphore = Semaphore(max_concurrent_jobs) - # Resolve room ID from alias if it starts with # - self.report_to_room = str(self.actions.get("report_to_room", "")) - if self.report_to_room.startswith("#"): - report_to_info = await self.client.resolve_room_alias( - RoomAlias(self.report_to_room) - ) - self.report_to_room = report_to_info.room_id - elif self.report_to_room and not self.report_to_room.startswith("!"): - self.log.warning("Invalid room ID or alias provided for report_to_room") - # Initialise the NSFW model - self.log.info("Loaded nsfwbot successfully") + try: + if not isinstance(self.config, Config): + self.log.error("Plugin not yet configured.") + else: + self.config.load_and_update() + self.via_servers = self.config["via_servers"] + self.actions = self.config["actions"] + max_concurrent_jobs = self.config["max_concurrent_jobs"] + self.semaphore = Semaphore(max_concurrent_jobs) + self.report_to_room = str(self.actions.get("report_to_room", "")) + if self.report_to_room.startswith("#"): + report_to_info = await self.client.resolve_room_alias( + RoomAlias(self.report_to_room) + ) + self.report_to_room = report_to_info.room_id + elif self.report_to_room and not self.report_to_room.startswith("!"): + self.log.warning("Invalid room ID or alias provided for report_to_room") + self.log.info("Loaded nsfwbot successfully") + except Exception as e: + self.log.error(f"Error during start: {e}") @command.passive( "^mxc://.+/.+$", @@ -75,17 +94,21 @@ async def start(self) -> None: msgtypes=(MessageType.IMAGE,), ) async def handle_image_message(self, evt: MessageEvent, url: Tuple[str]) -> None: - """Handle direct image messages.""" - if not isinstance(evt.content, MediaMessageEventContent) or not evt.content.url: - return - # Process image - results = await self.process_images([evt.content.url]) - # Create matrix.to URL for the original message - matrix_to_url = self.create_matrix_to_url(evt.room_id, evt.event_id) - # Prepare the response message - response = self.format_response(results, matrix_to_url) - # Send responses based on actions - await self.send_responses(evt, response, results) + """ + Handle direct image messages. + + :param evt: The message event containing the image. + :param url: The URL of the image. + """ + try: + if not isinstance(evt.content, MediaMessageEventContent) or not evt.content.url: + return + results = await self.process_images([evt.content.url]) + matrix_to_url = self.create_matrix_to_url(evt.room_id, evt.event_id) + response = self.format_response(results, matrix_to_url) + await self.send_responses(evt, response, results) + except Exception as e: + self.log.error(f"Error handling image message: {e}") @command.passive( '^ None msgtypes=(MessageType.TEXT,), ) async def handle_text_message(self, evt: MessageEvent) -> None: - """Handle text messages with possible tags.""" - if isinstance(evt.content, TextMessageEventContent) and evt.content.formatted_body: - img_urls = self.extract_img_tags(evt.content.formatted_body) - # Do nothing if no URLs found - if len(img_urls) == 0: - return - # Process all images at once - all_results = await self.process_images([ContentURI(url) for url in img_urls]) - # Create matrix.to URL for the original message - matrix_to_url = self.create_matrix_to_url(evt.room_id, evt.event_id) - # Prepare the response message - response = self.format_response(all_results, matrix_to_url) - # Send responses based on actions - await self.send_responses(evt, response, all_results) + """ + Handle text messages with possible tags. + + :param evt: The message event containing the text. + """ + try: + if isinstance(evt.content, TextMessageEventContent) and evt.content.formatted_body: + img_urls = self.extract_img_tags(evt.content.formatted_body) + if len(img_urls) == 0: + return + all_results = await self.process_images([ContentURI(url) for url in img_urls]) + matrix_to_url = self.create_matrix_to_url(evt.room_id, evt.event_id) + response = self.format_response(all_results, matrix_to_url) + await self.send_responses(evt, response, all_results) + except Exception as e: + self.log.error(f"Error handling text message: {e}") async def process_images(self, mxc_urls: List[ContentURI]) -> dict: - """Download and process the images using the NSFW model.""" + """ + Download and process the images using the NSFW model. + + :param mxc_urls: List of MXC URLs of the images. + :return: Dictionary of results with MXC URLs as keys and predictions as values. + """ async with self.semaphore: temp_files = [] try: - # Download all images and save to temporary files for mxc_url in mxc_urls: img_bytes = await self.client.download_media(mxc_url) # type:ignore temp_filename = f"/tmp/{uuid4()}.jpg" @@ -121,22 +150,28 @@ async def process_images(self, mxc_urls: List[ContentURI]) -> dict: img_file.write(img_bytes) temp_files.append((mxc_url, temp_filename)) - # Predict using the NSFW model predictions = self.model.predict([temp_filename for _, temp_filename in temp_files]) - # Replace temporary filenames in the results with the original MXC URLs final_results = { str(mxc_url): predictions.pop(temp_filename) for mxc_url, temp_filename in temp_files } return final_results + except Exception as e: + self.log.error(f"Error processing images: {e}") + return {} finally: - # Ensure all temporary files are removed for _, temp_filename in temp_files: os.remove(temp_filename) def create_matrix_to_url(self, room_id: RoomID, event_id: EventID) -> str: - """Create a matrix.to URL for a given room ID and event ID.""" + """ + Create a matrix.to URL for a given room ID and event ID. + + :param room_id: The room ID. + :param event_id: The event ID. + :return: The matrix.to URL. + """ via_params = ( str("?" + "&".join([f"via={server}" for server in self.via_servers])) if self.via_servers @@ -145,12 +180,23 @@ def create_matrix_to_url(self, room_id: RoomID, event_id: EventID) -> str: return f"https://matrix.to/#/{room_id}/{event_id}{via_params}" def extract_img_tags(self, html: str) -> List[str]: - """Extract image URLs from tags in the HTML content.""" + """ + Extract image URLs from tags in the HTML content. + + :param html: The HTML content. + :return: List of image URLs. + """ soup = BeautifulSoup(html, "html.parser") return [img["src"] for img in soup.find_all("img") if "src" in img.attrs] def format_response(self, results: dict, matrix_to_url: str) -> str: - """Format the response message based on the results.""" + """ + Format the response message based on the results. + + :param results: Dictionary of results with MXC URLs as keys and predictions as values. + :param matrix_to_url: The matrix.to URL for the original message. + :return: The formatted response message. + """ response_parts = [ f"{mxc_url} in {matrix_to_url} appears {res['Label']} with score {res['Score']:.2%}" for mxc_url, res in results.items() @@ -161,33 +207,46 @@ def format_response(self, results: dict, matrix_to_url: str) -> str: return "\n".join(response_parts) async def send_responses(self, evt: MessageEvent, response: str, results: dict) -> None: - """Send responses or take actions based on config.""" - # Check if we should ignore SFW images - ignore_sfw = self.actions.get("ignore_sfw", False) - nsfw_results = [res for res in results.values() if res["Label"] == "NSFW"] - # If all images were SFW and should be ignored - if ignore_sfw and not nsfw_results: - self.log.info(f"Ignored SFW images in {evt.room_id}") - return - - # Direct reply in the same room - if self.actions.get("direct_reply", False): - await evt.reply(response) - self.log.info(f"Replied to {evt.room_id}") - - # Report to a specific room - if self.report_to_room: - try: - await self.client.send_text(room_id=RoomID(self.report_to_room), text=response) - self.log.info(f"Sent report to {RoomID(self.report_to_room)}") - except MBadJSON as e: - self.log.warning(f"Failed to send message to {RoomID(self.report_to_room)}: {e}") - - # Redact the message if it's NSFW and redacting is enabled - redact_nsfw = self.actions.get("redact_nsfw", False) - if nsfw_results and redact_nsfw: - try: - await self.client.redact(room_id=evt.room_id, event_id=evt.event_id, reason="NSFW") - self.log.info(f"Redacted NSFW message in {evt.room_id}") - except MForbidden: - self.log.warning(f"Failed to redact NSFW message in {evt.room_id}") + """ + Send responses or take actions based on config. + + :param evt: The message event. + :param response: The formatted response message. + :param results: Dictionary of results with MXC URLs as keys and predictions as values. + """ + try: + # Check if we should ignore SFW images + ignore_sfw = self.actions.get("ignore_sfw", False) + nsfw_results = [res for res in results.values() if res["Label"] == "NSFW"] + # If all images were SFW and should be ignored + if ignore_sfw and not nsfw_results: + self.log.info(f"Ignored SFW images in {evt.room_id}") + return + + # Direct reply in the same room + if self.actions.get("direct_reply", False): + await evt.reply(response) + self.log.info(f"Replied to {evt.room_id}") + + # Report to a specific room + if self.report_to_room: + try: + await self.client.send_text(room_id=RoomID(self.report_to_room), text=response) + self.log.info(f"Sent report to {RoomID(self.report_to_room)}") + except MBadJSON as e: + self.log.warning( + f"Failed to send message to {RoomID(self.report_to_room)}: {e}" + ) + + # Redact the message if it's NSFW and redacting is enabled + redact_nsfw = self.actions.get("redact_nsfw", False) + if nsfw_results and redact_nsfw: + try: + await self.client.redact( + room_id=evt.room_id, event_id=evt.event_id, reason="NSFW" + ) + self.log.info(f"Redacted NSFW message in {evt.room_id}") + except MForbidden: + self.log.warning(f"Failed to redact NSFW message in {evt.room_id}") + except Exception as e: + self.log.error(f"Error sending responses: {e}")