Skip to content

Commit

Permalink
Update openai dependency and code.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sachaa-Thanasius committed Nov 7, 2023
1 parent 2558d98 commit 7dbda8b
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 41 deletions.
3 changes: 3 additions & 0 deletions core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import atlas_api
import discord
import fichub_api
import openai
import wavelink
from discord.ext import commands
from wavelink.ext import spotify
Expand Down Expand Up @@ -73,6 +74,8 @@ def __init__(
self.fichub_client = fichub_api.Client(session=self.web_session)
self.ao3_client = ao3.Client(session=self.web_session)

self.openai_client = openai.AsyncOpenAI(api_key=CONFIG.openai.key)

# Things to load before connecting to the Gateway.
self.prefix_cache: dict[int, list[str]] = {}
self.blocked_entities_cache: dict[str, set[int]] = {}
Expand Down
7 changes: 4 additions & 3 deletions exts/ai_generation/ai_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ async def morph_user(self, target: discord.User, prompt: str) -> tuple[str, Byte
avatar_buffer = BytesIO()
await target.display_avatar.replace(size=256, format="png", static_format="png").save(avatar_buffer)

# Verify what the new size is.
with Image.open(avatar_buffer) as avatar_image:
file_size = avatar_image.size

ai_url = await create_image(prompt, file_size)
ai_url = await create_image(self.bot.openai_client, prompt, file_size)
ai_bytes = await get_image(self.bot.web_session, ai_url)
ai_buffer = await asyncio.to_thread(process_image, ai_bytes)
gif_buffer = await create_morph(avatar_buffer, ai_buffer)
Expand Down Expand Up @@ -217,7 +218,7 @@ async def generate(

if generation_type == "image":
log_start_time = perf_counter()
ai_url = await create_image(prompt, (512, 512))
ai_url = await create_image(self.bot.openai_client, prompt, (512, 512))
ai_bytes = await get_image(ctx.session, ai_url)
ai_buffer = await asyncio.to_thread(process_image, ai_bytes)
creation_time = perf_counter() - log_start_time
Expand All @@ -238,7 +239,7 @@ async def generate(

elif generation_type == "text":
log_start_time = perf_counter()
ai_text = await create_completion(prompt)
ai_text = await create_completion(self.bot.openai_client, prompt)
creation_time = perf_counter() - log_start_time

# Send the generated image in an embed.
Expand Down
64 changes: 34 additions & 30 deletions exts/ai_generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@


__all__ = (
"temp_file_names",
"get_image",
"process_image",
"create_completion",
"create_image",
"create_inspiration",
"temp_file_names",
"create_morph",
)

Expand All @@ -27,28 +27,6 @@
INSPIROBOT_API_URL = "https://inspirobot.me/api"


@asynccontextmanager
async def temp_file_names(*extensions: str) -> AsyncGenerator[tuple[Path, ...], None]:
"""Create temporary filesystem paths to generated filenames in a temporary folder.
Upon completion, the folder is removed.
Parameters
----------
*extensions: tuple[:class:`str`]
The file extensions that the generated filenames should have, e.g. py, txt, doc.
Yields
------
temp_paths: tuple[:class:`Path`]
Filepaths with random filenames with the given file extensions, in order.
"""

async with aiofiles.tempfile.TemporaryDirectory() as temp_dir:
temp_paths = tuple(Path(temp_dir).joinpath(f"temp_output{i}." + ext) for i, ext in enumerate(extensions))
yield temp_paths


async def get_image(session: aiohttp.ClientSession, url: str) -> bytes:
"""Asynchronously load the bytes of an image from a url.
Expand Down Expand Up @@ -80,7 +58,7 @@ def process_image(image_bytes: bytes) -> BytesIO:
return output_buffer


async def create_completion(prompt: str) -> str:
async def create_completion(client: openai.AsyncOpenAI, prompt: str) -> str:
"""Makes a call to OpenAI's API to generate text based on given input.
Parameters
Expand All @@ -94,16 +72,16 @@ async def create_completion(prompt: str) -> str:
The generated text completion.
"""

completion_response = await openai.Completion.acreate( # type: ignore # Possible args are weird.
prompt=prompt,
completion_response = await client.completions.create(
model="text-davinci-003",
prompt=prompt,
max_tokens=150,
temperature=0,
)
return completion_response.choices[0].text # type: ignore
return completion_response.choices[0].text


async def create_image(prompt: str, size: tuple[int, int] = (256, 256)) -> str:
async def create_image(client: openai.AsyncOpenAI, prompt: str, size: tuple[int, int] = (256, 256)) -> str:
"""Makes a call to OpenAI's API to generate an image based on given inputs.
Parameters
Expand All @@ -119,12 +97,16 @@ async def create_image(prompt: str, size: tuple[int, int] = (256, 256)) -> str:
The url of the generated image.
"""

image_response = await openai.Image.acreate( # type: ignore # Possible args are weird.
image_response = await client.images.generate(
prompt=prompt,
n=1,
response_format="url",
size=f"{size[0]}x{size[1]}",

Check failure on line 104 in exts/ai_generation/utils.py

View workflow job for this annotation

GitHub Actions / Type Coverage and Linting @ 3.10

Argument of type "str" cannot be assigned to parameter "size" of type "NotGiven | Literal['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792'] | None" in function "generate"   Type "str" cannot be assigned to type "NotGiven | Literal['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792'] | None"     "str" is incompatible with "None"     "str" is incompatible with "NotGiven"     "str" cannot be assigned to type "Literal['256x256']"     "str" cannot be assigned to type "Literal['512x512']"     "str" cannot be assigned to type "Literal['1024x1024']"     "str" cannot be assigned to type "Literal['1792x1024']"     "str" cannot be assigned to type "Literal['1024x1792']" (reportGeneralTypeIssues)
)
return image_response.data[0].url # type: ignore

url = image_response.data[0].url
assert url
return url


async def create_inspiration(session: aiohttp.ClientSession) -> str:
Expand All @@ -146,6 +128,28 @@ async def create_inspiration(session: aiohttp.ClientSession) -> str:
return await response.text()


@asynccontextmanager
async def temp_file_names(*extensions: str) -> AsyncGenerator[tuple[Path, ...], None]:
"""Create temporary filesystem paths to generated filenames in a temporary folder.
Upon completion, the folder is removed.
Parameters
----------
*extensions: tuple[:class:`str`]
The file extensions that the generated filenames should have, e.g. py, txt, doc.
Yields
------
temp_paths: tuple[:class:`Path`]
Filepaths with random filenames with the given file extensions, in order.
"""

async with aiofiles.tempfile.TemporaryDirectory() as temp_dir:
temp_paths = tuple(Path(temp_dir).joinpath(f"temp_output{i}." + ext) for i, ext in enumerate(extensions))
yield temp_paths


async def create_morph(before_img_buffer: BytesIO, after_img_buffer: BytesIO) -> BytesIO:
"""Create a morph gif between two images using ffmpeg.
Expand Down
5 changes: 0 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import aiohttp
import asyncpg
import discord
import openai

import core
from core.tree import HookableTree
Expand All @@ -19,10 +18,6 @@ async def main() -> None:
command_timeout=30,
init=pool_init,
) as pool, LoggingManager() as logging_manager:
# Set up OpenAI.
openai.api_key = core.CONFIG.openai.key
openai.aiosession.set(web_session)

# Set the bot's basic starting parameters.
intents = discord.Intents.all()
intents.presences = False
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ requires-python = ">=3.10"
authors = [{ name = "Sachaa-Thanasius", email = "111999343+Sachaa-Thanasius@users.noreply.github.com" }]

[project.urls]
"Homepage" = "https://github.com/Sachaa-Thanasius/Beira"
Homepage = "https://github.com/Sachaa-Thanasius/Beira"
"Bug Tracker" = "https://github.com/Sachaa-Thanasius/Beira/issues"

[tool.black]
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ lxml>=4.9.3
matplotlib
msgspec[toml]
numpy
openai
openai>=1.1.0
Pillow>=10.0.0
tatsu @ git+https://github.com/Sachaa-Thanasius/Tatsu.git
typing_extensions>=4.5.0,<5
wavelink>=2.6.1,<3
wavelink>=2.6.3,<3
wavelink-stubs
yarl>=1.8.2,<2

0 comments on commit 7dbda8b

Please sign in to comment.