diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 6ac8c1c..00569b4 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -37,17 +37,13 @@ jobs: python-version: ${{ matrix.python-version }} - name: install flatbuffers - run: sudo apt-get install -y flatbuffers-compiler + run: | + wget -O flatc.zip https://github.com/google/flatbuffers/releases/download/v24.3.25/Linux.flatc.binary.clang++-15.zip + unzip flatc.zip && chmod +x flatc && sudo mv flatc /usr/bin/flatc - name: install nox run: python -m pip install nox - - name: Restore cached guest accounts - uses: actions/cache@v3 - with: - path: /tmp/arkprts_auth_cache.json - key: guest-auth-cache-${{ matrix.python-version }} - - name: Run tests run: | python -m nox -s test --verbose diff --git a/README.md b/README.md index f3546ef..6c56758 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ Programmatically getting auth tokens from a user on your website. @route("/code") def code(request): auth = arkprts.YostarAuth(request.query["server"], network=...) - await auth.get_token_from_email_code(request.query["email"]) + await auth.send_email_code(request.query["email"]) return "Code sent!" diff --git a/arkprts/__main__.py b/arkprts/__main__.py index 60c28a4..5a155c7 100644 --- a/arkprts/__main__.py +++ b/arkprts/__main__.py @@ -2,39 +2,40 @@ import argparse import asyncio +import json import logging +import sys +import typing import arkprts -parser: argparse.ArgumentParser = argparse.ArgumentParser(description="Get user info from Arknights API.") -parser.add_argument("nickname", type=str, nargs="?", default=None, help="User nickname") -parser.add_argument("--uid", type=str, default=None, help="Channel UID") +parser: argparse.ArgumentParser = argparse.ArgumentParser("Use the arknights API") +parser.add_argument("--log-level", type=str, default="ERROR", help="Logging level") +parser.add_argument("--channel-uid", type=str, default=None, help="Channel UID") parser.add_argument("--token", type=str, default="", help="Yostar Token") -parser.add_argument("--log-level", type=str, default="INFO", help="Logging level") -parser.add_argument("--server", type=str, default="en", help="Server to use, global only") +parser.add_argument("--server", type=str, default="en", help="Server to use") +parser.add_argument("--guest", action="store_true", help="Whether to use a guest account.") +subparsers: argparse._SubParsersAction[argparse.ArgumentParser] = parser.add_subparsers(dest="command", required=True) -async def main() -> None: - """Entry-point.""" - args = parser.parse_args() +parser_search: argparse.ArgumentParser = subparsers.add_parser("search", description="Get user info.") +parser_search.add_argument( + "nickname", + type=str, + nargs="?", + default=None, + help="User nickname, gives your friends by default", +) - logging.basicConfig() - logging.getLogger("arkprts").setLevel(args.log_level.upper()) +parser_api: argparse.ArgumentParser = subparsers.add_parser("api", description="Make a request towards the API.") +parser_api.add_argument("endpoint", type=str, nargs="?", help="Endpoint path, not full url") +parser_api.add_argument("payload", type=str, nargs="?", default="{}", help="JSON payload") - if args.uid and args.token: - auth = arkprts.YostarAuth(args.server) - await auth.login_with_token(args.uid, args.token) - elif args.nickname: - auth = arkprts.GuestAuth(max_sessions=1) - auth.network.default_server = args.server - else: - auth = arkprts.YostarAuth(args.server) - await auth.login_with_email_code() - - client = arkprts.Client(auth) - if args.nickname: - players = await client.search_players(args.nickname, limit=10) +async def search(client: arkprts.Client, nickname: typing.Optional[str] = None) -> None: + """Get user info.""" + if nickname: + players = await client.search_players(nickname, limit=10) else: players = await client.get_friends() print("Friends:", end="\n\n") @@ -67,5 +68,44 @@ async def main() -> None: await client.network.close() +async def api(client: arkprts.Client, endpoint: str, payload: typing.Optional[str] = None) -> None: + """Make a request.""" + client.assets.loaded = True + try: + data = await client.request(endpoint, json=payload and json.loads(payload)) + json.dump(data, sys.stdout, indent=4, ensure_ascii=False) + except arkprts.errors.GameServerError as e: + json.dump(e.data, sys.stdout, indent=4, ensure_ascii=False) + finally: + await client.network.close() + + sys.stdout.write("\n") + + +async def main() -> None: + """Entry-point.""" + args = parser.parse_args() + + logging.basicConfig() + logging.getLogger("arkprts").setLevel(args.log_level.upper()) + + if args.channel_uid and args.token: + auth = arkprts.YostarAuth(args.server) + await auth.login_with_token(args.channel_uid, args.token) + elif args.guest or (args.command == "search" and args.nickname): + auth = arkprts.GuestAuth(max_sessions=1) + auth.network.default_server = args.server + else: + auth = arkprts.YostarAuth(args.server) + await auth.login_with_email_code() + + client = arkprts.Client(auth) + + if args.command == "search": + await search(client, args.nickname) + elif args.command == "api": + await api(client, args.endpoint, args.payload) + + if __name__ == "__main__": asyncio.run(main()) diff --git a/arkprts/assets/base.py b/arkprts/assets/base.py index f77c86d..172bf59 100644 --- a/arkprts/assets/base.py +++ b/arkprts/assets/base.py @@ -8,7 +8,6 @@ import abc import bisect -import importlib.util import json import logging import pathlib @@ -57,16 +56,13 @@ def create( ) -> Assets: """Create a new assets file based on what libraries are available.""" try: - importlib.util.find_spec("UnityPy") + from . import bundle + + return bundle.BundleAssets(path, default_server=default_server, network=network) except ImportError: from . import git - LOGGER.debug("Creating GitAssets due to UnityPy being unavailable") return git.GitAssets(path, default_server=default_server or "en") - else: - from . import bundle - - return bundle.BundleAssets(path, default_server=default_server, network=network) @abc.abstractmethod async def update_assets(self) -> None: diff --git a/arkprts/assets/bundle.py b/arkprts/assets/bundle.py index 76278a0..20b4f0b 100644 --- a/arkprts/assets/bundle.py +++ b/arkprts/assets/bundle.py @@ -16,6 +16,7 @@ import os import pathlib import re +import shlex import subprocess import tempfile import typing @@ -98,29 +99,29 @@ def run_flatbuffers( output_directory: PathLike, ) -> pathlib.Path: """Run the flatbuffers cli. Returns the output filename.""" - code = subprocess.call( - [ # noqa: S603 # check for execution of untrusted input - "flatc", - "-o", - str(output_directory), - str(fbs_schema_path), - "--", - str(fbs_path), - "--json", - "--strict-json", - "--natural-utf8", - "--defaults-json", - "--unknown-json", - "--raw-binary", - # unfortunately not in older versions - # "--no-warnings", - "--force-empty", - ], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - if code != 0: - raise ValueError(f"flatc failed with code {code}") + args = [ + "flatc", + "-o", + str(output_directory), + str(fbs_schema_path), + "--", + str(fbs_path), + "--json", + "--strict-json", + "--natural-utf8", + "--defaults-json", + "--unknown-json", + "--raw-binary", + # "--no-warnings", + "--force-empty", + ] + result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False) # noqa: S603, UP022 + if result.returncode != 0: + file = pathlib.Path(tempfile.mktemp(".log")) + file.write_bytes(result.stdout + b"\n\n\n\n" + result.stderr) + raise ValueError( + f"flatc failed with code {result.returncode}: {file} `{shlex.join(args)}` (random exit code likely means a faulty FBS file was provided)", + ) return pathlib.Path(output_directory) / (pathlib.Path(fbs_path).stem + ".json") @@ -145,8 +146,8 @@ async def update_fbs_schema(*, force: bool = False) -> None: continue UPDATED_FBS[server] = True - directory = resolve_fbs_schema_directory(server).parent - await git.update_repository("MooncellWiki/OpenArknightsFBS", directory, branch=branch, force=force) + directory = resolve_fbs_schema_directory(server).parent # pyright: ignore[reportArgumentType] + await git.download_repository("MooncellWiki/OpenArknightsFBS", directory, branch=branch, force=force) def recursively_collapse_keys(obj: typing.Any) -> typing.Any: @@ -333,6 +334,19 @@ def __init__( network: netn.NetworkSession | None = None, json_loads: typing.Callable[[bytes], typing.Any] = json.loads, ) -> None: + try: + # ensure optional dependencies have been installed + import bson # noqa: F401 # type: ignore + import Crypto.Cipher.AES # noqa: F401 # type: ignore + import PIL # noqa: F401 # type: ignore + import UnityPy # noqa: F401 # type: ignore + except ImportError as e: + raise ImportError("Cannot use BundleAssets without arkprts[assets]") from e + try: + subprocess.run(["flatc", "--version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False) # noqa: S603 + except OSError as e: + raise ImportError("Cannot use BundleAssets without a flatc executable") from e + super().__init__(default_server=default_server or "en", json_loads=json_loads) temporary_directory = pathlib.Path(tempfile.gettempdir()) diff --git a/arkprts/assets/git.py b/arkprts/assets/git.py index a96c302..6e4df89 100644 --- a/arkprts/assets/git.py +++ b/arkprts/assets/git.py @@ -72,6 +72,7 @@ async def download_github_tarball( ) -> pathlib.Path: """Download a tarball from github.""" destination = pathlib.Path(destination or tempfile.mktemp(f"{repository.split('/')[-1]}.tar.gz")) + destination.parent.mkdir(parents=True, exist_ok=True) if branch == "HEAD": url = f"https://api.github.com/repos/{repository}/tarball" else: @@ -93,7 +94,7 @@ def decompress_tarball(path: PathLike, destination: PathLike, *, allow: str = "* members: list[tarfile.TarInfo] = [] for member in tar.getmembers(): - if not fnmatch.fnmatch(allow, member.name): + if not fnmatch.fnmatch(member.name, allow): continue member.name = member.name[len(top_directory + "/") :] @@ -116,7 +117,7 @@ async def download_repository( destination = pathlib.Path(destination) commit_file = destination / "commit.txt" - if not force and commit_file.exists() and time.time() - commit_file.stat().st_mtime < 60 * 60 * 6: + if not force and commit_file.exists() and time.time() - commit_file.stat().st_mtime < 60: LOGGER.debug("%s was updated recently, skipping download", repository) return @@ -138,10 +139,8 @@ async def download_repository( ) LOGGER.debug("Decompressing %s", repository) - tarball_commit = decompress_tarball(tarball_path, destination, allow=allow) - LOGGER.debug("Decompressed %s %s", repository, tarball_commit) - if tarball_commit not in commit: - raise RuntimeError(f"Tarball commit {tarball_commit} does not match github commit {commit}") + decompress_tarball(tarball_path, destination, allow=allow) + LOGGER.debug("Decompressed %s %s", repository, commit) # sometimes the contents are identical, we still want to update the mtime commit_file.write_text(commit) @@ -166,6 +165,9 @@ async def update_git_repository(repository: str, directory: PathLike, *, branch: cwd=directory.parent, ) await proc.wait() + old_directory = directory.parent / repository.split("/", 1)[1] + if directory != old_directory: + old_directory.rename(directory) else: LOGGER.info("Updating %s in %s", repository, directory) proc = await asyncio.create_subprocess_exec("git", "pull", cwd=directory) diff --git a/arkprts/auth.py b/arkprts/auth.py index cafe8e3..ee85f61 100644 --- a/arkprts/auth.py +++ b/arkprts/auth.py @@ -126,8 +126,8 @@ def generate_u8_sign(data: typing.Mapping[str, object]) -> str: """u8 auth sign.""" query = urllib.parse.urlencode(sorted(data.items())) - hama_code = hmac.new(b"91240f70c09a08a6bc72af1a5c8d4670", query.encode(), "sha1") - return hama_code.hexdigest().lower() + code = hmac.new(b"91240f70c09a08a6bc72af1a5c8d4670", query.encode(), "sha1") + return code.hexdigest().lower() @dataclasses.dataclass() @@ -315,6 +315,10 @@ async def _get_secret( LOGGER.info("Logged in with UID %s", uid) return secret + @abc.abstractmethod + async def login_with_token(self, channel_uid: str, token: str, /) -> None: + """Login with a channel uid and token.""" + @classmethod async def from_token( cls, @@ -371,7 +375,7 @@ async def _get_access_token(self, channel_uid: str, yostar_token: str) -> str: data = await self.request_passport("user/login", json=body) return data["accessToken"] - async def _request_yostar_auth(self, email: str) -> None: + async def send_email_code(self, email: str) -> None: """Request to log in with a yostar account.""" LOGGER.debug("Sending code to %s.", email) body = {"platform": "android", "account": email, "authlang": "en"} @@ -430,7 +434,7 @@ async def get_token_from_email_code( email = input("Enter email:") if not code: - await self._request_yostar_auth(email) + await self.send_email_code(email) if not stdin: return "", "" diff --git a/arkprts/automation.py b/arkprts/automation.py index e28132c..75fea0f 100644 --- a/arkprts/automation.py +++ b/arkprts/automation.py @@ -21,46 +21,6 @@ __all__ = ["AutomationClient"] -def recursively_update_dict( - target: typing.MutableMapping[typing.Any, typing.Any], - source: typing.Mapping[typing.Any, typing.Any], -) -> None: - """Recursively update a dictionary. - - This is used to update the player data. - """ - for key, value in source.items(): - if isinstance(value, dict): - recursively_update_dict(target[key], typing.cast("dict[object, object]", value)) - elif isinstance(value, list): - for i, v in enumerate(typing.cast("list[object]", value)): - if isinstance(v, dict): - recursively_update_dict(target[key][i], typing.cast("dict[object, object]", v)) - else: - target[key] = value - - -def recursively_delete_dict( - target: typing.MutableMapping[typing.Any, typing.Any], - source: typing.Mapping[typing.Any, typing.Any], -) -> None: - """Recursively delete items from a dictionary. - - This is used to update the player data. - """ - for key, value in source.items(): - if isinstance(value, dict): - recursively_delete_dict(target[key], typing.cast("dict[object, object]", value)) - elif isinstance(value, list): - for i, v in enumerate(typing.cast("list[object]", value)): - if isinstance(v, dict): - recursively_delete_dict(target[key][i], typing.cast("dict[object, object]", v)) - else: - target[key].remove(v) - else: - del target[key] - - # https://github.com/Rhine-Department-0xf/Rhine-DFramwork/blob/main/client/encryption.py @@ -164,8 +124,7 @@ def update_player_data_delta( Does not implement deleted as it is often overwritten with "modified". """ - recursively_delete_dict(self.data, delta["deleted"]) - recursively_update_dict(self.data, delta["modified"]) + self.data.update(delta["modified"]) async def request(self, endpoint: str, json: typing.Mapping[str, object] = {}, **kwargs: typing.Any) -> typing.Any: """Send an authenticated request to the arknights game server.""" diff --git a/arkprts/errors.py b/arkprts/errors.py index dd0adaf..e83e0b7 100644 --- a/arkprts/errors.py +++ b/arkprts/errors.py @@ -28,12 +28,13 @@ class ArkPrtsError(BaseArkprtsError): def __init__(self, data: typing.Mapping[str, typing.Any]) -> None: self.data = data - super().__init__(f"[{data['result']}] {self.message} {data}") + super().__init__(f"[{data.get('result')}] {self.message} {data}") class GameServerError(BaseArkprtsError): """Game server error.""" + data: typing.Mapping[str, typing.Any] status_code: int error: str code: int @@ -41,6 +42,7 @@ class GameServerError(BaseArkprtsError): info: typing.Mapping[str, typing.Any] def __init__(self, data: typing.Mapping[str, typing.Any]) -> None: + self.data = data self.status_code = data.get("statusCode", 400) self.error = data["error"] self.code = data.get("code", 0) diff --git a/arkprts/network.py b/arkprts/network.py index f35373c..972bc8a 100644 --- a/arkprts/network.py +++ b/arkprts/network.py @@ -168,7 +168,7 @@ async def request( url = url + "/" + endpoint if method is None: - method = "POST" if kwargs.get("json") else "GET" + method = "POST" if kwargs.get("json") is not None else "GET" data = await self.raw_request(method, url, **kwargs) diff --git a/dev-requirements/pytest.txt b/dev-requirements/pytest.txt index 835b78e..ffa1dfb 100644 --- a/dev-requirements/pytest.txt +++ b/dev-requirements/pytest.txt @@ -1,5 +1,5 @@ pytest -pytest-asyncio==0.21.1 +pytest-asyncio pytest-dotenv pytest-cov coverage[toml] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 8065d87..e8b6527 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "arkprts" requires-python = ">=3.9" -version = "0.3.7" +version = "0.3.8" dynamic = [ "dependencies", "description", diff --git a/setup.py b/setup.py index 660d03f..d7104ea 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="arkprts", - version="0.3.7", + version="0.3.8", description="Arknights python wrapper.", url="https://github.com/thesadru/arkprts", packages=find_packages(exclude=["tests", "tests.*"]), diff --git a/tests/conftest.py b/tests/conftest.py index ba4d51d..8f0fa15 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,10 @@ def event_loop() -> typing.Iterator[asyncio.AbstractEventLoop]: with warnings.catch_warnings(): warnings.simplefilter("ignore") - loop = asyncio.get_event_loop() + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() yield loop loop.close() diff --git a/tests/test_assets.py b/tests/test_assets.py index 522bf57..3cbff07 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -1,5 +1,7 @@ """Test game assets.""" +import pytest + import arkprts @@ -9,10 +11,14 @@ async def test_update(client: arkprts.Client) -> None: async def test_bundle_assets() -> None: - assets = arkprts.BundleAssets() - await assets.update_assets() - - await assets.network.close() + try: + assets = arkprts.BundleAssets() + except ImportError: + pytest.skip("Environment is not ready for bundle download.") + else: + await assets.update_assets(force=True) + + await assets.network.close() async def test_git_assets() -> None: