Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FrontendManager to manage non-default front-end impl #3897

Merged
merged 18 commits into from
Jul 16, 2024
Merged
4 changes: 4 additions & 0 deletions .github/workflows/test-ui.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ jobs:
npm run test:generate
npm test -- --verbose
working-directory: ./tests-ui
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt
python -m pytest tests-unit
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ venv/
!/web/extensions/core/
/tests-ui/data/object_info.json
/user/
*.log
*.log
web_custom_versions/
Empty file added app/__init__.py
Empty file.
220 changes: 220 additions & 0 deletions app/frontend_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import argparse
import logging
import os
import re
import tempfile
import zipfile
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict

import requests
from typing_extensions import NotRequired


REQUEST_TIMEOUT = 10 # seconds


class Asset(TypedDict):
url: str


class Release(TypedDict):
id: int
tag_name: str
name: str
prerelease: bool
created_at: str
published_at: str
body: str
assets: NotRequired[list[Asset]]


@dataclass
class FrontEndProvider:
owner: str
repo: str

@property
def folder_name(self) -> str:
return f"{self.owner}_{self.repo}"

@property
def release_url(self) -> str:
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"

@cached_property
def all_releases(self) -> list[Release]:
releases = []
api_url = self.release_url
while api_url:
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status() # Raises an HTTPError if the response was an error
releases.extend(response.json())
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
if "next" in response.links:
api_url = response.links["next"]["url"]
else:
api_url = None
return releases

@cached_property
def latest_release(self) -> Release:
latest_release_url = f"{self.release_url}/latest"
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status() # Raises an HTTPError if the response was an error
return response.json()

def get_release(self, version: str) -> Release:
if version == "latest":
return self.latest_release
else:
for release in self.all_releases:
if release["tag_name"] in [version, f"v{version}"]:
return release
raise ValueError(f"Version {version} not found in releases")


def download_release_asset_zip(release: Release, destination_path: str) -> None:
"""Download dist.zip from github release."""
asset_url = None
for asset in release.get("assets", []):
if asset["name"] == "dist.zip":
asset_url = asset["url"]
break

if not asset_url:
raise ValueError("dist.zip not found in the release assets")

# Use a temporary file to download the zip content
with tempfile.TemporaryFile() as tmp_file:
headers = {"Accept": "application/octet-stream"}
response = requests.get(
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
)
response.raise_for_status() # Ensure we got a successful response

# Write the content to the temporary file
tmp_file.write(response.content)

# Go back to the beginning of the temporary file
tmp_file.seek(0)

# Extract the zip file content to the destination path
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
zip_ref.extractall(destination_path)


class FrontendManager:
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")

@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Args:
value (str): The version string to parse.

Returns:
tuple[str, str]: A tuple containing provider name and version.

Raises:
argparse.ArgumentTypeError: If the version string is invalid.
"""
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(\d+\.\d+\.\d+|latest)$"
match_result = re.match(VERSION_PATTERN, value)
if match_result is None:
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")

return match_result.group(1), match_result.group(2), match_result.group(3)

@classmethod
def add_argument(cls, parser: argparse.ArgumentParser):
parser.add_argument(
"--front-end-version",
type=str,
default=cls.DEFAULT_VERSION_STRING,
help="""
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
download available frontend implementations from GitHub releases.

The version string should be in the format of:
[repoOwner]/[repoName]@[version]
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
""",
)

def is_valid_directory(path: str | None) -> str | None:
"""Validate if the given path is a directory."""
if path is None:
return None

if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
return path

parser.add_argument(
"--front-end-root",
type=is_valid_directory,
default=None,
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)

@classmethod
def init_frontend_unsafe(cls, version_string: str) -> str:
"""
Initializes the frontend for the specified version.

Args:
version_string (str): The version string.

Returns:
str: The path to the initialized frontend.

Raises:
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
"""
if version_string == cls.DEFAULT_VERSION_STRING:
return cls.DEFAULT_FRONTEND_PATH

repo_owner, repo_name, version = cls.parse_version_string(version_string)
provider = FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)

semantic_version = release["tag_name"].lstrip("v")
web_root = str(
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
)
if not os.path.exists(web_root):
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
return web_root

@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initializes the frontend with the specified version string.

Args:
version_string (str): The version string to initialize the frontend with.

Returns:
str: The path of the initialized frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
return cls.DEFAULT_FRONTEND_PATH
3 changes: 3 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
import enum
import comfy.options
from app.frontend_management import FrontendManager
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comfy/cli_args.py should not import this file.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things inside the comfy folder should not depend on things outside of it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inlined.



class EnumAction(argparse.Action):
"""
Expand Down Expand Up @@ -124,6 +126,7 @@ class LatentPreviewMethod(enum.Enum):

parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")

FrontendManager.add_argument(parser)

if comfy.options.args_parsing:
args = parser.parse_args()
Expand Down
7 changes: 5 additions & 2 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
[pytest]
markers =
inference: mark as inference test (deselect with '-m "not inference"')
testpaths = tests
addopts = -s
testpaths =
tests
tests-unit
addopts = -s
pythonpath = .
11 changes: 8 additions & 3 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
from comfy.cli_args import args
import comfy.utils
import comfy.model_management

from app.frontend_management import FrontendManager
from app.user_manager import UserManager


class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
Expand Down Expand Up @@ -83,8 +84,12 @@ def __init__(self, loop):
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
self.sockets = dict()
self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web")
self.web_root = (
FrontendManager.init_frontend(args.front_end_version)
if args.front_end_root is None
else args.front_end_root
)
logging.info(f"[Prompt Server] web root: {self.web_root}")
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None
Expand Down
8 changes: 8 additions & 0 deletions tests-unit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Pytest Unit Tests

## Install test dependencies

`pip install -r tests-units/requirements.txt`

## Run tests
`pytest tests-units/`
Empty file added tests-unit/__init__.py
Empty file.
Empty file added tests-unit/comfy/__init__.py
Empty file.
Loading
Loading