From 1ef57b21d00d0ef183d95ea17b90403b8bf84821 Mon Sep 17 00:00:00 2001 From: hypergonial <46067571+hypergonial@users.noreply.github.com> Date: Wed, 13 Dec 2023 01:16:14 +0100 Subject: [PATCH] Fix some basic typing mistakes --- extensions/reminders.py | 2 +- models/checks.py | 4 ++-- models/db.py | 4 ++-- models/db_user.py | 24 ++++++++++++------------ models/journal.py | 2 +- models/mod_actions.py | 4 ++-- models/rolebutton.py | 39 ++++++++++++++++++++------------------- models/starboard.py | 4 ++-- models/tag.py | 28 ++++++++++++++-------------- utils/cache.py | 2 +- 10 files changed, 57 insertions(+), 56 deletions(-) diff --git a/extensions/reminders.py b/extensions/reminders.py index 0c2e9f7..bbb137e 100644 --- a/extensions/reminders.py +++ b/extensions/reminders.py @@ -367,7 +367,7 @@ async def reminder_list(ctx: SnedSlashContext) -> None: hikari.Embed(title="✉️ Your reminders:", description="\n".join(content), color=const.EMBED_BLUE) for content in reminders ] - # TODO: wtf + # FIXME: wtf typing navigator = AuthorOnlyNavigator(ctx, pages=pages, timeout=600) # type: ignore await navigator.send(ctx.interaction) diff --git a/models/checks.py b/models/checks.py index 5031947..f6506c2 100644 --- a/models/checks.py +++ b/models/checks.py @@ -103,7 +103,7 @@ async def _has_permissions(ctx: SnedApplicationContext, *, perms: hikari.Permiss if isinstance(channel, hikari.GuildThreadChannel): channel = ctx.app.cache.get_guild_channel(channel.parent_id) - assert isinstance(channel, hikari.GuildChannel) + assert isinstance(channel, hikari.PermissibleGuildChannel) member_perms = lightbulb.utils.permissions_in(channel, ctx.member) missing_perms = ~member_perms & perms @@ -140,7 +140,7 @@ async def _bot_has_permissions(ctx: SnedContext, *, perms: hikari.Permissions) - if isinstance(channel, hikari.GuildThreadChannel): channel = ctx.app.cache.get_guild_channel(channel.parent_id) - assert isinstance(channel, hikari.GuildChannel) + assert isinstance(channel, hikari.PermissibleGuildChannel) bot_perms = lightbulb.utils.permissions_in(channel, member) missing_perms = ~bot_perms & perms diff --git a/models/db.py b/models/db.py index 218b32c..8ab4104 100644 --- a/models/db.py +++ b/models/db.py @@ -121,7 +121,7 @@ async def acquire(self) -> t.AsyncIterator[asyncpg.Connection]: """Acquire a database connection from the connection pool.""" con = await self.pool.acquire() try: - yield con + yield con # type: ignore finally: await self.pool.release(con) @@ -213,7 +213,7 @@ async def fetchrow(self, query: str, *args, timeout: float | None = None) -> asy DatabaseStateConflictError The application is not connected to the database server. """ - return await self.pool.fetchrow(query, *args, timeout=timeout) + return await self.pool.fetchrow(query, *args, timeout=timeout) # type: ignore async def fetchval(self, query: str, *args, column: int = 0, timeout: float | None = None) -> t.Any: """Run a query and return a value in the first row that matched query parameters. diff --git a/models/db_user.py b/models/db_user.py index 1272d14..5d91500 100644 --- a/models/db_user.py +++ b/models/db_user.py @@ -60,7 +60,7 @@ async def update(self) -> None: @classmethod async def fetch( cls, user: hikari.SnowflakeishOr[hikari.PartialUser], guild: hikari.SnowflakeishOr[hikari.PartialGuild] - ) -> DatabaseUser: + ) -> t.Self: """Fetch a user from the database. If not present, returns a default DatabaseUser object. Parameters @@ -86,15 +86,15 @@ async def fetch( return cls(hikari.Snowflake(user), hikari.Snowflake(guild), flags=DatabaseUserFlag.NONE, warns=0) return cls( - id=hikari.Snowflake(record.get("user_id")), - guild_id=hikari.Snowflake(record.get("guild_id")), - flags=DatabaseUserFlag(record.get("flags")), - warns=record.get("warns"), - data=json.loads(record.get("data")) if record.get("data") else {}, + id=hikari.Snowflake(record["user_id"]), + guild_id=hikari.Snowflake(record["guild_id"]), + flags=DatabaseUserFlag(record["flags"]), + warns=record["warns"], + data=json.loads(record["data"]) if record.get("data") else {}, ) @classmethod - async def fetch_all(cls, guild: hikari.SnowflakeishOr[hikari.PartialGuild]) -> list[DatabaseUser]: + async def fetch_all(cls, guild: hikari.SnowflakeishOr[hikari.PartialGuild]) -> list[t.Self]: """Fetch all stored user data that belongs to the specified guild. Parameters @@ -115,11 +115,11 @@ async def fetch_all(cls, guild: hikari.SnowflakeishOr[hikari.PartialGuild]) -> l return [ cls( - id=hikari.Snowflake(record.get("user_id")), - guild_id=hikari.Snowflake(record.get("guild_id")), - flags=DatabaseUserFlag(record.get("flags")), - warns=record.get("warns"), - data=json.loads(record.get("data")) if record.get("data") else {}, + id=hikari.Snowflake(record["user_id"]), + guild_id=hikari.Snowflake(record["guild_id"]), + flags=DatabaseUserFlag(record["flags"]), + warns=record["warns"], + data=json.loads(record["data"]) if record.get("data") else {}, ) for record in records ] diff --git a/models/journal.py b/models/journal.py index fdb68e8..f603db5 100644 --- a/models/journal.py +++ b/models/journal.py @@ -81,7 +81,7 @@ def from_record(cls, record: asyncpg.Record) -> JournalEntry: user_id=hikari.Snowflake(record["user_id"]), guild_id=hikari.Snowflake(record["guild_id"]), content=record.get("content"), - author_id=hikari.Snowflake(record.get("author_id")) if record.get("author_id") else None, + author_id=hikari.Snowflake(record["author_id"]) if record.get("author_id") else None, created_at=datetime.datetime.fromtimestamp(record["created_at"]), entry_type=JournalEntryType(record["entry_type"]), ) diff --git a/models/mod_actions.py b/models/mod_actions.py index e99657f..a19869b 100644 --- a/models/mod_actions.py +++ b/models/mod_actions.py @@ -335,7 +335,7 @@ async def remove_timeout_extensions(self, event: hikari.MemberUpdateEvent): return for record in records: - await self.app.scheduler.cancel_timer(record.get("id"), event.guild_id) + await self.app.scheduler.cancel_timer(record["id"], event.guild_id) async def tempban_expire(self, event: TimerCompleteEvent) -> None: """Handle tempban timer expiry and unban user.""" @@ -639,7 +639,7 @@ async def ban( "tempban", ) if record: - await self.app.scheduler.cancel_timer(record.get("id"), moderator.guild_id) + await self.app.scheduler.cancel_timer(record["id"], moderator.guild_id) if soft: await self.app.rest.unban_user(moderator.guild_id, user.id, reason="Automatic unban by softban.") diff --git a/models/rolebutton.py b/models/rolebutton.py index 6c98244..607fd05 100644 --- a/models/rolebutton.py +++ b/models/rolebutton.py @@ -1,6 +1,7 @@ from __future__ import annotations import enum +import typing as t import hikari import miru @@ -95,15 +96,15 @@ async def fetch(cls, id: int) -> RoleButton | None: return None return cls( - id=record.get("entry_id"), - guild_id=hikari.Snowflake(record.get("guild_id")), - channel_id=hikari.Snowflake(record.get("channel_id")), - message_id=hikari.Snowflake(record.get("msg_id")), - emoji=hikari.Emoji.parse(record.get("emoji")), - label=record.get("label"), - style=hikari.ButtonStyle[record.get("style")], - mode=RoleButtonMode(record.get("mode")), - role_id=record.get("role_id"), + id=record["entry_id"], + guild_id=hikari.Snowflake(record["guild_id"]), + channel_id=hikari.Snowflake(record["channel_id"]), + message_id=hikari.Snowflake(record["msg_id"]), + emoji=hikari.Emoji.parse(record["emoji"]), + label=record["label"], + style=hikari.ButtonStyle[record["style"]], + mode=RoleButtonMode(record["mode"]), + role_id=record["role_id"], add_title=record.get("add_title"), add_description=record.get("add_desc"), remove_title=record.get("remove_title"), @@ -111,7 +112,7 @@ async def fetch(cls, id: int) -> RoleButton | None: ) @classmethod - async def fetch_all(cls, guild: hikari.SnowflakeishOr[hikari.PartialGuild]) -> list[RoleButton]: + async def fetch_all(cls, guild: hikari.SnowflakeishOr[hikari.PartialGuild]) -> list[t.Self]: """Fetch all rolebuttons that belong to a given guild. Parameters @@ -131,15 +132,15 @@ async def fetch_all(cls, guild: hikari.SnowflakeishOr[hikari.PartialGuild]) -> l return [ cls( - id=record.get("entry_id"), - guild_id=hikari.Snowflake(record.get("guild_id")), - channel_id=hikari.Snowflake(record.get("channel_id")), - message_id=hikari.Snowflake(record.get("msg_id")), - emoji=hikari.Emoji.parse(record.get("emoji")), + id=record["entry_id"], + guild_id=hikari.Snowflake(record["guild_id"]), + channel_id=hikari.Snowflake(record["channel_id"]), + message_id=hikari.Snowflake(record["msg_id"]), + emoji=hikari.Emoji.parse(record["emoji"]), label=record.get("label"), - style=hikari.ButtonStyle[record.get("style")], - mode=RoleButtonMode(record.get("mode")), - role_id=record.get("role_id"), + style=hikari.ButtonStyle[record["style"]], + mode=RoleButtonMode(record["mode"]), + role_id=record["role_id"], add_title=record.get("add_title"), add_description=record.get("add_desc"), remove_title=record.get("remove_title"), @@ -193,7 +194,7 @@ async def create( """ record = await cls._db.fetchrow("""SELECT entry_id FROM button_roles ORDER BY entry_id DESC""") - id = record.get("entry_id") + 1 if record else 1 + id = record["entry_id"] + 1 if record else 1 role_id = hikari.Snowflake(role) button = miru.Button( diff --git a/models/starboard.py b/models/starboard.py index baaea02..0353dd1 100644 --- a/models/starboard.py +++ b/models/starboard.py @@ -47,7 +47,7 @@ async def fetch(cls, guild: hikari.SnowflakeishOr[hikari.PartialGuild]) -> Starb records = await cls._app.db_cache.get(table="starboard", guild_id=hikari.Snowflake(guild), limit=1) if not records: return cls(guild_id=hikari.Snowflake(guild)) - return cls.from_record(records[0]) + return cls.from_record(records[0]) # type: ignore async def update(self) -> None: """Update the starboard settings in the database, or insert them if they do not yet exist.""" @@ -107,7 +107,7 @@ async def fetch(cls, original_message: hikari.SnowflakeishOr[hikari.PartialMessa ) if not records: return None - return cls.from_record(records[0]) + return cls.from_record(records[0]) # type: ignore async def update(self) -> None: """Update the starboard entry in the database, or insert it if it does not yet exist.""" diff --git a/models/tag.py b/models/tag.py index 9484636..875cde5 100644 --- a/models/tag.py +++ b/models/tag.py @@ -30,7 +30,7 @@ class Tag(DatabaseModel): @classmethod async def fetch( cls, name: str, guild: hikari.SnowflakeishOr[hikari.PartialGuild], add_use: bool = False - ) -> Tag | None: + ) -> t.Self | None: """Fetches a tag from the database. Parameters @@ -60,13 +60,13 @@ async def fetch( return return cls( - guild_id=hikari.Snowflake(record.get("guild_id")), - name=record.get("tagname"), - owner_id=hikari.Snowflake(record.get("owner_id")), - creator_id=hikari.Snowflake(record.get("creator_id")) if record.get("creator_id") else None, + guild_id=hikari.Snowflake(record["guild_id"]), + name=record["tagname"], + owner_id=hikari.Snowflake(record["owner_id"]), + creator_id=hikari.Snowflake(record["creator_id"]) if record.get("creator_id") else None, aliases=record.get("aliases"), - content=record.get("content"), - uses=record.get("uses"), + content=record["content"], + uses=record["uses"], ) @classmethod @@ -140,7 +140,7 @@ async def fetch_all( cls, guild: hikari.SnowflakeishOr[hikari.PartialGuild], owner: hikari.SnowflakeishOr[hikari.PartialUser] | None = None, - ) -> list[Tag]: + ) -> list[t.Self]: """Fetch all tags that belong to a guild, and optionally a user. Parameters @@ -170,13 +170,13 @@ async def fetch_all( return [ cls( - guild_id=hikari.Snowflake(record.get("guild_id")), - name=record.get("tagname"), - owner_id=hikari.Snowflake(record.get("owner_id")), - creator_id=hikari.Snowflake(record.get("creator_id")) if record.get("creator_id") else None, + guild_id=hikari.Snowflake(record["guild_id"]), + name=record["tagname"], + owner_id=hikari.Snowflake(record["owner_id"]), + creator_id=hikari.Snowflake(record["creator_id"]) if record.get("creator_id") else None, aliases=record.get("aliases"), - content=record.get("content"), - uses=record.get("uses"), + content=record["content"], + uses=record["uses"], ) for record in records ] diff --git a/utils/cache.py b/utils/cache.py index 1de2cf3..1397aeb 100644 --- a/utils/cache.py +++ b/utils/cache.py @@ -44,7 +44,7 @@ async def start(self) -> None: """ ) for record in records: - self._cache[record.get("tablename")] = [] + self._cache[record["tablename"]] = [] logger.info("Cache initialized!") self.is_ready = True