diff --git a/core/utils/emojis.py b/core/utils/emojis.py index 96d6b41..0da670b 100644 --- a/core/utils/emojis.py +++ b/core/utils/emojis.py @@ -31,6 +31,7 @@ "d12": "", "d20": "", "d100": "", + "ao3": "<:ao3:1229883149136433325>", } # fmt: on diff --git a/core/utils/misc.py b/core/utils/misc.py index 591c9f0..654b804 100644 --- a/core/utils/misc.py +++ b/core/utils/misc.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +import re import time import lxml.html @@ -37,6 +38,10 @@ def __exit__(self, *exc: object) -> None: self.logger.info("Time: %.3f seconds", self.total_time) +_BEFORE_WS = re.compile(r"^([\s]+)") +_AFTER_WS = re.compile(r"([\s]+)$") + + def html_to_markdown(node: lxml.html.HtmlElement, *, include_spans: bool = False, base_url: str | None = None) -> str: # Modified from RoboDanny code: # https://github.com/Rapptz/RoboDanny/blob/6e54be1985793ed29fca6b7c5259677904b8e1ad/cogs/dictionary.py#L532 @@ -48,26 +53,35 @@ def html_to_markdown(node: lxml.html.HtmlElement, *, include_spans: bool = False node.make_links_absolute("".join(base_url.partition(".com/wiki/")[0:-1]), resolve_base_href=True) for child in node.iter(): - child_text = child.text.strip() if child.text else "" + if child.text: + # Account for whitespace within a block that should be outside of it. + before_ws = _match.group() if (_match := _BEFORE_WS.search(child.text)) else "" + after_ws = _match.group() if (_match := _AFTER_WS.search(child.text)) else "" + child_text = child.text.strip() + else: + before_ws = after_ws = child_text = "" if child.tag in {"i", "em"}: - text.append(f"{italics_marker}{child_text}{italics_marker}") + text.append(f"{before_ws}{italics_marker}{child_text}{italics_marker}{after_ws}") if italics_marker == "*": # type: ignore italics_marker = "_" elif child.tag in {"b", "strong"}: if text and text[-1].endswith("*"): text.append("\u200b") - text.append(f"**{child_text.strip()}**") + text.append(f"{before_ws}**{child_text}**{after_ws}") elif child.tag == "a": # No markup for links if base_url is None: - text.append(child_text) + text.append(f"{before_ws}{child_text}{after_ws}") else: - text.append(f"[{child.text}]({child.attrib['href']})") + text.append(f"{before_ws}[{child.text}]({child.attrib['href']}){after_ws}") elif child.tag == "p": text.append(f"\n{child_text}\n") elif include_spans and child.tag == "span": - text.append(child_text) + if len(child) > 1: + text.append(f"{html_to_markdown(child, include_spans=True)}") + else: + text.append(f"{before_ws}{child_text}{after_ws}") if child.tail: text.append(child.tail) diff --git a/exts/story_search.py b/exts/story_search.py index c519563..a1d8a60 100644 --- a/exts/story_search.py +++ b/exts/story_search.py @@ -12,11 +12,10 @@ import sys import textwrap from bisect import bisect_left -from functools import lru_cache -from pathlib import Path -from typing import Any, ClassVar, Self +from typing import TYPE_CHECKING, ClassVar, Self import aiohttp +import async_lru import asyncpg import discord import lxml.html @@ -24,7 +23,7 @@ from discord.ext import commands import core -from core.utils import EMOJI_URL, PaginatedEmbedView +from core.utils import EMOJI_STOCK, EMOJI_URL, PaginatedEmbedView if sys.version_info >= (3, 12): @@ -33,24 +32,97 @@ import importlib_resources +if TYPE_CHECKING: + + def markdownify(html: str, **kwargs: object) -> str: ... +else: + from markdownify import markdownify + + LOGGER = logging.getLogger(__name__) +AO3_EMOJI = discord.PartialEmoji.from_str(EMOJI_STOCK["ao3"]) -@lru_cache + +@async_lru.alru_cache(ttl=300) async def get_ao3_html(session: aiohttp.ClientSession, url: str) -> lxml.html.HtmlElement | None: async with session.get(url) as response: text = await response.text() - element = lxml.html.fromstring(text) - download_btn = element.find(".//li[@class='download']//[li='HTML']") - if download_btn: - download_link = download_btn.attrib["href"] - if download_link: - async with session.get(url) as response: + + element = lxml.html.fromstring(text, "https://archiveofourown.org") + element.make_links_absolute() + + download_btn = element.find(".//li[@class='download']/ul/li/a[.='HTML']") + + if download_btn is not None: + download_link = download_btn.get("href") + if download_link is not None: + async with session.get(download_link) as response: story_text = await response.text() - return lxml.html.fromstring(story_text) + return lxml.html.fromstring(story_text, "https://archiveofourown.org") + return None +def find_keywords_in_ao3_text( + element: lxml.html.HtmlElement, + query: str, + url: str, +) -> tuple[StoryInfo, list[tuple[str, str]]]: + query = query.casefold() + title = title_el.text_content() if ((title_el := element.find(".//h1")) is not None) else "" + + if (author_el := element.find(".//a[@rel='author']")) is not None: # noqa: SIM108 + author = author_el.text_content() + else: + author = "Unknown" + + results: list[tuple[str, str]] = [] + + for chapter_div in element.findall(".//div[@id='chapters']/div[@class='userstuff']"): + for elder_sibling in chapter_div.itersiblings("div", preceding=True): + if elder_sibling.get("class") == "meta group": + if (chapter_title_el := chapter_div.find(".//h2[@class='heading']")) is not None: + chapter_title = chapter_title_el.text_content() + else: + chapter_title = "" + break + else: + chapter_title = "" + + for para in chapter_div.xpath(".//p"): + assert isinstance(para, lxml.html.HtmlElement) + if query in para.text_content().casefold(): + # Calling next(para.itersiblings()) or next(iter(para.itersiblings())) doesn't work, + # so one-iteration-for-loops it is. + para_before = para_after = None + + for before in para.itersiblings(preceding=True): + para_before = before + break + + for after in para.itersiblings(): + para_after = after + break + + # Could be one-lined with map and filter, but I find it less readable. + # html_to_markdown(elem, include_spans=True) + combined = "\n\n".join( + markdownify( + lxml.html.tostring(elem, method="html", encoding="unicode") + .replace("\n", "") + .replace("\n", "") + ).strip() + for elem in (para_before, para, para_after) + if elem is not None + ) + + combined = re.sub(re.escape(query), lambda m: f"__{m.group()}__", combined, flags=re.IGNORECASE) + results.append((chapter_title, combined)) + + return StoryInfo("NONE", title, author, url, AO3_EMOJI.id), results + + class StoryInfo(msgspec.Struct): """A class to hold all the information about each story.""" @@ -58,7 +130,7 @@ class StoryInfo(msgspec.Struct): name: str author: str link: str - emoji_id: int + emoji_id: int | None = None text: list[str] = msgspec.field(default_factory=list) chapter_index: list[int] = msgspec.field(default_factory=list) collection_index: list[int] = msgspec.field(default_factory=list) @@ -87,8 +159,16 @@ class StoryQuoteView(PaginatedEmbedView[tuple[str, str, str]]): The story's data and metadata, including full name, author name, and image representation. """ - def __init__(self, *args: Any, story_data: StoryInfo, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__( + self, + author_id: int, + pages_content: list[tuple[str, str, str]], + per: int = 1, + *, + story_data: StoryInfo, + timeout: float | None = 180, + ) -> None: + super().__init__(author_id, pages_content, per, timeout=timeout) self.story_data = story_data def format_page(self) -> discord.Embed: @@ -112,6 +192,59 @@ def format_page(self) -> discord.Embed: return embed_page +class AO3StoryQuoteView(PaginatedEmbedView[tuple[str, str]]): + """A subclass of `PaginatedEmbedView` that handles paginated embeds, specifically for quotes from a story. + + Parameters + ---------- + *args + Positional arguments the normal initialization of an `PaginatedEmbedView`. See that class for more info. + story_data: StoryInfo + The story's data and metadata, including full name, author name, and image representation. + **kwargs + Keyword arguments the normal initialization of an `PaginatedEmbedView`. See that class for more info. + + Attributes + ---------- + story_data: StoryInfo + The story's data and metadata, including full name, author name, and image representation. + """ + + def __init__( + self, + author_id: int, + pages_content: list[tuple[str, str]], + per: int = 1, + *, + story_data: StoryInfo, + timeout: float | None = 180, + ) -> None: + super().__init__(author_id, pages_content, per, timeout=timeout) + self.story_data = story_data + + def format_page(self) -> discord.Embed: + """Makes, or retrieves from the cache, the quote embed 'page' that the user will see. + + Assumes a per_page value of 1. + """ + + name, url, emoji_id = self.story_data.name, self.story_data.link, self.story_data.emoji_id + icon_url = EMOJI_URL.format(emoji_id) if emoji_id else None + embed_page = discord.Embed(color=0x149CDF).set_author(name=name, url=url, icon_url=icon_url) + + if self.total_pages == 0: + embed_page.add_field(name="N/A", value="N/A").set_footer(text="Page 0/0").title = "No quotes found!" + else: + # per_page value of 1 means parsing a list of length 1. + content = self.pages[self.page_index] + for chapter_name, quote in content: + embed_page.title = chapter_name + embed_page.description = quote + embed_page.set_footer(text=f"Page {self.page_index + 1}/{self.total_pages}") + + return embed_page + + class StorySearchCog(commands.Cog, name="Quote Search"): """A cog with commands for people to search the text of some ACI100 books while in Discord. @@ -140,12 +273,12 @@ def cog_emoji(self) -> discord.PartialEmoji: async def cog_load(self) -> None: """Load whatever is necessary to avoid reading from files or querying the database during runtime.""" - # Load story text from markdown files. + # Load story text from markdown files. Directory structure is known. data_dir = importlib_resources.files("data.story_text") - with importlib_resources.as_file(data_dir) as data_path: # type: ignore - assert isinstance(data_path, Path) # Wouldn't be necessary if as_file were typed better. - for file in data_path.glob("**/*text.md"): - await self.load_story_text(file) + for author_works in data_dir.iterdir(): + for work in author_works.iterdir(): + if work.is_file() and work.name.endswith("text.md"): + self.load_story_text(work) async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing # Extract the original error. @@ -156,7 +289,7 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: LOGGER.exception("", exc_info=error) @classmethod - async def load_story_text(cls, filepath: Path) -> None: + def load_story_text(cls, filepath: importlib_resources.abc.Traversable) -> None: """Load the story metadata and text.""" # Compile all necessary regex patterns. @@ -170,8 +303,13 @@ async def load_story_text(cls, filepath: Path) -> None: # Start file copying and indexing. with filepath.open("r", encoding="utf-8") as story_file: + if TYPE_CHECKING: + import io + + assert isinstance(story_file, io.TextIOWrapper) + # Instantiate index lists, which act as a table of contents of sorts. - stem = str(filepath.stem)[:-5] + stem = str(filepath.name)[:-8] temp_text = cls.story_records[stem].text = [line for line in story_file if line.strip()] temp_chap_index = cls.story_records[stem].chapter_index temp_coll_index = cls.story_records[stem].collection_index @@ -205,7 +343,7 @@ async def load_story_text(cls, filepath: Path) -> None: ): temp_coll_index.append(index) - LOGGER.info("Loaded file: %s", filepath.stem) + LOGGER.info("Loaded file: %s", filepath.name.removesuffix(".md")) @classmethod def process_text(cls, story: str, terms: str, exact: bool = True) -> list[tuple[str, str, str]]: @@ -341,22 +479,21 @@ async def search_cadmean(self, ctx: core.Context, *, query: str) -> None: view.message = message @commands.hybrid_command() - async def search_ao3_link(self, ctx: core.Context, url: str, query: str) -> None: - """Search the text of an ao3 link.""" - - element = await get_ao3_html(ctx.session, url) - - title = title_el.text if (element is not None and (title_el := element.find("h1"))) else "" - - if element is not None: - results: list[tuple[str, str | None, str]] = [] - for div in element.iter("div[@id='chapters']/div[@class='userstuff']"): - for para in div.iter("p"): - if para.text and (query.casefold() in para.text.casefold()): - header = next( - sibling.findtext("h2") for sibling in div.itersiblings("div[@class='meta group']") - ) - results.append((title or "", header, para.text)) + async def search_ao3(self, ctx: core.Context, url: str, query: str) -> None: + """Search the text of an AO3 story.""" + + async with ctx.typing(): + element = await get_ao3_html(ctx.session, url) + + if element is not None: + story_metadata, search_results = find_keywords_in_ao3_text(element, query, url) + else: + story_metadata = StoryInfo("NONE", "Unknown", "Unknown", url, AO3_EMOJI.id) + search_results = [] + + view = AO3StoryQuoteView(ctx.author.id, search_results, story_data=story_metadata) + message = await ctx.send(embed=await view.get_first_page(), view=view) + view.message = message async def setup(bot: core.Beira) -> None: diff --git a/requirements.txt b/requirements.txt index 3c52d9e..a075496 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ fichub-api @ https://github.com/Sachaa-Thanasius/fichub-api/releases/download/v0 importlib_resources; python_version < "3.12" jishaku @ git+https://github.com/Gorialis/jishaku@a6661e2813124fbfe53326913e54f7c91e5d0dec lxml>=4.9.3 +markdownify msgspec[toml] openpyxl Pillow>=10.0.0