diff --git a/metricity/exts/event_listeners/_syncer_utils.py b/metricity/exts/event_listeners/_syncer_utils.py new file mode 100644 index 0000000..258a165 --- /dev/null +++ b/metricity/exts/event_listeners/_syncer_utils.py @@ -0,0 +1,152 @@ +import discord +from pydis_core.utils import logging +from sqlalchemy import update +from sqlalchemy.ext.asyncio import AsyncSession + +from metricity import models +from metricity.bot import Bot +from metricity.config import BotConfig +from metricity.database import async_session + +log = logging.get_logger(__name__) + + +def insert_thread(thread: discord.Thread, sess: AsyncSession) -> None: + """Insert the given thread to the database session.""" + sess.add(models.Thread( + id=str(thread.id), + parent_channel_id=str(thread.parent_id), + name=thread.name, + archived=thread.archived, + auto_archive_duration=thread.auto_archive_duration, + locked=thread.locked, + type=thread.type.name, + created_at=thread.created_at, + )) + + +async def sync_message(message: discord.Message, sess: AsyncSession, *, from_thread: bool) -> None: + """Sync the given message with the database.""" + if await sess.get(models.Message, str(message.id)): + return + + args = { + "id": str(message.id), + "channel_id": str(message.channel.id), + "author_id": str(message.author.id), + "created_at": message.created_at, + } + + if from_thread: + thread = message.channel + args["channel_id"] = str(thread.parent_id) + args["thread_id"] = str(thread.id) + + sess.add(models.Message(**args)) + + +async def sync_channels(bot: Bot, guild: discord.Guild) -> None: + """Sync channels and categories with the database.""" + bot.channel_sync_in_progress.clear() + + log.info("Beginning category synchronisation process") + + async with async_session() as sess: + for channel in guild.channels: + if isinstance(channel, discord.CategoryChannel): + if existing_cat := await sess.get(models.Category, str(channel.id)): + existing_cat.name = channel.name + else: + sess.add(models.Category(id=str(channel.id), name=channel.name, deleted=False)) + + await sess.commit() + + log.info("Category synchronisation process complete, synchronising deleted categories") + + async with async_session() as sess: + await sess.execute( + update(models.Category) + .where(~models.Category.id.in_( + [str(channel.id) for channel in guild.channels if isinstance(channel, discord.CategoryChannel)], + )) + .values(deleted=True), + ) + await sess.commit() + + log.info("Deleted category synchronisation process complete, synchronising channels") + + async with async_session() as sess: + for channel in guild.channels: + if channel.category and channel.category.id in BotConfig.ignore_categories: + continue + + if not isinstance(channel, discord.CategoryChannel): + category_id = str(channel.category.id) if channel.category else None + # Cast to bool so is_staff is False if channel.category is None + is_staff = channel.id in BotConfig.staff_channels or bool( + channel.category and channel.category.id in BotConfig.staff_categories, + ) + if db_chan := await sess.get(models.Channel, str(channel.id)): + db_chan.name = channel.name + else: + sess.add(models.Channel( + id=str(channel.id), + name=channel.name, + category_id=category_id, + is_staff=is_staff, + deleted=False, + )) + + await sess.commit() + + log.info("Channel synchronisation process complete, synchronising deleted channels") + + async with async_session() as sess: + await sess.execute( + update(models.Channel) + .where(~models.Channel.id.in_([str(channel.id) for channel in guild.channels])) + .values(deleted=True), + ) + await sess.commit() + + log.info("Deleted channel synchronisation process complete, synchronising threads") + + async with async_session() as sess: + for thread in guild.threads: + if thread.parent and thread.parent.category: + if thread.parent.category.id in BotConfig.ignore_categories: + continue + else: + # This is a forum channel, not currently supported by Discord.py. Ignore it. + continue + + if db_thread := await sess.get(models.Thread, str(thread.id)): + db_thread.name = thread.name + db_thread.archived = thread.archived + db_thread.auto_archive_duration = thread.auto_archive_duration + db_thread.locked = thread.locked + db_thread.type = thread.type.name + else: + insert_thread(thread, sess) + await sess.commit() + + log.info("Thread synchronisation process complete, finished synchronising guild.") + bot.channel_sync_in_progress.set() + + +async def sync_thread_archive_state(guild: discord.Guild) -> None: + """Sync the archive state of all threads in the database with the state in guild.""" + active_thread_ids = [str(thread.id) for thread in guild.threads] + + async with async_session() as sess: + await sess.execute( + update(models.Thread) + .where(models.Thread.id.in_(active_thread_ids)) + .values(archived=False), + ) + await sess.execute( + update(models.Thread) + .where(~models.Thread.id.in_(active_thread_ids)) + .values(archived=True), + ) + await sess.commit() diff --git a/metricity/exts/event_listeners/_utils.py b/metricity/exts/event_listeners/_utils.py deleted file mode 100644 index 4006ea2..0000000 --- a/metricity/exts/event_listeners/_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -import discord -from sqlalchemy.ext.asyncio import AsyncSession - -from metricity import models - - -def insert_thread(thread: discord.Thread, sess: AsyncSession) -> None: - """Insert the given thread to the database session.""" - sess.add(models.Thread( - id=str(thread.id), - parent_channel_id=str(thread.parent_id), - name=thread.name, - archived=thread.archived, - auto_archive_duration=thread.auto_archive_duration, - locked=thread.locked, - type=thread.type.name, - created_at=thread.created_at, - )) - - -async def sync_message(message: discord.Message, sess: AsyncSession, *, from_thread: bool) -> None: - """Sync the given message with the database.""" - if await sess.get(models.Message, str(message.id)): - return - - args = { - "id": str(message.id), - "channel_id": str(message.channel.id), - "author_id": str(message.author.id), - "created_at": message.created_at, - } - - if from_thread: - thread = message.channel - args["channel_id"] = str(thread.parent_id) - args["thread_id"] = str(thread.id) - - sess.add(models.Message(**args)) diff --git a/metricity/exts/event_listeners/guild_listeners.py b/metricity/exts/event_listeners/guild_listeners.py index 79cd8f4..db976b2 100644 --- a/metricity/exts/event_listeners/guild_listeners.py +++ b/metricity/exts/event_listeners/guild_listeners.py @@ -1,18 +1,12 @@ """An ext to listen for guild (and guild channel) events and syncs them to the database.""" -import math - import discord from discord.ext import commands -from pydis_core.utils import logging, scheduling -from sqlalchemy import column, update -from sqlalchemy.dialects.postgresql import insert +from pydis_core.utils import logging -from metricity import models from metricity.bot import Bot from metricity.config import BotConfig -from metricity.database import async_session -from metricity.exts.event_listeners import _utils +from metricity.exts.event_listeners import _syncer_utils log = logging.get_logger(__name__) @@ -22,187 +16,6 @@ class GuildListeners(commands.Cog): def __init__(self, bot: Bot) -> None: self.bot = bot - scheduling.create_task(self.sync_guild()) - - async def sync_guild(self) -> None: - """Sync all channels and members in the guild.""" - await self.bot.wait_until_guild_available() - - guild = self.bot.get_guild(self.bot.guild_id) - await self.sync_channels(guild) - - log.info("Beginning thread archive state synchronisation process") - await self.sync_thread_archive_state(guild) - - log.info("Beginning user synchronisation process") - async with async_session() as sess: - await sess.execute(update(models.User).values(in_guild=False)) - await sess.commit() - - users = [ - { - "id": str(user.id), - "name": user.name, - "avatar_hash": getattr(user.avatar, "key", None), - "guild_avatar_hash": getattr(user.guild_avatar, "key", None), - "joined_at": user.joined_at, - "created_at": user.created_at, - "is_staff": BotConfig.staff_role_id in [role.id for role in user.roles], - "bot": user.bot, - "in_guild": True, - "public_flags": dict(user.public_flags), - "pending": user.pending, - } - for user in guild.members - ] - - user_chunks = discord.utils.as_chunks(users, 500) - created = 0 - updated = 0 - total_users = len(users) - - log.info("Performing bulk upsert of %d rows in %d chunks", total_users, math.ceil(total_users / 500)) - - async with async_session() as sess: - for chunk in user_chunks: - qs = insert(models.User).returning(column("xmax")).values(chunk) - - update_cols = [ - "name", - "avatar_hash", - "guild_avatar_hash", - "joined_at", - "is_staff", - "bot", - "in_guild", - "public_flags", - "pending", - ] - - res = await sess.execute(qs.on_conflict_do_update( - index_elements=[models.User.id], - set_={k: getattr(qs.excluded, k) for k in update_cols}, - )) - - objs = list(res) - - created += [obj[0] == 0 for obj in objs].count(True) - updated += [obj[0] != 0 for obj in objs].count(True) - - log.info("User upsert: inserted %d rows, updated %d rows, done %d rows, %d rows remaining", - created, updated, created + updated, total_users - (created + updated)) - - await sess.commit() - - log.info("User upsert complete") - - self.bot.sync_process_complete.set() - - @staticmethod - async def sync_thread_archive_state(guild: discord.Guild) -> None: - """Sync the archive state of all threads in the database with the state in guild.""" - active_thread_ids = [str(thread.id) for thread in guild.threads] - - async with async_session() as sess: - await sess.execute( - update(models.Thread) - .where(models.Thread.id.in_(active_thread_ids)) - .values(archived=False), - ) - await sess.execute( - update(models.Thread) - .where(~models.Thread.id.in_(active_thread_ids)) - .values(archived=True), - ) - await sess.commit() - - async def sync_channels(self, guild: discord.Guild) -> None: - """Sync channels and categories with the database.""" - self.bot.channel_sync_in_progress.clear() - - log.info("Beginning category synchronisation process") - - async with async_session() as sess: - for channel in guild.channels: - if isinstance(channel, discord.CategoryChannel): - if existing_cat := await sess.get(models.Category, str(channel.id)): - existing_cat.name = channel.name - else: - sess.add(models.Category(id=str(channel.id), name=channel.name, deleted=False)) - - await sess.commit() - - log.info("Category synchronisation process complete, synchronising deleted categories") - - async with async_session() as sess: - await sess.execute( - update(models.Category) - .where(~models.Category.id.in_( - [str(channel.id) for channel in guild.channels if isinstance(channel, discord.CategoryChannel)], - )) - .values(deleted=True), - ) - await sess.commit() - - log.info("Deleted category synchronisation process complete, synchronising channels") - - async with async_session() as sess: - for channel in guild.channels: - if channel.category and channel.category.id in BotConfig.ignore_categories: - continue - - if not isinstance(channel, discord.CategoryChannel): - category_id = str(channel.category.id) if channel.category else None - # Cast to bool so is_staff is False if channel.category is None - is_staff = channel.id in BotConfig.staff_channels or bool( - channel.category and channel.category.id in BotConfig.staff_categories, - ) - if db_chan := await sess.get(models.Channel, str(channel.id)): - db_chan.name = channel.name - else: - sess.add(models.Channel( - id=str(channel.id), - name=channel.name, - category_id=category_id, - is_staff=is_staff, - deleted=False, - )) - - await sess.commit() - - log.info("Channel synchronisation process complete, synchronising deleted channels") - - async with async_session() as sess: - await sess.execute( - update(models.Channel) - .where(~models.Channel.id.in_([str(channel.id) for channel in guild.channels])) - .values(deleted=True), - ) - await sess.commit() - - log.info("Deleted channel synchronisation process complete, synchronising threads") - - async with async_session() as sess: - for thread in guild.threads: - if thread.parent and thread.parent.category: - if thread.parent.category.id in BotConfig.ignore_categories: - continue - else: - # This is a forum channel, not currently supported by Discord.py. Ignore it. - continue - - if db_thread := await sess.get(models.Thread, str(thread.id)): - db_thread.name = thread.name - db_thread.archived = thread.archived - db_thread.auto_archive_duration = thread.auto_archive_duration - db_thread.locked = thread.locked - db_thread.type = thread.type.name - else: - _utils.insert_thread(thread, sess) - await sess.commit() - - log.info("Thread synchronisation process complete, finished synchronising guild.") - self.bot.channel_sync_in_progress.set() @commands.Cog.listener() async def on_guild_channel_create(self, channel: discord.abc.GuildChannel) -> None: @@ -210,7 +23,7 @@ async def on_guild_channel_create(self, channel: discord.abc.GuildChannel) -> No if channel.guild.id != BotConfig.guild_id: return - await self.sync_channels(channel.guild) + await _syncer_utils.sync_channels(self.bot, channel.guild) @commands.Cog.listener() async def on_guild_channel_delete(self, channel: discord.abc.GuildChannel) -> None: @@ -218,7 +31,7 @@ async def on_guild_channel_delete(self, channel: discord.abc.GuildChannel) -> No if channel.guild.id != BotConfig.guild_id: return - await self.sync_channels(channel.guild) + await _syncer_utils.sync_channels(self.bot, channel.guild) @commands.Cog.listener() async def on_guild_channel_update( @@ -230,7 +43,7 @@ async def on_guild_channel_update( if channel.guild.id != BotConfig.guild_id: return - await self.sync_channels(channel.guild) + await _syncer_utils.sync_channels(self.bot, channel.guild) @commands.Cog.listener() async def on_thread_create(self, thread: discord.Thread) -> None: @@ -238,7 +51,7 @@ async def on_thread_create(self, thread: discord.Thread) -> None: if thread.guild.id != BotConfig.guild_id: return - await self.sync_channels(thread.guild) + await _syncer_utils.sync_channels(self.bot, thread.guild) @commands.Cog.listener() async def on_thread_update(self, _before: discord.Thread, thread: discord.Thread) -> None: @@ -246,18 +59,7 @@ async def on_thread_update(self, _before: discord.Thread, thread: discord.Thread if thread.guild.id != BotConfig.guild_id: return - await self.sync_channels(thread.guild) - - @commands.Cog.listener() - async def on_guild_available(self, guild: discord.Guild) -> None: - """Synchronize the user table with the Discord users.""" - log.info("Received guild available for %d", guild.id) - - if guild.id != BotConfig.guild_id: - log.info("Guild was not the configured guild, discarding event") - return - - await self.sync_guild() + await _syncer_utils.sync_channels(self.bot, thread.guild) async def setup(bot: Bot) -> None: diff --git a/metricity/exts/event_listeners/message_listeners.py b/metricity/exts/event_listeners/message_listeners.py index a71e53f..917b13c 100644 --- a/metricity/exts/event_listeners/message_listeners.py +++ b/metricity/exts/event_listeners/message_listeners.py @@ -7,7 +7,7 @@ from metricity.bot import Bot from metricity.config import BotConfig from metricity.database import async_session -from metricity.exts.event_listeners import _utils +from metricity.exts.event_listeners import _syncer_utils from metricity.models import Message, User @@ -44,7 +44,7 @@ async def on_message(self, message: discord.Message) -> None: return from_thread = isinstance(message.channel, discord.Thread) - await _utils.sync_message(message, sess, from_thread=from_thread) + await _syncer_utils.sync_message(message, sess, from_thread=from_thread) await sess.commit() diff --git a/metricity/exts/event_listeners/startup_sync.py b/metricity/exts/event_listeners/startup_sync.py new file mode 100644 index 0000000..0f6264f --- /dev/null +++ b/metricity/exts/event_listeners/startup_sync.py @@ -0,0 +1,115 @@ +"""An ext to sync the guild when the bot starts up.""" + +import math + +import discord +from discord.ext import commands +from pydis_core.utils import logging, scheduling +from sqlalchemy import column, update +from sqlalchemy.dialects.postgresql import insert + +from metricity import models +from metricity.bot import Bot +from metricity.config import BotConfig +from metricity.database import async_session +from metricity.exts.event_listeners import _syncer_utils + +log = logging.get_logger(__name__) + + +class StartupSyncer(commands.Cog): + """Sync the guild on bot startup.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + scheduling.create_task(self.sync_guild()) + + async def sync_guild(self) -> None: + """Sync all channels and members in the guild.""" + await self.bot.wait_until_guild_available() + + guild = self.bot.get_guild(self.bot.guild_id) + await _syncer_utils.sync_channels(self.bot, guild) + + log.info("Beginning thread archive state synchronisation process") + await _syncer_utils.sync_thread_archive_state(guild) + + log.info("Beginning user synchronisation process") + async with async_session() as sess: + await sess.execute(update(models.User).values(in_guild=False)) + await sess.commit() + + users = ( + { + "id": str(user.id), + "name": user.name, + "avatar_hash": getattr(user.avatar, "key", None), + "guild_avatar_hash": getattr(user.guild_avatar, "key", None), + "joined_at": user.joined_at, + "created_at": user.created_at, + "is_staff": BotConfig.staff_role_id in [role.id for role in user.roles], + "bot": user.bot, + "in_guild": True, + "public_flags": dict(user.public_flags), + "pending": user.pending, + } + for user in guild.members + ) + + user_chunks = discord.utils.as_chunks(users, 500) + created = 0 + updated = 0 + total_users = len(guild.members) + + log.info("Performing bulk upsert of %d rows in %d chunks", total_users, math.ceil(total_users / 500)) + + async with async_session() as sess: + for chunk in user_chunks: + qs = insert(models.User).returning(column("xmax")).values(chunk) + + update_cols = [ + "name", + "avatar_hash", + "guild_avatar_hash", + "joined_at", + "is_staff", + "bot", + "in_guild", + "public_flags", + "pending", + ] + + res = await sess.execute(qs.on_conflict_do_update( + index_elements=[models.User.id], + set_={k: getattr(qs.excluded, k) for k in update_cols}, + )) + + objs = list(res) + + created += [obj[0] == 0 for obj in objs].count(True) + updated += [obj[0] != 0 for obj in objs].count(True) + + log.info("User upsert: inserted %d rows, updated %d rows, done %d rows, %d rows remaining", + created, updated, created + updated, total_users - (created + updated)) + + await sess.commit() + + log.info("User upsert complete") + + self.bot.sync_process_complete.set() + + @commands.Cog.listener() + async def on_guild_available(self, guild: discord.Guild) -> None: + """Synchronize the user table with the Discord users.""" + log.info("Received guild available for %d", guild.id) + + if guild.id != BotConfig.guild_id: + log.info("Guild was not the configured guild, discarding event") + return + + await self.sync_guild() + + +async def setup(bot: Bot) -> None: + """Load the GuildListeners cog.""" + await bot.add_cog(StartupSyncer(bot)) diff --git a/pyproject.toml b/pyproject.toml index e817933..5c5e61c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "metricity" -version = "2.5.1" +version = "2.6.0" description = "Advanced metric collection for the Python Discord server" authors = ["Joe Banks "] license = "MIT"