Skip to content

Commit

Permalink
Use FooFlags.__new__ directly
Browse files Browse the repository at this point in the history
This speeds up library a **lot**:
```
Benchmarking DMChannel parsing.
[DMChannel] Time using pyvolt ----: 0.107055 seconds
[DMChannel] Time using revolt.py -: 0.150814 seconds
[DMChannel] Time using voltage ---: 2.892504 seconds
Benchmarking GroupChannel parsing.
[GroupChannel] Time using pyvolt ----: 0.185667 seconds
[GroupChannel] Time using revolt.py -: 0.261947 seconds
[GroupChannel] Time using voltage ---: 7.473850 seconds
Benchmarking TextChannel parsing.
[TextChannel] Time using pyvolt ----: 0.592224 seconds
[TextChannel] Time using revolt.py -: 8.655765 seconds
[TextChannel] Time using voltage ---: 15.454577 seconds
Benchmarking Member parsing.
[Member] Time using pyvolt ----: 0.056654 seconds
[Member] Time using revolt.py -: 0.402864 seconds
[Member] Time using voltage ---: 0.529478 seconds
Benchmarking Message parsing.
[Message] Time using pyvolt ----: 0.056601 seconds
[Message] Time using revolt.py -: 0.261532 seconds
[Message] Time using voltage ---: 0.516635 seconds
Benchmarking Server parsing.
[Server] Time using pyvolt ----: 0.684154 seconds
[Server] Time using revolt.py -: 4.294232 seconds
[Server] Time using voltage ---: 11.650132 seconds
Benchmarking User parsing.
[User] Time using pyvolt ----: 0.679173 seconds
[User] Time using revolt.py -: 1.043663 seconds
[User] Time using voltage ---: 3.473125 seconds
```
Library just takes 2.361524 seconds to parse DMs, groups, text channels, members, messages, servers and users!
  • Loading branch information
MCausc78 committed Aug 2, 2024
1 parent 390b5dc commit 1416367
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 74 deletions.
160 changes: 97 additions & 63 deletions pyvolt/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,26 @@
from collections.abc import Callable, Iterator
from typing_extensions import Self


class _MissingSentinel:
__slots__ = ()

def __bool__(self) -> typing.Literal[False]:
return False

def __repr__(self) -> typing.Literal['...']:
return '...'


MISSING: typing.Any = _MissingSentinel()

BF = typing.TypeVar('BF', bound='BaseFlags')


class _Flag(typing.Generic[BF]):
class flag(typing.Generic[BF]):
__slots__ = (
'_func',
'_parent',
'doc',
'name',
'value',
Expand All @@ -21,36 +35,44 @@ class _Flag(typing.Generic[BF]):
'use_any',
)

def __init__(self, func: Callable[[BF], int], *, inverted: bool, use_any: bool, alias: bool) -> None:
self._func: Callable[[BF], int] = func
self.doc: str | None = func.__doc__
self.name: str = func.__name__
def __init__(self, *, inverted: bool = False, use_any: bool = False, alias: bool = False) -> None:
self._func: Callable[[BF], int] = MISSING
self._parent: type[BF] = MISSING
self.doc: str | None = None
self.name: str = ''
self.value: int = 0
self.alias: bool = alias
self.inverted: bool = inverted
self.use_any: bool = use_any

def __call__(self, func: Callable[[BF], int], /) -> Self:
self._func = func
self.doc = func.__doc__
self.name = func.__name__
return self

class _MissingSentinel:
__slots__ = ()

def __repr__(self) -> typing.Literal['IGNORE_THIS']:
return 'IGNORE_THIS'


MISSING: typing.Any = _MissingSentinel()
@typing.overload
def __get__(self, instance: BF, owner: type[BF], /) -> bool: ...

@typing.overload
def __get__(self, instance: None, owner: type[BF], /) -> int: ...

def flag(
*, inverted: bool = False, use_any: bool = False, alias: bool = False, sentinel: BF = MISSING
) -> Callable[[Callable[[BF], int]], bool]:
def decorator(func: Callable[[BF], int]) -> _Flag[BF]:
return _Flag(func, inverted=inverted, use_any=use_any, alias=alias)
def __get__(self, instance: BF | None, owner: type[BF], /) -> bool | int:
if instance is None:
if self._parent:
return self.value
# Needs to be here to allow prepare class
return self # type: ignore
else:
return instance._get(self)

return decorator # type: ignore
def __set__(self, instance: BF, value: bool) -> None:
instance._set(self, value)


class BaseFlags:
"""Base class for flags."""

if typing.TYPE_CHECKING:
ALL_VALUE: typing.ClassVar[int]
INVERTED: typing.ClassVar[bool]
Expand All @@ -65,45 +87,12 @@ class BaseFlags:
def __init_subclass__(cls, *, inverted: bool = False, support_kwargs: bool = True) -> None:
valid_flags = {}
for k, f in inspect.getmembers(cls):
if isinstance(f, _Flag):
if isinstance(f, flag):
f.value = f._func(cls)
if f.alias:
continue
valid_flags[f.name] = f.value

bits = f.value

if f.use_any and f.inverted:

def fget(self: Self) -> bool:
return (self.value & bits) == 0
elif f.use_any:

def fget(self: Self) -> bool:
return (self.value & bits) != 0
elif f.inverted:

def fget(self: Self) -> bool:
return (self.value & bits) == bits
else:

def fget(self: Self) -> bool:
return (self.value & bits) != bits

if f.inverted:

def fset(self, value: bool) -> None:
self &= ~value
else:

def fset(self, value: bool) -> None:
self |= value

prop = property(
fget=fget,
fset=fset,
doc=f.doc,
)

setattr(cls, k, prop)
f._parent = cls

default = 0
if inverted:
Expand Down Expand Up @@ -140,6 +129,13 @@ def init_without_kwargs(self, value: int = cls.NONE_VALUE, /) -> None:
cls.__init__ = init_without_kwargs # type: ignore

cls.__slots__ = ('value',)

if cls.INVERTED:
cls._get = cls._get1
cls._set = cls._set1
else:
cls._get = cls._get2
cls._set = cls._set2
cls.ALL = cls(cls.ALL_VALUE)
cls.NONE = cls(cls.NONE_VALUE)

Expand All @@ -148,6 +144,48 @@ def init_without_kwargs(self, value: int = cls.NONE_VALUE, /) -> None:
def __init__(self, value: int = 0, /, **kwargs: bool) -> None:
pass

def _get(self, other: flag[Self], /) -> bool:
return False

def _set(self, flag: flag[Self], value: bool, /) -> None:
pass

# used if flag is inverted
def _get1(self, other: flag[Self], /) -> bool:
if other.use_any and other.inverted:
return (self.value & ~other.value) == 0
elif other.use_any:
return (self.value & ~other.value) != 0
elif other.inverted:
return (self.value & ~other.value) == other.value
else:
return (self.value & ~other.value) != other.value

# used if flag is uninverted
def _get2(self, other: flag[Self], /) -> bool:
if other.use_any and other.inverted:
return (self.value & other.value) == 0
elif other.use_any:
return (self.value & other.value) != 0
elif other.inverted:
return (self.value & other.value) == other.value
else:
return (self.value & other.value) != other.value

# used if flag is inverted
def _set1(self, flag: flag[Self], value: bool, /) -> None:
if flag.inverted:
self.value &= ~flag.value
else:
self.value |= flag.value

# used if flag is uninverted
def _set2(self, flag: flag[Self], value: bool, /) -> None:
if flag.inverted:
self.value |= flag.value
else:
self.value &= ~flag.value

@classmethod
def all(cls) -> Self:
return cls(cls.ALL_VALUE)
Expand All @@ -167,7 +205,7 @@ def __hash__(self) -> int:

def __iter__(self) -> Iterator[tuple[str, bool]]:
for name, value in self.__class__.__dict__.items():
if isinstance(value, _Flag):
if isinstance(value, flag):
if value.alias:
continue
yield (name, getattr(self, name))
Expand Down Expand Up @@ -263,16 +301,13 @@ def __xor__(self, other: Self | int, /) -> Self:
return self.__class__(self.value ^ other)


F = typing.TypeVar('F', bound='BaseFlags')


def doc_flags(
intro: str,
/,
*,
added_in: str | None = None,
) -> Callable[[type[F]], type[F]]:
def decorator(cls: type[F]) -> type[F]:
) -> Callable[[type[BF]], type[BF]]:
def decorator(cls: type[BF]) -> type[BF]:
directives = ''

if added_in:
Expand Down Expand Up @@ -709,7 +744,6 @@ def spam(cls) -> int:


__all__ = (
'_Flag',
'flag',
'BaseFlags',
'doc_flags',
Expand Down
33 changes: 22 additions & 11 deletions pyvolt/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from copy import copy
from datetime import datetime
from functools import partial
import logging
import typing

Expand Down Expand Up @@ -217,6 +218,10 @@

_EMPTY_DICT: dict[typing.Any, typing.Any] = {}

_new_server_flags = ServerFlags.__new__
_new_user_badges = UserBadges.__new__
_new_user_flags = UserFlags.__new__


class Parser:
def __init__(self, state: State) -> None:
Expand Down Expand Up @@ -395,7 +400,7 @@ def parse_bulk_message_delete_event(
message_ids=d['ids'],
)

def parse_category(self, d: raw.Category) -> Category:
def parse_category(self, d: raw.Category, /) -> Category:
return Category(
id=d['id'],
title=d['title'],
Expand Down Expand Up @@ -1271,7 +1276,7 @@ def parse_response_webhook(self, d: raw.ResponseWebhook) -> Webhook:
token=None,
)

def parse_role(self, d: raw.Role, role_id: str, server_id: str) -> Role:
def parse_role(self, d: raw.Role, role_id: str, server_id: str, /) -> Role:
return Role(
state=self.state,
id=role_id,
Expand Down Expand Up @@ -1309,6 +1314,9 @@ def _parse_server(
icon = d.get('icon')
banner = d.get('banner')

flags = _new_server_flags(ServerFlags)
flags.value = d.get('flags', 0)

return Server(
state=self.state,
id=server_id,
Expand All @@ -1322,7 +1330,7 @@ def _parse_server(
default_permissions=Permissions(d['default_permissions']),
internal_icon=self.parse_asset(icon) if icon else None,
internal_banner=self.parse_asset(banner) if banner else None,
flags=ServerFlags(d.get('flags', 0)),
flags=flags,
nsfw=d.get('nsfw', False),
analytics=d.get('analytics', False),
discoverable=d.get('discoverable', False),
Expand Down Expand Up @@ -1563,18 +1571,15 @@ def parse_spotify_embed_special(self, d: raw.SpotifySpecial) -> SpotifyEmbedSpec
def parse_streamable_embed_special(self, d: raw.StreamableSpecial) -> StreamableEmbedSpecial:
return StreamableEmbedSpecial(id=d['id'])

def parse_system_message_channels(
self,
d: raw.SystemMessageChannels,
) -> SystemMessageChannels:
def parse_system_message_channels(self, d: raw.SystemMessageChannels, /) -> SystemMessageChannels:
return SystemMessageChannels(
user_joined=d.get('user_joined'),
user_left=d.get('user_left'),
user_kicked=d.get('user_kicked'),
user_banned=d.get('user_banned'),
)

def parse_text_channel(self, d: raw.TextChannel) -> ServerTextChannel:
def parse_text_channel(self, d: raw.TextChannel, /) -> ServerTextChannel:
icon = d.get('icon')
default_permissions = d.get('default_permissions')
role_permissions = d.get('role_permissions', {})
Expand Down Expand Up @@ -1626,7 +1631,13 @@ def parse_user(self, d: raw.User, /) -> User | OwnUser:

avatar = d.get('avatar')
status = d.get('status')
# profile = d.get('profile')

badges = _new_user_badges(UserBadges)
badges.value = d.get('badges', 0)

flags = _new_user_flags(UserFlags)
flags.value = d.get('flags', 0)

bot = d.get('bot')

return User(
Expand All @@ -1636,10 +1647,10 @@ def parse_user(self, d: raw.User, /) -> User | OwnUser:
discriminator=d['discriminator'],
display_name=d.get('display_name'),
internal_avatar=self.parse_asset(avatar) if avatar else None,
badges=UserBadges(d.get('badges', 0)),
badges=badges,
status=self.parse_user_status(status) if status else None,
# internal_profile=self.parse_user_profile(profile) if profile else None,
flags=UserFlags(d.get('flags', 0)),
flags=flags,
privileged=d.get('privileged', False),
bot=self.parse_bot_user_info(bot) if bot else None,
relationship=RelationshipStatus(d['relationship']),
Expand Down

0 comments on commit 1416367

Please sign in to comment.