Skip to content

Commit

Permalink
A few things
Browse files Browse the repository at this point in the history
- Remove dependency on openai by making direct post requests
 - It's not as though the API has a free tier anymore; consider removal of AI commands entirely?
- Remove usages of `__file__`
 - Just seems like best practice to not use it.
 - Gained importlib_resources dependency
- Give timing cog an empty setup to avoid errors during startup for now.
  • Loading branch information
Sachaa-Thanasius committed Mar 19, 2024
1 parent 1caa6ad commit f0ac95b
Show file tree
Hide file tree
Showing 12 changed files with 78 additions and 51 deletions.
3 changes: 0 additions & 3 deletions core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import atlas_api
import discord
import fichub_api
import openai
import wavelink
from discord.ext import commands

Expand Down Expand Up @@ -75,8 +74,6 @@ def __init__(
self.fichub_client = fichub_api.Client(session=self.web_session)
self.ao3_client = ao3.Client(session=self.web_session)

self.openai_client = openai.AsyncOpenAI(api_key=CONFIG.openai.key)

# Things to load before connecting to the Gateway.
self.prefix_cache: dict[int, list[str]] = {}
self.blocked_entities_cache: dict[str, set[int]] = {}
Expand Down
4 changes: 3 additions & 1 deletion core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"""

from discord import app_commands
from discord.app_commands.commands import Check as AppCheckFunc
from discord.ext import commands


AppCheckFunc = app_commands.commands.Check


__all__ = (
"CannotTargetSelf",
"NotOwnerOrFriend",
Expand Down
6 changes: 3 additions & 3 deletions exts/ai_generation/ai_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def morph_user(self, target: discord.User, prompt: str) -> tuple[str, Byte
with Image.open(avatar_buffer) as avatar_image:
file_size = avatar_image.size

ai_url = await create_image(self.bot.openai_client, prompt, file_size)
ai_url = await create_image(self.bot.web_session, prompt, file_size)
ai_bytes = await get_image(self.bot.web_session, ai_url)
ai_buffer = await asyncio.to_thread(process_image, ai_bytes)
gif_buffer = await create_morph(avatar_buffer, ai_buffer)
Expand Down Expand Up @@ -218,7 +218,7 @@ async def generate(

if generation_type == "image":
log_start_time = perf_counter()
ai_url = await create_image(self.bot.openai_client, prompt, (512, 512))
ai_url = await create_image(ctx.session, prompt, (512, 512))
ai_bytes = await get_image(ctx.session, ai_url)
ai_buffer = await asyncio.to_thread(process_image, ai_bytes)
creation_time = perf_counter() - log_start_time
Expand All @@ -239,7 +239,7 @@ async def generate(

elif generation_type == "text":
log_start_time = perf_counter()
ai_text = await create_completion(self.bot.openai_client, prompt)
ai_text = await create_completion(ctx.session, prompt)
creation_time = perf_counter() - log_start_time

# Send the generated image in an embed.
Expand Down
61 changes: 39 additions & 22 deletions exts/ai_generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from pathlib import Path

import aiohttp
import openai
from PIL import Image

import core


__all__ = (
"get_image",
Expand Down Expand Up @@ -58,7 +59,7 @@ def process_image(image_bytes: bytes) -> BytesIO:
return output_buffer


async def create_completion(client: openai.AsyncOpenAI, prompt: str) -> str:
async def create_completion(session: aiohttp.ClientSession, prompt: str) -> str:
"""Makes a call to OpenAI's API to generate text based on given input.
Parameters
Expand All @@ -69,19 +70,27 @@ async def create_completion(client: openai.AsyncOpenAI, prompt: str) -> str:
Returns
-------
text: :class:`str`
The generated text completion.
The generated text completion, or an empty string if it failed.
"""

completion_response = await client.completions.create(
model="text-davinci-003",
prompt=prompt,
max_tokens=150,
temperature=0,
)
return completion_response.choices[0].text


async def create_image(client: openai.AsyncOpenAI, prompt: str, size: tuple[int, int] = (256, 256)) -> str:
async with session.post(
"https://api.openai.com/v1/completions",
headers={"Authorization": f"Bearer {core.CONFIG}"},
json={
"model": "gpt-3.5-turbo-instruct",
"prompt": prompt,
"max_tokens": 150,
"temperature": 0,
},
) as response:
result = await response.json()
try:
return result["choices"][0]["text"] or ""
except (KeyError, IndexError):
return ""


async def create_image(session: aiohttp.ClientSession, prompt: str, size: tuple[int, int] = (256, 256)) -> str:
"""Makes a call to OpenAI's API to generate an image based on given inputs.
Parameters
Expand All @@ -94,17 +103,25 @@ async def create_image(client: openai.AsyncOpenAI, prompt: str, size: tuple[int,
Returns
-------
url: :class:`str`
The url of the generated image.
The url of the generated image, or an empty string if it failed.
"""

image_response = await client.images.generate(
prompt=prompt,
n=1,
response_format="url",
size=f"{size[0]}x{size[1]}", # type: ignore # FIXME: Find a way to pass in a literal.
)

return image_response.data[0].url or ""
async with session.post(
"https://api.openai.com/v1/images/generations",
headers={"Authorization": f"Bearer {core.CONFIG}"},
json={
"model": "dall-e-2",
"prompt": prompt,
"n": 1,
"size": f"{size[0]}x{size[1]}", # FIXME: Find a way to pass in a literal.
"response_format": "url",
},
) as response:
result = await response.json()
try:
return result["data"][0]["url"] or ""
except (KeyError, IndexError):
return ""


async def create_inspiration(session: aiohttp.ClientSession) -> str:
Expand Down
2 changes: 1 addition & 1 deletion exts/ff_metadata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def html_to_markdown(node: lxml.html.HtmlElement, *, include_spans: bool = False

if child.tag in {"i", "em"}:
text.append(f"{italics_marker}{child_text}{italics_marker}")
italics_marker = "_" if italics_marker == "*" else "*"
italics_marker = "_" if italics_marker == "*" else "*" # type: ignore
elif child.tag in {"b", "strong"}:
if text and text[-1].endswith("*"):
text.append("\u200b")
Expand Down
4 changes: 2 additions & 2 deletions exts/lol.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@

LOGGER = logging.getLogger(__name__)

GECKODRIVER = Path(__file__).parents[1].joinpath("drivers/geckodriver/geckodriver.exe")
GECKODRIVER_LOGS = Path(__file__).parents[1].joinpath("logs/geckodriver.log")
GECKODRIVER = Path().resolve().joinpath("drivers/geckodriver/geckodriver.exe")
GECKODRIVER_LOGS = Path().resolve().joinpath("logs/geckodriver.log")

PERSONAL_GUILD = 107584745809944576

Expand Down
2 changes: 1 addition & 1 deletion exts/notifications/aci_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,5 +184,5 @@ def make_listeners(bot: core.Beira) -> tuple[tuple[str, functools.partial[Any]],
return (
("on_member_update", functools.partial(on_leveled_role_member_update, bot, role_log_webhook)),
("on_member_update", functools.partial(on_server_boost_role_member_update, bot, role_log_webhook)),
# ("on_message", functools.partial(on_bad_twitter_link, bot)), # Twitter got their shit together.
# ("on_message", functools.partial(on_bad_twitter_link, bot)), # Twitter works. # noqa: ERA001
)
2 changes: 1 addition & 1 deletion exts/notifications/rss_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def check_url(self, rec: NotificationRecord) -> str | None:

def process_new_item(self, text: str) -> discord.Embed:
"""Turn new item/update into a nicely formatted discord Embed."""
...
... # noqa: PIE790

@tasks.loop(seconds=10)
async def notification_check_loop(self) -> None:
Expand Down
13 changes: 10 additions & 3 deletions exts/story_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import random
import re
import sys
import textwrap
from bisect import bisect_left
from functools import lru_cache
Expand All @@ -26,6 +27,11 @@
from core.utils import EMOJI_URL, PaginatedEmbedView


if sys.version_info >= (3, 12):
from importlib import resources as importlib_resources
else:
import importlib_resources

if TYPE_CHECKING:
from typing_extensions import Self
else:
Expand Down Expand Up @@ -139,9 +145,10 @@ async def cog_load(self) -> None:
"""Load whatever is necessary to avoid reading from files or querying the database during runtime."""

# Load story text from markdown files.
project_path = Path(__file__).resolve().parents[1]
for file in project_path.glob("data/story_text/**/*.md"):
if "text" in file.name:
data_dir = importlib_resources.files("data.story_text")
with importlib_resources.as_file(data_dir) as data_path: # type: ignore
assert isinstance(data_path, Path) # Wouldn't be necessary if as_file were typed better.
for file in data_path.glob("**/*text.md"):
await self.load_story_text(file)

async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing
Expand Down
6 changes: 3 additions & 3 deletions exts/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async def timezone_autocomplete(


# TODO: Complete and enable later.
# async def setup(bot: core.Beira) -> None:
# """Connects cog to bot."""
async def setup(bot: core.Beira) -> None:
"""Connects cog to bot."""

# await bot.add_cog(TimingCog(bot))
# await bot.add_cog(TimingCog(bot)) # noqa: ERA001
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ description = "An personal Discord bot made in Python."
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
authors = [{ name = "Sachaa-Thanasius", email = "111999343+Sachaa-Thanasius@users.noreply.github.com" }]
authors = [
{ name = "Sachaa-Thanasius", email = "111999343+Sachaa-Thanasius@users.noreply.github.com" },
]

[project.urls]
Homepage = "https://github.com/Sachaa-Thanasius/Beira"
"Bug Tracker" = "https://github.com/Sachaa-Thanasius/Beira/issues"

[tool.ruff]
include = ["main.py", "core/*.py", "exts/*.py", "typings/*.*", "**/pyproject.toml"]
include = ["main.py", "core/*", "exts/*", "**/pyproject.toml"]
line-length = 120
target-version = "py310"

Expand Down Expand Up @@ -75,12 +76,11 @@ extend-ignore = [
"ISC002",
]
unfixable = [
"ERA", # I don't want anything erroneously detected deleted by this.
"ERA", # Disallow erroneous detection into deletion.
]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F403", "PLC0414"] # Star import usually throws these.
"typings/parsedatetime/parsedatetime.pyi" = ["F403"] # Star import usually throws these.
"misc/**" = [
"T201", # Leave prints alone.
"ERA001", # Leave commented code alone.
Expand All @@ -91,7 +91,7 @@ lines-after-imports = 2
combine-as-imports = true

[tool.pyright]
include = ["main.py", "core", "exts", "typings"]
include = ["main.py", "core", "exts"]
pythonVersion = "3.10"
typeCheckingMode = "strict"

Expand Down
14 changes: 9 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
ao3.py @ git+https://github.com/Sachaa-Thanasius/ao3.py.git
ao3.py @ git+https://github.com/Sachaa-Thanasius/ao3.py.git@main
arsenic
async-lru
asyncpg==0.29.0
asyncpg-stubs==0.29.1
atlas-api @ git+https://github.com/Sachaa-Thanasius/atlas-api-wrapper.git
atlas-api @ https://github.com/Sachaa-Thanasius/atlas-api-wrapper/releases/download/v0.2.2/atlas_api-0.2.2-py3-none-any.whl
discord.py[speed,voice]>=2.3.2
fichub-api @ git+https://github.com/Sachaa-Thanasius/fichub-api.git
fichub-api @ https://github.com/Sachaa-Thanasius/fichub-api/releases/download/v0.2.2/fichub_api-0.2.2-py3-none-any.whl
importlib_resources; python_version < "3.12"
jishaku @ git+https://github.com/Gorialis/jishaku@a6661e2813124fbfe53326913e54f7c91e5d0dec
lxml>=4.9.3
types-lxml
msgspec[toml]
openai>=1.8.0
parsedatetime-stubs @ git+https://github.com/Sachaa-Thanasius/parsedatetime-stubs
Pillow>=10.0.0
wavelink>=3.0.0

# To be used later:
# parsedatetime
# parsedatetime-stubs @ git+https://github.com/Sachaa-Thanasius/parsedatetime-stubs
# python-dateutil

0 comments on commit f0ac95b

Please sign in to comment.