Skip to content

Commit

Permalink
Support new license format. (#144)
Browse files Browse the repository at this point in the history
* Support new license format.

* Fix.

* revert.

* revert.

* Add warning.

* Handle deprecated field.

* Improve error message.
  • Loading branch information
robinjhuang committed Jul 29, 2024
1 parent 923c7ac commit bf3a042
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 18 deletions.
5 changes: 2 additions & 3 deletions comfy_cli/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from rich import print
from rich.console import Console
from typing_extensions import Annotated, List

from comfy_cli import constants, env_checker, logging, tracking, ui, utils
from comfy_cli import constants, env_checker, logging, tracking, utils, ui
from comfy_cli.command import custom_nodes
from comfy_cli.command import install as install_inner
from comfy_cli.command import run as run_inner
Expand Down Expand Up @@ -334,7 +333,7 @@ def update(
"comfy",
help="[all|comfy]",
autocompletion=utils.create_choice_completer(["all", "comfy"]),
)
),
):
if target not in ["all", "comfy"]:
typer.echo(
Expand Down
35 changes: 26 additions & 9 deletions comfy_cli/registry/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
import requests

# Reduced global imports from comfy_cli.registry
from comfy_cli.registry.types import Node, NodeVersion, PublishNodeVersionResponse
from comfy_cli.registry.types import (
Node,
NodeVersion,
PublishNodeVersionResponse,
PyProjectConfig,
License,
)


class RegistryAPI:
Expand All @@ -17,11 +23,13 @@ def determine_base_url(self):
if env == "dev":
return "http://localhost:8080"
elif env == "staging":
return "https://staging.comfyregistry.org"
return "https://stagingapi.comfy.org"
else:
return "https://api.comfy.org"

def publish_node_version(self, node_config, token) -> PublishNodeVersionResponse:
def publish_node_version(
self, node_config: PyProjectConfig, token
) -> PublishNodeVersionResponse:
"""
Publishes a new version of a node.
Expand All @@ -33,7 +41,6 @@ def publish_node_version(self, node_config, token) -> PublishNodeVersionResponse
PublishNodeVersionResponse: The response object from the API server.
"""
# Local import to prevent circular dependency

if not node_config.tool_comfy.publisher_id:
raise Exception(
"Publisher ID is required in pyproject.toml to publish a node version"
Expand All @@ -43,24 +50,26 @@ def publish_node_version(self, node_config, token) -> PublishNodeVersionResponse
raise Exception(
"Project name is required in pyproject.toml to publish a node version"
)

url = f"{self.base_url}/publishers/{node_config.tool_comfy.publisher_id}/nodes/{node_config.project.name}/versions"
headers = {"Content-Type": "application/json"}
body = {
license_json = serialize_license(node_config.project.license)
request_body = {
"personal_access_token": token,
"node": {
"id": node_config.project.name,
"description": node_config.project.description,
"icon": node_config.tool_comfy.icon,
"name": node_config.tool_comfy.display_name,
"license": node_config.project.license,
"license": license_json,
"repository": node_config.project.urls.repository,
},
"node_version": {
"version": node_config.project.version,
"dependencies": node_config.project.dependencies,
},
}
print(request_body)
url = f"{self.base_url}/publishers/{node_config.tool_comfy.publisher_id}/nodes/{node_config.project.name}/versions"
headers = {"Content-Type": "application/json"}
body = request_body

response = requests.post(url, headers=headers, data=json.dumps(body))

Expand Down Expand Up @@ -177,3 +186,11 @@ def map_node_to_node_class(api_node_data):
else None
),
)


def serialize_license(license: License) -> str:
if license.file:
return json.dumps({"file": license.file})
if license.text:
return json.dumps({"text": license.text})
return "{}"
28 changes: 26 additions & 2 deletions comfy_cli/registry/config_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import subprocess
from typing import Optional
import typer

import tomlkit
import tomlkit.exceptions
Expand All @@ -11,6 +12,7 @@
Model,
ProjectConfig,
PyProjectConfig,
License,
URLs,
)

Expand Down Expand Up @@ -140,13 +142,35 @@ def extract_node_configuration(
urls_data = project_data.get("urls", {})
comfy_data = data.get("tool", {}).get("comfy", {})

license_data = project_data.get("license", {})
if isinstance(license_data, str):
license = License(text=license_data)
typer.echo(
'Warning: License should be in one of these two formats: license = {file = "LICENSE"} OR license = {text = "MIT License"}. Please check the documentation: https://docs.comfy.org/registry/specifications.'
)
elif isinstance(license_data, dict):
if "file" in license_data or "text" in license_data:
license = License(
file=license_data.get("file", ""), text=license_data.get("text", "")
)
else:
typer.echo(
'Warning: License should be in one of these two formats: license = {file = "LICENSE"} OR license = {text = "MIT License"}. Please check the documentation: https://docs.comfy.org/registry/specifications.'
)
license = License()
else:
license = License()
typer.echo(
'Warning: License should be in one of these two formats: license = {file = "LICENSE"} OR license = {text = "MIT License"}. Please check the documentation: https://docs.comfy.org/registry/specifications.'
)

project = ProjectConfig(
name=project_data.get("name", ""),
description=project_data.get("description", ""),
version=project_data.get("version", ""),
requires_python=project_data.get("requires-pyton", ""),
requires_python=project_data.get("requires-python", ""),
dependencies=project_data.get("dependencies", []),
license=project_data.get("license", ""),
license=license,
urls=URLs(
homepage=urls_data.get("Homepage", ""),
documentation=urls_data.get("Documentation", ""),
Expand Down
7 changes: 5 additions & 2 deletions comfy_cli/registry/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,18 @@ class ComfyConfig:
icon: str = ""
models: List[Model] = field(default_factory=list)


@dataclass
class License:
file: str = ""
text: str = ""
@dataclass
class ProjectConfig:
name: str = ""
description: str = ""
version: str = "1.0.0"
requires_python: str = ">= 3.9"
dependencies: List[str] = field(default_factory=list)
license: str = ""
license: License = field(default_factory=License)
urls: URLs = field(default_factory=URLs)


Expand Down
4 changes: 2 additions & 2 deletions tests/comfy_cli/registry/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from comfy_cli.registry import PyProjectConfig
from comfy_cli.registry.api import RegistryAPI
from comfy_cli.registry.types import ComfyConfig, ProjectConfig, URLs
from comfy_cli.registry.types import ComfyConfig, ProjectConfig, URLs, License


class TestRegistryAPI(unittest.TestCase):
Expand All @@ -16,7 +16,7 @@ def setUp(self):
version="0.1.0",
requires_python=">= 3.9",
dependencies=["dep1", "dep2"],
license="MIT",
license=License(file="LICENSE"),
urls=URLs(repository="https://github.com/test/test_node"),
),
tool_comfy=ComfyConfig(
Expand Down
125 changes: 125 additions & 0 deletions tests/comfy_cli/registry/test_config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from unittest.mock import patch, mock_open
import pytest
from comfy_cli.registry.config_parser import extract_node_configuration
from comfy_cli.registry.types import (
PyProjectConfig,
ProjectConfig,
License,
URLs,
ComfyConfig,
Model,
)


@pytest.fixture
def mock_toml_data():
return {
"project": {
"name": "test-project",
"description": "A test project",
"version": "1.0.0",
"requires-python": ">=3.7",
"dependencies": ["requests"],
"license": {"file": "LICENSE"},
"urls": {
"Homepage": "https://example.com",
"Documentation": "https://docs.example.com",
"Repository": "https://github.com/example/test-project",
"Issues": "https://github.com/example/test-project/issues",
},
},
"tool": {
"comfy": {
"PublisherId": "test-publisher",
"DisplayName": "Test Project",
"Icon": "icon.png",
"Models": [
{
"location": "model1.bin",
"model_url": "https://example.com/model1",
},
{
"location": "model2.bin",
"model_url": "https://example.com/model2",
},
],
}
},
}


def test_extract_node_configuration_success(mock_toml_data):
with patch("os.path.isfile", return_value=True), patch(
"builtins.open", mock_open()
), patch("tomlkit.load", return_value=mock_toml_data):
result = extract_node_configuration("fake_path.toml")

assert isinstance(result, PyProjectConfig)
assert result.project.name == "test-project"
assert result.project.description == "A test project"
assert result.project.version == "1.0.0"
assert result.project.requires_python == ">=3.7"
assert result.project.dependencies == ["requests"]
assert result.project.license == License(file="LICENSE")
assert result.project.urls == URLs(
homepage="https://example.com",
documentation="https://docs.example.com",
repository="https://github.com/example/test-project",
issues="https://github.com/example/test-project/issues",
)
assert result.tool_comfy.publisher_id == "test-publisher"
assert result.tool_comfy.display_name == "Test Project"
assert result.tool_comfy.icon == "icon.png"
assert len(result.tool_comfy.models) == 2
assert result.tool_comfy.models[0] == Model(
location="model1.bin", model_url="https://example.com/model1"
)


def test_extract_node_configuration_license_text():
mock_data = {
"project": {
"license": "MIT License",
},
}
with patch("os.path.isfile", return_value=True), patch(
"builtins.open", mock_open()
), patch("tomlkit.load", return_value=mock_data):
result = extract_node_configuration("fake_path.toml")
assert result is not None, "Expected PyProjectConfig, got None"
assert isinstance(result, PyProjectConfig)
assert result.project.license == License(text="MIT License")


def test_extract_node_configuration_license_text_dict():
mock_data = {
"project": {
"license": {
"text": "MIT License\n\nCopyright (c) 2023 Example Corp\n\nPermission is hereby granted..."
},
},
}
with patch("os.path.isfile", return_value=True), patch(
"builtins.open", mock_open()
), patch("tomlkit.load", return_value=mock_data):
result = extract_node_configuration("fake_path.toml")

assert result is not None, "Expected PyProjectConfig, got None"
assert isinstance(result, PyProjectConfig)
assert result.project.license == License(
text="MIT License\n\nCopyright (c) 2023 Example Corp\n\nPermission is hereby granted..."
)


def test_extract_license_incorrect_format():
mock_data = {
"project": {"license": "MIT"},
}
with patch("os.path.isfile", return_value=True), patch(
"builtins.open", mock_open()
), patch("tomlkit.load", return_value=mock_data):
result = extract_node_configuration("fake_path.toml")

assert result is not None, "Expected PyProjectConfig, got None"
assert isinstance(result, PyProjectConfig)
assert result.project.license == License(text="MIT")

0 comments on commit bf3a042

Please sign in to comment.