From f8dbdbcfa011692d82897cd9222d4a97ccfe608f Mon Sep 17 00:00:00 2001 From: olijeffers0n <69084614+olijeffers0n@users.noreply.github.com> Date: Wed, 1 Dec 2021 22:33:46 +0000 Subject: [PATCH] Add RateLimiter --- rustplus/__init__.py | 2 +- rustplus/api/api.py | 66 ++++++++++++++++++++++++--- rustplus/api/socket/__init__.py | 3 +- rustplus/api/socket/echo_client.py | 4 +- rustplus/api/socket/repeated_timer.py | 33 ++++++++++++++ rustplus/api/socket/token_bucket.py | 31 +++++++++++++ rustplus/exceptions/__init__.py | 2 +- rustplus/exceptions/exceptions.py | 8 ++++ setup.py | 2 +- 9 files changed, 138 insertions(+), 13 deletions(-) create mode 100644 rustplus/api/socket/repeated_timer.py create mode 100644 rustplus/api/socket/token_bucket.py diff --git a/rustplus/__init__.py b/rustplus/__init__.py index b764171..60e1a0b 100644 --- a/rustplus/__init__.py +++ b/rustplus/__init__.py @@ -3,7 +3,7 @@ """ from .api import RustSocket -from .exceptions import ClientError, ImageError, ServerNotResponsiveError, ClientNotConnectedError, PrefixNotDefinedError, CommandsNotEnabledError +from .exceptions import ClientError, ImageError, ServerNotResponsiveError, ClientNotConnectedError, PrefixNotDefinedError, CommandsNotEnabledError, RustSocketDestroyedError from .commands import CommandOptions, Command __name__ = "rustplus" diff --git a/rustplus/api/api.py b/rustplus/api/api.py index 1846fa1..6c3cf92 100644 --- a/rustplus/api/api.py +++ b/rustplus/api/api.py @@ -6,16 +6,16 @@ from importlib import resources import asyncio -from .socket import EchoClient +from .socket import EchoClient, TokenBucket from .rustplus_pb2 import * from .structures import RustTime, RustInfo, RustMap, RustMarker, RustChatMessage, RustSuccess, RustTeamInfo, RustTeamMember, RustTeamNote, RustEntityInfo, RustContents, RustItem from ..utils import MonumentNameToImage, TimeParser, CoordUtil, ErrorChecker, IdToName, MapMarkerConverter -from ..exceptions import ImageError, ServerNotResponsiveError, ClientNotConnectedError, CommandsNotEnabledError +from ..exceptions import ImageError, ServerNotResponsiveError, ClientNotConnectedError, CommandsNotEnabledError, RustSocketDestroyedError, RateLimitError from ..commands import CommandOptions, RustCommandHandler class RustSocket: - def __init__(self, ip : str, port : str, steamid : int, playertoken : int, command_options : CommandOptions = None) -> None: + def __init__(self, ip : str, port : str, steamid : int, playertoken : int, command_options : CommandOptions = None, raise_ratelimit_exception : bool = True, ratelimit_limit : int = 25, ratelimit_refill : int = 3) -> None: self.seq = 1 self.ip = ip @@ -28,6 +28,8 @@ def __init__(self, ip : str, port : str, steamid : int, playertoken : int, comma self.prefix = None self.command_handler = None self.ws = None + self.bucket = TokenBucket(ratelimit_limit, ratelimit_limit, 1, ratelimit_refill) + self.raise_ratelimit_exception = raise_ratelimit_exception if command_options is not None: @@ -71,6 +73,21 @@ async def __sendAndRecieve(self, request, response = True) -> AppMessage: return None + async def __handle_ratelimit(self, cost = 1) -> None: + + while True: + + if self.bucket.can_consume(cost): + self.bucket.consume(cost) + return + + if self.raise_ratelimit_exception: + raise RateLimitError("Out of tokens") + + await asyncio.sleep(1) + + ## End of Utility Functions + async def __getTime(self) -> RustTime: request = self.__initProto() @@ -282,7 +299,7 @@ async def __getCurrentEvents(self): async def __start_websocket(self) -> None: self.ws = EchoClient(ip=self.ip, port=self.port, api=self, protocols=['http-only', 'chat']) - self.ws.daemon = False + self.ws.daemon = True self.ws.connect() async def connect(self) -> None: @@ -290,9 +307,12 @@ async def connect(self) -> None: Connect to the Rust Server """ + if self.bucket is None: + raise RustSocketDestroyedError("Socket is terminated") + await self.__start_websocket() - #except: + # TODO Make a `create_connection` request to the server to ping & check it is online # raise ServerNotResponsiveError("The sever is not available to connect to - your ip/port are either correct or the server is offline") async def closeConnection(self) -> None: @@ -311,28 +331,50 @@ async def disconnect(self) -> None: """ await self.closeConnection() + async def terminate(self) -> None: + """ + Closes and shuts down any processes like the rate limit manager. + You CANNOT reconnect after this. + You take your own responsibilty for managing your rate limit. + """ + await self.closeConnection() + self.bucket.refiller.stop() + self.bucket.refiller = None + self.bucket = None + async def getTime(self) -> RustTime: """ Gets the current in-game time """ + + self.__handle_ratelimit() return await self.__getTime() async def getInfo(self) -> RustInfo: """ Gets information on the Rust Server """ + + self.__handle_ratelimit() return await self.__getInfo() async def getRawMapData(self) -> RustMap: """ Returns the list of monuments on the server. This is a relatively expensive operation as the monuments are part of the map data """ + await self.__handle_ratelimit(6) return await self.__getRawMapData() async def getMap(self, addIcons : bool = False, addEvents : bool = False, addVendingMachines : bool = False, overrideImages : dict = {}) -> Image: """ Returns the Map of the server with the option to add icons. """ + + cost = 6 + if addIcons or addEvents or addVendingMachines: + cost += 1 + + await self.__handle_ratelimit(cost) return await self.__getAndFormatMap(addIcons, addEvents, addVendingMachines, overrideImages) async def getMarkers(self) -> List[RustMarker]: @@ -340,6 +382,7 @@ async def getMarkers(self) -> List[RustMarker]: Gets the map markers for the server. Returns a list of them """ + await self.__handle_ratelimit() return await self.__getMarkers() async def getTeamChat(self) -> List[RustChatMessage]: @@ -347,6 +390,7 @@ async def getTeamChat(self) -> List[RustChatMessage]: Returns a list of RustChatMessage objects """ + await self.__handle_ratelimit() return await self.__getTeamChat() async def sendTeamMessage(self, message : str) -> RustSuccess: @@ -354,6 +398,7 @@ async def sendTeamMessage(self, message : str) -> RustSuccess: Sends a team chat message as yourself. Returns the success data back from the server. Can be ignored """ + await self.__handle_ratelimit(2) return await self.__sendTeamChatMessage(message) async def getTeamInfo(self) -> RustTeamInfo: @@ -361,6 +406,7 @@ async def getTeamInfo(self) -> RustTeamInfo: Returns an AppTeamInfo object of the players in your team, as well as a lot of data about them """ + await self.__handle_ratelimit() return await self.__getTeamInfo() async def turnOnSmartSwitch(self, EID : int) -> RustSuccess: @@ -368,6 +414,7 @@ async def turnOnSmartSwitch(self, EID : int) -> RustSuccess: Turns on a smart switch on the server """ + await self.__handle_ratelimit() return await self.__updateSmartDevice(EID, True) async def turnOffSmartSwitch(self, EID : int) -> RustSuccess: @@ -375,6 +422,7 @@ async def turnOffSmartSwitch(self, EID : int) -> RustSuccess: Turns off a smart switch on the server """ + await self.__handle_ratelimit() return await self.__updateSmartDevice(EID, False) async def getEntityInfo(self, EID : int) -> RustEntityInfo: @@ -382,6 +430,7 @@ async def getEntityInfo(self, EID : int) -> RustEntityInfo: Get the entity info from a given entity ID """ + await self.__handle_ratelimit() return await self.__getEntityInfo(EID) async def promoteToTeamLeader(self, SteamID : int) -> RustSuccess: @@ -389,6 +438,7 @@ async def promoteToTeamLeader(self, SteamID : int) -> RustSuccess: Promotes a given user to the team leader by their 64-bit Steam ID """ + await self.__handle_ratelimit() return await self.__promoteToTeamLeader(SteamID) async def getTCStorageContents(self, EID : int, combineStacks : bool = False) -> RustContents: @@ -396,7 +446,8 @@ async def getTCStorageContents(self, EID : int, combineStacks : bool = False) -> Gets the Information about TC Upkeep and Contents. Do not use this for any other storage monitor than a TC """ - + + await self.__handle_ratelimit() return await self.__getTCStorage(EID, combineStacks) async def getCurrentEvents(self) -> List[RustMarker]: @@ -410,7 +461,8 @@ async def getCurrentEvents(self) -> List[RustMarker]: Returns the MapMarker for the event """ - + + await self.__handle_ratelimit() return await self.__getCurrentEvents() def command(self, coro) -> None: diff --git a/rustplus/api/socket/__init__.py b/rustplus/api/socket/__init__.py index 67e5434..b22a215 100644 --- a/rustplus/api/socket/__init__.py +++ b/rustplus/api/socket/__init__.py @@ -1 +1,2 @@ -from .echo_client import EchoClient \ No newline at end of file +from .echo_client import EchoClient +from .token_bucket import TokenBucket \ No newline at end of file diff --git a/rustplus/api/socket/echo_client.py b/rustplus/api/socket/echo_client.py index ba2c61a..796d829 100644 --- a/rustplus/api/socket/echo_client.py +++ b/rustplus/api/socket/echo_client.py @@ -12,10 +12,10 @@ def __init__(self, ip, port, api, protocols=None, extensions=None, heartbeat_fre self.api = api def opened(self): - pass + return def closed(self, code, reason): - pass + return def received_message(self, message): diff --git a/rustplus/api/socket/repeated_timer.py b/rustplus/api/socket/repeated_timer.py new file mode 100644 index 0000000..b935e83 --- /dev/null +++ b/rustplus/api/socket/repeated_timer.py @@ -0,0 +1,33 @@ +from threading import Timer + +class RepeatedTimer(object): + + def __init__(self, interval, function): + + self._timer = None + self.interval = interval + self.function = function + self.is_running = False + self.stopped = False + self.start() + + def _run(self): + + if self.stopped: + return + self.is_running = False + self.start() + self.function() + + def start(self): + + if not self.is_running: + self._timer = Timer(self.interval, self._run) + self._timer.start() + self.is_running = True + + def stop(self): + + self._timer.cancel() + self.is_running = False + self.stopped = True \ No newline at end of file diff --git a/rustplus/api/socket/token_bucket.py b/rustplus/api/socket/token_bucket.py new file mode 100644 index 0000000..67bd6ae --- /dev/null +++ b/rustplus/api/socket/token_bucket.py @@ -0,0 +1,31 @@ +from .repeated_timer import RepeatedTimer + + +class TokenBucket: + + def __init__(self, current: int, maximum: int, refresh_rate, refresh_amount) -> None: + + self.current = current + self.max = maximum + self.refresh_rate = refresh_rate + self.refresh_amount = refresh_amount + + self.refiller = RepeatedTimer(self.refresh_rate, self.refresh) + + def can_consume(self, amount: int = 1) -> bool: + + if (self.current - amount) >= 0: + return True + + return False + + def consume(self, amount: int = 1) -> None: + + self.current -= amount + + def refresh(self) -> None: + + if (self.max - self.current) > self.refresh_amount: + self.current += self.refresh_amount + else: + self.current = self.max \ No newline at end of file diff --git a/rustplus/exceptions/__init__.py b/rustplus/exceptions/__init__.py index 1aca11c..2dbc81e 100644 --- a/rustplus/exceptions/__init__.py +++ b/rustplus/exceptions/__init__.py @@ -1 +1 @@ -from .exceptions import ClientError, ImageError, ServerNotResponsiveError, ClientNotConnectedError, PrefixNotDefinedError, CommandsNotEnabledError +from .exceptions import ClientError, ImageError, ServerNotResponsiveError, ClientNotConnectedError, PrefixNotDefinedError, CommandsNotEnabledError, RustSocketDestroyedError, RateLimitError diff --git a/rustplus/exceptions/exceptions.py b/rustplus/exceptions/exceptions.py index fc547d0..e9959e4 100644 --- a/rustplus/exceptions/exceptions.py +++ b/rustplus/exceptions/exceptions.py @@ -24,4 +24,12 @@ class PrefixNotDefinedError(Error): class CommandsNotEnabledError(Error): """Raised when events are not enabled""" + pass + +class RustSocketDestroyedError(Error): + """Raised when the RustSocket has had #terminate called and tries to reconnect""" + pass + +class RateLimitError(Error): + """Raised When an issue with the ratelimit has occurred""" pass \ No newline at end of file diff --git a/setup.py b/setup.py index 3d9349c..ae1a6d3 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ }, version="4.1.2", include_package_data=True, - packages = ['rustplus', 'rustplus.api', 'rustplus.api.icons', 'rustplus.exceptions', 'rustplus.utils', 'rustplus.api.structures', 'rustplus.commands'], + packages = ['rustplus', 'rustplus.api', 'rustplus.api.icons', 'rustplus.exceptions', 'rustplus.utils', 'rustplus.api.structures', 'rustplus.commands', 'rustplus.api.socket'], license='MIT', description='A python wrapper for the Rust Plus API', long_description=readme,