Skip to content

Commit

Permalink
Try the search_ao3 feature
Browse files Browse the repository at this point in the history
  • Loading branch information
Sachaa-Thanasius committed Apr 16, 2024
1 parent 5efd7ad commit 4d13d9a
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 45 deletions.
1 change: 1 addition & 0 deletions core/utils/emojis.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"d12": "<a:d12:1109234528431636672>",
"d20": "<a:d20:1109234550707593346>",
"d100": "<a:d100:1109960365967687841>",
"ao3": "<:ao3:1229883149136433325>",
}
# fmt: on

Expand Down
26 changes: 20 additions & 6 deletions core/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import logging
import re
import time

import lxml.html
Expand Down Expand Up @@ -37,6 +38,10 @@ def __exit__(self, *exc: object) -> None:
self.logger.info("Time: %.3f seconds", self.total_time)


_BEFORE_WS = re.compile(r"^([\s]+)")
_AFTER_WS = re.compile(r"([\s]+)$")


def html_to_markdown(node: lxml.html.HtmlElement, *, include_spans: bool = False, base_url: str | None = None) -> str:
# Modified from RoboDanny code:
# https://github.com/Rapptz/RoboDanny/blob/6e54be1985793ed29fca6b7c5259677904b8e1ad/cogs/dictionary.py#L532
Expand All @@ -48,26 +53,35 @@ def html_to_markdown(node: lxml.html.HtmlElement, *, include_spans: bool = False
node.make_links_absolute("".join(base_url.partition(".com/wiki/")[0:-1]), resolve_base_href=True)

for child in node.iter():
child_text = child.text.strip() if child.text else ""
if child.text:
# Account for whitespace within a block that should be outside of it.
before_ws = _match.group() if (_match := _BEFORE_WS.search(child.text)) else ""
after_ws = _match.group() if (_match := _AFTER_WS.search(child.text)) else ""
child_text = child.text.strip()
else:
before_ws = after_ws = child_text = ""

if child.tag in {"i", "em"}:
text.append(f"{italics_marker}{child_text}{italics_marker}")
text.append(f"{before_ws}{italics_marker}{child_text}{italics_marker}{after_ws}")
if italics_marker == "*": # type: ignore
italics_marker = "_"
elif child.tag in {"b", "strong"}:
if text and text[-1].endswith("*"):
text.append("\u200b")
text.append(f"**{child_text.strip()}**")
text.append(f"{before_ws}**{child_text}**{after_ws}")
elif child.tag == "a":
# No markup for links
if base_url is None:
text.append(child_text)
text.append(f"{before_ws}{child_text}{after_ws}")
else:
text.append(f"[{child.text}]({child.attrib['href']})")
text.append(f"{before_ws}[{child.text}]({child.attrib['href']}){after_ws}")
elif child.tag == "p":
text.append(f"\n{child_text}\n")
elif include_spans and child.tag == "span":
text.append(child_text)
if len(child) > 1:
text.append(f"{html_to_markdown(child, include_spans=True)}")
else:
text.append(f"{before_ws}{child_text}{after_ws}")

if child.tail:
text.append(child.tail)
Expand Down
215 changes: 176 additions & 39 deletions exts/story_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@
import sys
import textwrap
from bisect import bisect_left
from functools import lru_cache
from pathlib import Path
from typing import Any, ClassVar, Self
from typing import TYPE_CHECKING, ClassVar, Self

import aiohttp
import async_lru
import asyncpg
import discord
import lxml.html
import msgspec
from discord.ext import commands

import core
from core.utils import EMOJI_URL, PaginatedEmbedView
from core.utils import EMOJI_STOCK, EMOJI_URL, PaginatedEmbedView


if sys.version_info >= (3, 12):
Expand All @@ -33,32 +32,105 @@
import importlib_resources


if TYPE_CHECKING:

def markdownify(html: str, **kwargs: object) -> str: ...
else:
from markdownify import markdownify


LOGGER = logging.getLogger(__name__)

AO3_EMOJI = discord.PartialEmoji.from_str(EMOJI_STOCK["ao3"])

@lru_cache

@async_lru.alru_cache(ttl=300)
async def get_ao3_html(session: aiohttp.ClientSession, url: str) -> lxml.html.HtmlElement | None:
async with session.get(url) as response:
text = await response.text()
element = lxml.html.fromstring(text)
download_btn = element.find(".//li[@class='download']//[li='HTML']")
if download_btn:
download_link = download_btn.attrib["href"]
if download_link:
async with session.get(url) as response:

element = lxml.html.fromstring(text, "https://archiveofourown.org")
element.make_links_absolute()

download_btn = element.find(".//li[@class='download']/ul/li/a[.='HTML']")

if download_btn is not None:
download_link = download_btn.get("href")
if download_link is not None:
async with session.get(download_link) as response:
story_text = await response.text()
return lxml.html.fromstring(story_text)
return lxml.html.fromstring(story_text, "https://archiveofourown.org")

return None


def find_keywords_in_ao3_text(
element: lxml.html.HtmlElement,
query: str,
url: str,
) -> tuple[StoryInfo, list[tuple[str, str]]]:
query = query.casefold()
title = title_el.text_content() if ((title_el := element.find(".//h1")) is not None) else ""

if (author_el := element.find(".//a[@rel='author']")) is not None: # noqa: SIM108
author = author_el.text_content()
else:
author = "Unknown"

results: list[tuple[str, str]] = []

for chapter_div in element.findall(".//div[@id='chapters']/div[@class='userstuff']"):
for elder_sibling in chapter_div.itersiblings("div", preceding=True):
if elder_sibling.get("class") == "meta group":
if (chapter_title_el := chapter_div.find(".//h2[@class='heading']")) is not None:
chapter_title = chapter_title_el.text_content()
else:
chapter_title = ""
break
else:
chapter_title = ""

for para in chapter_div.xpath(".//p"):
assert isinstance(para, lxml.html.HtmlElement)
if query in para.text_content().casefold():
# Calling next(para.itersiblings()) or next(iter(para.itersiblings())) doesn't work,
# so one-iteration-for-loops it is.
para_before = para_after = None

for before in para.itersiblings(preceding=True):
para_before = before
break

for after in para.itersiblings():
para_after = after
break

# Could be one-lined with map and filter, but I find it less readable.
# html_to_markdown(elem, include_spans=True)
combined = "\n\n".join(
markdownify(
lxml.html.tostring(elem, method="html", encoding="unicode")
.replace("\n<span>", "")
.replace("</span>\n", "")
).strip()
for elem in (para_before, para, para_after)
if elem is not None
)

combined = re.sub(re.escape(query), lambda m: f"__{m.group()}__", combined, flags=re.IGNORECASE)
results.append((chapter_title, combined))

return StoryInfo("NONE", title, author, url, AO3_EMOJI.id), results


class StoryInfo(msgspec.Struct):
"""A class to hold all the information about each story."""

acronym: str
name: str
author: str
link: str
emoji_id: int
emoji_id: int | None = None
text: list[str] = msgspec.field(default_factory=list)
chapter_index: list[int] = msgspec.field(default_factory=list)
collection_index: list[int] = msgspec.field(default_factory=list)
Expand Down Expand Up @@ -87,8 +159,16 @@ class StoryQuoteView(PaginatedEmbedView[tuple[str, str, str]]):
The story's data and metadata, including full name, author name, and image representation.
"""

def __init__(self, *args: Any, story_data: StoryInfo, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def __init__(
self,
author_id: int,
pages_content: list[tuple[str, str, str]],
per: int = 1,
*,
story_data: StoryInfo,
timeout: float | None = 180,
) -> None:
super().__init__(author_id, pages_content, per, timeout=timeout)
self.story_data = story_data

def format_page(self) -> discord.Embed:
Expand All @@ -112,6 +192,59 @@ def format_page(self) -> discord.Embed:
return embed_page


class AO3StoryQuoteView(PaginatedEmbedView[tuple[str, str]]):
"""A subclass of `PaginatedEmbedView` that handles paginated embeds, specifically for quotes from a story.

Parameters
----------
*args
Positional arguments the normal initialization of an `PaginatedEmbedView`. See that class for more info.
story_data: StoryInfo
The story's data and metadata, including full name, author name, and image representation.
**kwargs
Keyword arguments the normal initialization of an `PaginatedEmbedView`. See that class for more info.

Attributes
----------
story_data: StoryInfo
The story's data and metadata, including full name, author name, and image representation.
"""

def __init__(
self,
author_id: int,
pages_content: list[tuple[str, str]],
per: int = 1,
*,
story_data: StoryInfo,
timeout: float | None = 180,
) -> None:
super().__init__(author_id, pages_content, per, timeout=timeout)
self.story_data = story_data

def format_page(self) -> discord.Embed:
"""Makes, or retrieves from the cache, the quote embed 'page' that the user will see.

Assumes a per_page value of 1.
"""

name, url, emoji_id = self.story_data.name, self.story_data.link, self.story_data.emoji_id
icon_url = EMOJI_URL.format(emoji_id) if emoji_id else None
embed_page = discord.Embed(color=0x149CDF).set_author(name=name, url=url, icon_url=icon_url)

if self.total_pages == 0:
embed_page.add_field(name="N/A", value="N/A").set_footer(text="Page 0/0").title = "No quotes found!"
else:
# per_page value of 1 means parsing a list of length 1.
content = self.pages[self.page_index]
for chapter_name, quote in content:
embed_page.title = chapter_name
embed_page.description = quote
embed_page.set_footer(text=f"Page {self.page_index + 1}/{self.total_pages}")

return embed_page


class StorySearchCog(commands.Cog, name="Quote Search"):
"""A cog with commands for people to search the text of some ACI100 books while in Discord.

Expand Down Expand Up @@ -140,12 +273,12 @@ def cog_emoji(self) -> discord.PartialEmoji:
async def cog_load(self) -> None:
"""Load whatever is necessary to avoid reading from files or querying the database during runtime."""

# Load story text from markdown files.
# Load story text from markdown files. Directory structure is known.
data_dir = importlib_resources.files("data.story_text")
with importlib_resources.as_file(data_dir) as data_path: # type: ignore
assert isinstance(data_path, Path) # Wouldn't be necessary if as_file were typed better.
for file in data_path.glob("**/*text.md"):
await self.load_story_text(file)
for author_works in data_dir.iterdir():
for work in author_works.iterdir():
if work.is_file() and work.name.endswith("text.md"):
self.load_story_text(work)

async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing
# Extract the original error.
Expand All @@ -156,7 +289,7 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None:
LOGGER.exception("", exc_info=error)

@classmethod
async def load_story_text(cls, filepath: Path) -> None:
def load_story_text(cls, filepath: importlib_resources.abc.Traversable) -> None:
"""Load the story metadata and text."""

# Compile all necessary regex patterns.
Expand All @@ -170,8 +303,13 @@ async def load_story_text(cls, filepath: Path) -> None:

# Start file copying and indexing.
with filepath.open("r", encoding="utf-8") as story_file:

Check failure on line 305 in exts/story_search.py

View workflow job for this annotation

GitHub Actions / Type Coverage and Linting @ 3.10

Type of "open" is partially unknown   Type of "open" is "(mode: str = 'r', *args: Unknown, **kwargs: Unknown) -> None" (reportUnknownMemberType)

Check failure on line 305 in exts/story_search.py

View workflow job for this annotation

GitHub Actions / Type Coverage and Linting @ 3.10

Object of type "None" cannot be used with "with" because it does not implement __enter__   Member "__enter__" is unknown (reportGeneralTypeIssues)

Check failure on line 305 in exts/story_search.py

View workflow job for this annotation

GitHub Actions / Type Coverage and Linting @ 3.10

Object of type "None" cannot be used with "with" because it does not implement __exit__ (reportGeneralTypeIssues)

Check failure on line 305 in exts/story_search.py

View workflow job for this annotation

GitHub Actions / Type Coverage and Linting @ 3.10

Type of "story_file" is unknown (reportUnknownVariableType)
if TYPE_CHECKING:
import io

assert isinstance(story_file, io.TextIOWrapper)

# Instantiate index lists, which act as a table of contents of sorts.
stem = str(filepath.stem)[:-5]
stem = str(filepath.name)[:-8]
temp_text = cls.story_records[stem].text = [line for line in story_file if line.strip()]
temp_chap_index = cls.story_records[stem].chapter_index
temp_coll_index = cls.story_records[stem].collection_index
Expand Down Expand Up @@ -205,7 +343,7 @@ async def load_story_text(cls, filepath: Path) -> None:
):
temp_coll_index.append(index)

LOGGER.info("Loaded file: %s", filepath.stem)
LOGGER.info("Loaded file: %s", filepath.name.removesuffix(".md"))

@classmethod
def process_text(cls, story: str, terms: str, exact: bool = True) -> list[tuple[str, str, str]]:
Expand Down Expand Up @@ -341,22 +479,21 @@ async def search_cadmean(self, ctx: core.Context, *, query: str) -> None:
view.message = message

@commands.hybrid_command()
async def search_ao3_link(self, ctx: core.Context, url: str, query: str) -> None:
"""Search the text of an ao3 link."""

element = await get_ao3_html(ctx.session, url)

title = title_el.text if (element is not None and (title_el := element.find("h1"))) else ""

if element is not None:
results: list[tuple[str, str | None, str]] = []
for div in element.iter("div[@id='chapters']/div[@class='userstuff']"):
for para in div.iter("p"):
if para.text and (query.casefold() in para.text.casefold()):
header = next(
sibling.findtext("h2") for sibling in div.itersiblings("div[@class='meta group']")
)
results.append((title or "", header, para.text))
async def search_ao3(self, ctx: core.Context, url: str, query: str) -> None:
"""Search the text of an AO3 story."""

async with ctx.typing():
element = await get_ao3_html(ctx.session, url)

if element is not None:
story_metadata, search_results = find_keywords_in_ao3_text(element, query, url)
else:
story_metadata = StoryInfo("NONE", "Unknown", "Unknown", url, AO3_EMOJI.id)
search_results = []

view = AO3StoryQuoteView(ctx.author.id, search_results, story_data=story_metadata)
message = await ctx.send(embed=await view.get_first_page(), view=view)
view.message = message


async def setup(bot: core.Beira) -> None:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ fichub-api @ https://github.com/Sachaa-Thanasius/fichub-api/releases/download/v0
importlib_resources; python_version < "3.12"
jishaku @ git+https://github.com/Gorialis/jishaku@a6661e2813124fbfe53326913e54f7c91e5d0dec
lxml>=4.9.3
markdownify
msgspec[toml]
openpyxl
Pillow>=10.0.0
Expand Down

0 comments on commit 4d13d9a

Please sign in to comment.