From d8321ef8d003bd10492e37dce585ae93bdcf2460 Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Mon, 28 Aug 2023 17:14:59 -0700 Subject: [PATCH 01/10] Training dreambooth example --- nos/executors/ray.py | 2 +- nos/models/dreambooth/config.py | 136 ++++++++++++++++++++++++++++++++ nos/server/_service.py | 47 +++++++++++ 3 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 nos/models/dreambooth/config.py diff --git a/nos/executors/ray.py b/nos/executors/ray.py index d3190b62..72c2e868 100644 --- a/nos/executors/ray.py +++ b/nos/executors/ray.py @@ -211,7 +211,7 @@ def status(self, job_id: str) -> str: def logs(self, job_id: str) -> str: """Get logs for a job.""" return self.client.get_job_logs(job_id) - + def init(*args, **kwargs) -> bool: """Initialize Ray executor.""" diff --git a/nos/models/dreambooth/config.py b/nos/models/dreambooth/config.py new file mode 100644 index 00000000..4b88ddf0 --- /dev/null +++ b/nos/models/dreambooth/config.py @@ -0,0 +1,136 @@ +import json +import os +import shutil +import uuid +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Dict + +from nos.common.git import cached_repo +from nos.constants import NOS_TMP_DIR +from nos.logging import logger + + +GIT_TAG = "v0.20.1" +RUNTIME_ENVS = { + "diffusers-latest": { + "working_dir": "./nos/experimental/train/dreambooth", + "pip": [f"https://github.com/huggingface/diffusers/archive/refs/tags/{GIT_TAG}.zip", "accelerate>=0.22.0"], + } +} + + +@dataclass +class StableDiffusionTrainingJobConfig: + """Configuration for a training job. + + Training job contents are written to `~/.nos/tmp/{uuid}/`. + {uuid}_metadata.json: Metadata for the training job. + {uuid}_job_config.json: Job configuration for the training job. + """ + + model_name: str + """Model name (e.g `stabilityai/stable-diffusion-2-1`).""" + + method: str + """Stable diffusion training method (choice of `stable-diffusion-dreambooth-lora`).""" + + instance_directory: str + """Image instance directory (e.g. dog).""" + + instance_prompt: str + """Image instance prompt (e.g. A photo of sks dog in a bucket).""" + + max_train_steps: int = 500 + """Maximum number of training steps.""" + + resolution: int = 512 + """Image resolution.""" + + runtime_env: Dict[str, str] = field(default_factory=lambda: RUNTIME_ENVS["diffusers-latest"]) + """The runtime environment to use for the training job.""" + + _uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())) + """The UUID for creating a unique training job directory.""" + + _output_directory: str = field(init=False) + """The output directory for the training job.""" + + _repo_directory: str = field( + init=False, + default=cached_repo( + f"https://github.com/huggingface/diffusers/archive/refs/tags/{GIT_TAG}.zip", + repo_name="diffusers", + subdirectory="examples/dreambooth", + ), + ) + """The repository to use for the training job.""" + + def __post_init__(self): + if self.method not in ("stable-diffusion-dreambooth-lora"): + raise ValueError(f"Invalid method: {self.method}, available methods: ['stable-diffusion-dreambooth-lora']") + + # Setup the instance and output directories + logger.debug("Setting up instance and output directories") + working_directory = Path(NOS_TMP_DIR / self._uuid) + working_directory.mkdir(parents=True, exist_ok=True) + logger.debug(f"Finished setting up instance and output directories [working_directory={working_directory}]") + + # Copy the instance directory to the working directory + if not Path(self.instance_directory).exists(): + raise IOError(f"Failed to load instance_directory={self.instance_directory}.") + instance_directory = working_directory / "instances" + shutil.copytree(self.instance_directory, str(instance_directory)) + nfiles = len(os.listdir(instance_directory)) + logger.debug(f"Copied instance directory to {working_directory} [nfiles={nfiles}]") + + # Create an output directory for weights + output_directory = working_directory / "weights" + output_directory.mkdir(parents=True, exist_ok=True) + + # Setup the diffusers working directory + self.runtime_env["working_dir"] = self._repo_directory + + # Set the instance and output directories + self.instance_directory = str(instance_directory) + self._working_directory = str(working_directory) + self._output_directory = str(output_directory) + + # Write the metadata and job configuration files + logger.debug(f"Writing metadata and job configuration files to {working_directory}") + with open(str(working_directory / f"{self._uuid}_metadata.json"), "w") as fp: + json.dump(asdict(self), fp, indent=2) + with open(str(working_directory / f"{self._uuid}_job_config.json"), "w") as fp: + json.dump(asdict(self), fp, indent=2) + logger.debug(f"Finished writing metadata and job configuration files to {working_directory}") + + @property + def entrypoint(self): + """The entrypoint to run for the training job.""" + return ( + f"""accelerate launch train_dreambooth_lora.py""" + f""" --pretrained_model_name_or_path={self.model_name}""" + f""" --instance_data_dir={self.instance_directory}""" + f""" --output_dir={self._output_directory}""" + f''' --instance_prompt="{self.instance_prompt}"''' + f""" --resolution={self.resolution}""" + f""" --train_batch_size=1""" + f""" --gradient_accumulation_steps=1""" + f""" --checkpointing_steps=100""" + f""" --learning_rate=1e-4""" + f''' --lr_scheduler="constant"''' + f""" --lr_warmup_steps=0""" + f""" --max_train_steps={self.max_train_steps}""" + f''' --seed="0"''' + ) + # f''' --validation_prompt="A photo of sks dog in a bucket"''' + # f''' --validation_epochs=50''' + + def job_dict(self) -> Dict[str, Any]: + """The job configuration for the training job.""" + return { + "submission_id": self._uuid, + "entrypoint": self.entrypoint, + "runtime_env": self.runtime_env, + "entrypoint_num_gpus": 1, + } diff --git a/nos/server/_service.py b/nos/server/_service.py index cfeddb9e..2f466165 100644 --- a/nos/server/_service.py +++ b/nos/server/_service.py @@ -19,6 +19,7 @@ from nos.executors.ray import RayExecutor from nos.logging import logger from nos.managers import ModelHandle, ModelManager +from nos.models.dreambooth.config import StableDiffusionTrainingJobConfig from nos.protoc import import_module from nos.version import __version__ @@ -35,6 +36,52 @@ def load_spec(model_name: str, task: TaskType) -> ModelSpec: return model_spec +class TrainingService: + """Ray-executor based training service.""" + + config_cls = { + "stable-diffusion-dreambooth-lora": StableDiffusionTrainingJobConfig, + } + + def __init__(self): + self.executor = RayExecutor.get() + try: + self.executor.init() + except Exception as e: + err_msg = f"Failed to initialize executor [e={e}]" + logger.info(err_msg) + raise RuntimeError(err_msg) + + def train(self, method: str, training_inputs: Dict[str, Any], metadata: Dict[str, Any] = None) -> str: + """Train / Fine-tune a model by submitting a job to the RayJobExecutor. + + Args: + method (str): Training method (e.g. `stable-diffusion-dreambooth-lora`). + training_inputs (Dict[str, Any]): Training inputs. + Returns: + str: Job ID. + """ + try: + config_cls = self.config_cls[method] + except KeyError: + raise NotImplementedError(f"Training not supported for method [method={method}]") + + # Check if the training inputs are correctly specified + config = config_cls(method=method, **training_inputs) + try: + pass + except Exception as e: + raise ValueError(f"Invalid training inputs [training_inputs={training_inputs}, e={e}]") + + # Submit the training job as a Ray job + configd = config.job_dict() + logger.debug("Submitting training job") + logger.debug(f"config\n{configd}]") + job_id = self.executor.jobs.submit(**configd) + logger.debug(f"Submitted training job [job_id={job_id}, config={configd}]") + return job_id + + class InferenceService: """Ray-executor based inference service. From 5a945672a40ad8b91c7bf0ef95a488826b1296a9 Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Tue, 29 Aug 2023 14:25:04 -0700 Subject: [PATCH 02/10] Working local hub registry for stable diffusion lora models --- nos/models/dreambooth/config.py | 136 -------------------------------- nos/server/_service.py | 47 ----------- 2 files changed, 183 deletions(-) delete mode 100644 nos/models/dreambooth/config.py diff --git a/nos/models/dreambooth/config.py b/nos/models/dreambooth/config.py deleted file mode 100644 index 4b88ddf0..00000000 --- a/nos/models/dreambooth/config.py +++ /dev/null @@ -1,136 +0,0 @@ -import json -import os -import shutil -import uuid -from dataclasses import asdict, dataclass, field -from pathlib import Path -from typing import Any, Dict - -from nos.common.git import cached_repo -from nos.constants import NOS_TMP_DIR -from nos.logging import logger - - -GIT_TAG = "v0.20.1" -RUNTIME_ENVS = { - "diffusers-latest": { - "working_dir": "./nos/experimental/train/dreambooth", - "pip": [f"https://github.com/huggingface/diffusers/archive/refs/tags/{GIT_TAG}.zip", "accelerate>=0.22.0"], - } -} - - -@dataclass -class StableDiffusionTrainingJobConfig: - """Configuration for a training job. - - Training job contents are written to `~/.nos/tmp/{uuid}/`. - {uuid}_metadata.json: Metadata for the training job. - {uuid}_job_config.json: Job configuration for the training job. - """ - - model_name: str - """Model name (e.g `stabilityai/stable-diffusion-2-1`).""" - - method: str - """Stable diffusion training method (choice of `stable-diffusion-dreambooth-lora`).""" - - instance_directory: str - """Image instance directory (e.g. dog).""" - - instance_prompt: str - """Image instance prompt (e.g. A photo of sks dog in a bucket).""" - - max_train_steps: int = 500 - """Maximum number of training steps.""" - - resolution: int = 512 - """Image resolution.""" - - runtime_env: Dict[str, str] = field(default_factory=lambda: RUNTIME_ENVS["diffusers-latest"]) - """The runtime environment to use for the training job.""" - - _uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())) - """The UUID for creating a unique training job directory.""" - - _output_directory: str = field(init=False) - """The output directory for the training job.""" - - _repo_directory: str = field( - init=False, - default=cached_repo( - f"https://github.com/huggingface/diffusers/archive/refs/tags/{GIT_TAG}.zip", - repo_name="diffusers", - subdirectory="examples/dreambooth", - ), - ) - """The repository to use for the training job.""" - - def __post_init__(self): - if self.method not in ("stable-diffusion-dreambooth-lora"): - raise ValueError(f"Invalid method: {self.method}, available methods: ['stable-diffusion-dreambooth-lora']") - - # Setup the instance and output directories - logger.debug("Setting up instance and output directories") - working_directory = Path(NOS_TMP_DIR / self._uuid) - working_directory.mkdir(parents=True, exist_ok=True) - logger.debug(f"Finished setting up instance and output directories [working_directory={working_directory}]") - - # Copy the instance directory to the working directory - if not Path(self.instance_directory).exists(): - raise IOError(f"Failed to load instance_directory={self.instance_directory}.") - instance_directory = working_directory / "instances" - shutil.copytree(self.instance_directory, str(instance_directory)) - nfiles = len(os.listdir(instance_directory)) - logger.debug(f"Copied instance directory to {working_directory} [nfiles={nfiles}]") - - # Create an output directory for weights - output_directory = working_directory / "weights" - output_directory.mkdir(parents=True, exist_ok=True) - - # Setup the diffusers working directory - self.runtime_env["working_dir"] = self._repo_directory - - # Set the instance and output directories - self.instance_directory = str(instance_directory) - self._working_directory = str(working_directory) - self._output_directory = str(output_directory) - - # Write the metadata and job configuration files - logger.debug(f"Writing metadata and job configuration files to {working_directory}") - with open(str(working_directory / f"{self._uuid}_metadata.json"), "w") as fp: - json.dump(asdict(self), fp, indent=2) - with open(str(working_directory / f"{self._uuid}_job_config.json"), "w") as fp: - json.dump(asdict(self), fp, indent=2) - logger.debug(f"Finished writing metadata and job configuration files to {working_directory}") - - @property - def entrypoint(self): - """The entrypoint to run for the training job.""" - return ( - f"""accelerate launch train_dreambooth_lora.py""" - f""" --pretrained_model_name_or_path={self.model_name}""" - f""" --instance_data_dir={self.instance_directory}""" - f""" --output_dir={self._output_directory}""" - f''' --instance_prompt="{self.instance_prompt}"''' - f""" --resolution={self.resolution}""" - f""" --train_batch_size=1""" - f""" --gradient_accumulation_steps=1""" - f""" --checkpointing_steps=100""" - f""" --learning_rate=1e-4""" - f''' --lr_scheduler="constant"''' - f""" --lr_warmup_steps=0""" - f""" --max_train_steps={self.max_train_steps}""" - f''' --seed="0"''' - ) - # f''' --validation_prompt="A photo of sks dog in a bucket"''' - # f''' --validation_epochs=50''' - - def job_dict(self) -> Dict[str, Any]: - """The job configuration for the training job.""" - return { - "submission_id": self._uuid, - "entrypoint": self.entrypoint, - "runtime_env": self.runtime_env, - "entrypoint_num_gpus": 1, - } diff --git a/nos/server/_service.py b/nos/server/_service.py index 2f466165..cfeddb9e 100644 --- a/nos/server/_service.py +++ b/nos/server/_service.py @@ -19,7 +19,6 @@ from nos.executors.ray import RayExecutor from nos.logging import logger from nos.managers import ModelHandle, ModelManager -from nos.models.dreambooth.config import StableDiffusionTrainingJobConfig from nos.protoc import import_module from nos.version import __version__ @@ -36,52 +35,6 @@ def load_spec(model_name: str, task: TaskType) -> ModelSpec: return model_spec -class TrainingService: - """Ray-executor based training service.""" - - config_cls = { - "stable-diffusion-dreambooth-lora": StableDiffusionTrainingJobConfig, - } - - def __init__(self): - self.executor = RayExecutor.get() - try: - self.executor.init() - except Exception as e: - err_msg = f"Failed to initialize executor [e={e}]" - logger.info(err_msg) - raise RuntimeError(err_msg) - - def train(self, method: str, training_inputs: Dict[str, Any], metadata: Dict[str, Any] = None) -> str: - """Train / Fine-tune a model by submitting a job to the RayJobExecutor. - - Args: - method (str): Training method (e.g. `stable-diffusion-dreambooth-lora`). - training_inputs (Dict[str, Any]): Training inputs. - Returns: - str: Job ID. - """ - try: - config_cls = self.config_cls[method] - except KeyError: - raise NotImplementedError(f"Training not supported for method [method={method}]") - - # Check if the training inputs are correctly specified - config = config_cls(method=method, **training_inputs) - try: - pass - except Exception as e: - raise ValueError(f"Invalid training inputs [training_inputs={training_inputs}, e={e}]") - - # Submit the training job as a Ray job - configd = config.job_dict() - logger.debug("Submitting training job") - logger.debug(f"config\n{configd}]") - job_id = self.executor.jobs.submit(**configd) - logger.debug(f"Submitted training job [job_id={job_id}, config={configd}]") - return job_id - - class InferenceService: """Ray-executor based inference service. From 6e1f87ed3c4cb63a5b157b4f5e02b72ce1c35352 Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Tue, 29 Aug 2023 17:05:27 -0700 Subject: [PATCH 03/10] Overhaul training API with new TrainingService and grpc client tests --- nos/client/grpc.py | 34 ++++++++++++ nos/proto/nos_service.proto | 26 +++++---- nos/server/_service.py | 48 ++++++++++------- .../train/__init__.py | 0 .../train/_train_service.py | 11 ++-- nos/{experimental => server}/train/config.py | 0 .../train/dreambooth/config.py | 7 +-- tests/client/grpc/test_grpc_client.py | 16 ++++++ tests/client/test_client_integration.py | 54 +++++++++++++++++-- tests/server/test_inference_service.py | 1 + tests/server/test_training_service.py | 2 +- 11 files changed, 154 insertions(+), 45 deletions(-) rename nos/{experimental => server}/train/__init__.py (100%) rename nos/{experimental => server}/train/_train_service.py (86%) rename nos/{experimental => server}/train/config.py (100%) rename nos/{experimental => server}/train/dreambooth/config.py (95%) diff --git a/nos/client/grpc.py b/nos/client/grpc.py index 03b92ac1..0696d4c4 100644 --- a/nos/client/grpc.py +++ b/nos/client/grpc.py @@ -1,8 +1,10 @@ """gRPC client for NOS service.""" import secrets import time +import uuid from dataclasses import dataclass, field from functools import cached_property, lru_cache +from pathlib import Path from typing import Any, Callable, Dict, List import grpc @@ -268,6 +270,38 @@ def Run( module: InferenceModule = self.Module(task, model_name) return module(**inputs) + def Train(self, method: str, **inputs: Dict[str, Any]) -> nos_service_pb2.TrainingJobResponse: + """Training module. + + Args: + method (str): Training method (e.g. `stable-diffusion-dreambooth-lora`). + **inputs (Dict[str, Any]): Training inputs. + Returns: + str: Job ID. + Raises: + NosClientException: If the server fails to respond to the request. + """ + try: + request = nos_service_pb2.TrainingJobRequest( + method=method, + inputs=inputs, + ) + response = self.stub.Train(request) + return response.job_id + except grpc.RpcError as e: + raise NosClientException(f"Failed to train model (details={(e.details())})", e) + + def Volume(self, name: str) -> str: + """Remote volume module for NOS. + + Note: This is meant for remote volume mounts especially useful for training. + """ + info = self.GetServiceInfo() + root = Path.home() / ".nos" if info.runtime == "local" else Path.home() / ".nosd" + path = root / f"volumes/{name}_{uuid.uuid4().hex[:8]}" + path.mkdir(parents=True, exist_ok=True) + return str(path) + @dataclass class InferenceModule: diff --git a/nos/proto/nos_service.proto b/nos/proto/nos_service.proto index d1a9b007..434719a6 100644 --- a/nos/proto/nos_service.proto +++ b/nos/proto/nos_service.proto @@ -65,6 +65,7 @@ message PingResponse { // Service information repsonse message ServiceInfoResponse { string version = 1; // (e.g. "0.1.0") + string runtime = 2; // (e.g. "cpu", "gpu", "local" etc) } // Register system shared memory request @@ -77,6 +78,17 @@ message GenericResponse { bytes response_bytes = 1; } + +// Training job request / responses +message TrainingJobRequest { + string method = 1; + map inputs = 2; +} + +message TrainingJobResponse { + bytes response_bytes = 1; +} + // Service definition service InferenceService { // Check health status of the inference server. @@ -94,6 +106,9 @@ service InferenceService { // Run the inference request rpc Run(InferenceRequest) returns (InferenceResponse) {} + // Dispatch a training request + rpc Train(TrainingJobRequest) returns (TrainingJobResponse) {} + // Register shared memory rpc RegisterSystemSharedMemory(GenericRequest) returns (GenericResponse) {} @@ -108,14 +123,3 @@ service InferenceService { // TODO (spillai): To be implemented later (for power-users) // rpc DeleteModel(DeleteModelRequest) returns (DeleteModelResponse) {} } - - -message TrainingRequest { - ModelInfo model = 1; - map inputs = 2; - map outputs = 3; -} - -message TrainingResponse { - bytes response_bytes = 1; -} diff --git a/nos/server/_service.py b/nos/server/_service.py index cfeddb9e..6ff72e05 100644 --- a/nos/server/_service.py +++ b/nos/server/_service.py @@ -20,6 +20,7 @@ from nos.logging import logger from nos.managers import ModelHandle, ModelManager from nos.protoc import import_module +from nos.server.train._train_service import TrainingService from nos.version import __version__ @@ -49,14 +50,10 @@ class InferenceService: """ def __init__(self): - self.model_manager = ModelManager() self.executor = RayExecutor.get() - try: - self.executor.init() - except Exception as e: - err_msg = f"Failed to initialize executor [e={e}]" - logger.info(err_msg) - raise RuntimeError(err_msg) + if not self.executor.is_initialized(): + raise RuntimeError("Ray executor is not initialized") + self.model_manager = ModelManager() if NOS_SHM_ENABLED: self.shm_manager = SharedMemoryTransportManager() else: @@ -115,7 +112,7 @@ def execute(self, model_name: str, task: TaskType = None, inputs: Dict[str, Any] return response -class InferenceServiceImpl(nos_service_pb2_grpc.InferenceServiceServicer, InferenceService): +class InferenceServiceImpl(nos_service_pb2_grpc.InferenceServiceServicer, InferenceService, TrainingService): """ Experimental gRPC-based inference service. @@ -126,6 +123,13 @@ class InferenceServiceImpl(nos_service_pb2_grpc.InferenceServiceServicer, Infere """ def __init__(self, *args, **kwargs): + self.executor = RayExecutor.get() + try: + self.executor.init() + except Exception as e: + err_msg = f"Failed to initialize executor [e={e}]" + logger.info(err_msg) + raise RuntimeError(err_msg) super().__init__(*args, **kwargs) def Ping(self, request: empty_pb2.Empty, context: grpc.ServicerContext) -> nos_service_pb2.PingResponse: @@ -136,7 +140,13 @@ def GetServiceInfo( self, request: empty_pb2.Empty, context: grpc.ServicerContext ) -> nos_service_pb2.ServiceInfoResponse: """Get information on the service.""" - return nos_service_pb2.ServiceInfoResponse(version=__version__) + from nos.common.system import has_gpu, is_inside_docker + + if is_inside_docker(): + runtime = "gpu" if has_gpu() else "cpu" + else: + runtime = "local" + return nos_service_pb2.ServiceInfoResponse(version=__version__, runtime=runtime) def ListModels(self, request: empty_pb2.Empty, context: grpc.ServicerContext) -> nos_service_pb2.ModelListResponse: """List all models.""" @@ -238,23 +248,23 @@ def Run( context.abort(grpc.StatusCode.INTERNAL, "Internal Server Error") def Train( - self, request: nos_service_pb2.TrainingRequest, context: grpc.ServicerContext - ) -> nos_service_pb2.TrainingResponse: - model_request = request.model - logger.debug(f"=> Received training request [task={model_request.task}, model={model_request.name}]") - if model_request.task not in (TaskType.IMAGE_GENERATION.value,): - context.abort(grpc.StatusCode.NOT_FOUND, f"Invalid training task [task={model_request.task}]") + self, request: nos_service_pb2.TrainingJobRequest, context: grpc.ServicerContext + ) -> nos_service_pb2.TrainingJobResponse: + logger.debug(f"=> Received training request [method={request.method}]") + if request.method not in TrainingService.config_cls: + context.abort(grpc.StatusCode.NOT_FOUND, f"Invalid training task [method={request.method}]") try: st = time.perf_counter() - logger.info(f"Training request [task={model_request.task}, model={model_request.name}]") - response = self.train(model_request.name, task=TaskType(model_request.task), inputs=request.inputs) + logger.info(f"Training request [method={request.method}]") + job_id = self.train(request.method, training_inputs=request.inputs) + response = {"job_id": job_id} logger.info( - f"Trained request dispatched [id={id}, task={model_request.task}, model={model_request.name}, elapsed={(time.perf_counter() - st) * 1e3:.1f}ms]" + f"Trained request dispatched [id={id}, method={request.method}, elapsed={(time.perf_counter() - st) * 1e3:.1f}ms]" ) return nos_service_pb2.TrainingResponse(response_bytes=dumps(response)) except (grpc.RpcError, Exception) as e: - msg = f"Failed to train request [task={model_request.task}, model={model_request.name}]" + msg = f"Failed to train request [method={request.method}]" msg += f"{traceback.format_exc()}" logger.error(f"{msg}, e={e}") context.abort(grpc.StatusCode.INTERNAL, "Internal Server Error") diff --git a/nos/experimental/train/__init__.py b/nos/server/train/__init__.py similarity index 100% rename from nos/experimental/train/__init__.py rename to nos/server/train/__init__.py diff --git a/nos/experimental/train/_train_service.py b/nos/server/train/_train_service.py similarity index 86% rename from nos/experimental/train/_train_service.py rename to nos/server/train/_train_service.py index 0a0fc1bf..b829e20a 100644 --- a/nos/experimental/train/_train_service.py +++ b/nos/server/train/_train_service.py @@ -2,9 +2,9 @@ from nos.exceptions import ModelNotFoundError from nos.executors.ray import RayExecutor, RayJobExecutor -from nos.experimental.train.dreambooth.config import StableDiffusionTrainingJobConfig from nos.logging import logger from nos.protoc import import_module +from nos.server.train.dreambooth.config import StableDiffusionTrainingJobConfig nos_service_pb2 = import_module("nos_service_pb2") @@ -19,13 +19,10 @@ class TrainingService: } def __init__(self): + """Initialize the training service.""" self.executor = RayExecutor.get() - try: - self.executor.init() - except Exception as e: - err_msg = f"Failed to initialize executor [e={e}]" - logger.info(err_msg) - raise RuntimeError(err_msg) + if not self.executor.is_initialized(): + raise RuntimeError("Ray executor is not initialized") def train(self, method: str, training_inputs: Dict[str, Any], metadata: Dict[str, Any] = None) -> str: """Train / Fine-tune a model by submitting a job to the RayJobExecutor. diff --git a/nos/experimental/train/config.py b/nos/server/train/config.py similarity index 100% rename from nos/experimental/train/config.py rename to nos/server/train/config.py diff --git a/nos/experimental/train/dreambooth/config.py b/nos/server/train/dreambooth/config.py similarity index 95% rename from nos/experimental/train/dreambooth/config.py rename to nos/server/train/dreambooth/config.py index dae13757..7f275692 100644 --- a/nos/experimental/train/dreambooth/config.py +++ b/nos/server/train/dreambooth/config.py @@ -7,9 +7,9 @@ from typing import Any, Dict from nos.common.git import cached_repo -from nos.experimental.train.config import TrainingJobConfig from nos.logging import logger from nos.models.dreambooth.dreambooth import StableDiffusionDreamboothConfigs +from nos.server.train.config import TrainingJobConfig GIT_TAG = "v0.20.1" @@ -71,8 +71,9 @@ def __post_init__(self): runtime_env["working_dir"] = self.repo_directory # Create a new short unique name using method and uuid (with 8 characters) - job_id = f"{self.method}_{uuid.uuid4().hex[:8]}" - self.job_config = TrainingJobConfig(uuid=job_id, runtime_env=runtime_env) + model_id = f"{self.method}_{uuid.uuid4().hex[:8]}" + self.job_config = TrainingJobConfig(uuid=model_id, runtime_env=runtime_env) + job_id = self.job_config.uuid working_directory = Path(self.job_config.working_directory) # Copy the instance directory to the working directory diff --git a/tests/client/grpc/test_grpc_client.py b/tests/client/grpc/test_grpc_client.py index a3e892b7..b301b1d7 100644 --- a/tests/client/grpc/test_grpc_client.py +++ b/tests/client/grpc/test_grpc_client.py @@ -27,3 +27,19 @@ def predict_module_wrap(): predict_fn = dumps(predict_module_wrap) assert isinstance(predict_fn, bytes) + + def train_wrap(): + return grpc_client.Train( + method="stable-diffusion-dreambooth-lora", + inputs={ + "model_name": "stabilityai/stable-diffusion-2-1", + "instance_directory": "/tmp", + "instance_prompt": "A photo of a bench on the moon", + }, + metadata={ + "name": "sdv21-dreambooth-lora-test-bench", + }, + ) + + train_fn = dumps(train_wrap) + assert isinstance(train_fn, bytes) diff --git a/tests/client/test_client_integration.py b/tests/client/test_client_integration.py index a158afff..328185c7 100644 --- a/tests/client/test_client_integration.py +++ b/tests/client/test_client_integration.py @@ -13,7 +13,7 @@ @pytest.mark.client @pytest.mark.parametrize("runtime", ["cpu", "gpu", "auto"]) -def test_nos_init(runtime): # noqa: F811 +def test_client_init(runtime): # noqa: F811 """Test the NOS server daemon initialization.""" # Initialize the server @@ -44,8 +44,8 @@ def test_nos_init(runtime): # noqa: F811 @pytest.mark.client @pytest.mark.benchmark(group=PyTestGroup.INTEGRATION) @pytest.mark.parametrize("runtime", ["gpu"]) -def test_nos_object_detection_benchmark(runtime): # noqa: F811 - """Test the NOS server daemon initialization.""" +def test_client_inference_benchmark(runtime): # noqa: F811 + """Test and benchmark end-to-end client inference interface.""" from itertools import islice import cv2 @@ -55,7 +55,7 @@ def test_nos_object_detection_benchmark(runtime): # noqa: F811 from nos.logging import logger # Initialize the server - container = nos.init(runtime=runtime, port=GRPC_PORT, utilization=0.8) + container = nos.init(runtime=runtime, port=GRPC_PORT, utilization=1) assert container is not None assert container.id is not None containers = InferenceServiceRuntime.list() @@ -103,3 +103,49 @@ def test_nos_object_detection_benchmark(runtime): # noqa: F811 nos.shutdown() containers = InferenceServiceRuntime.list() assert len(containers) == 0 + + +@pytest.mark.client +@pytest.mark.benchmark(group=PyTestGroup.INTEGRATION) +@pytest.mark.parametrize("runtime", ["local"]) +def test_client_training(runtime): # noqa: F811 + """Test end-to-end client training interface.""" + import shutil + import tempfile + from pathlib import Path + + from nos.logging import logger + from nos.test.utils import NOS_TEST_IMAGE + + # Initialize the server + container = nos.init(runtime=runtime, port=GRPC_PORT, utilization=1) + assert container is not None + assert container.id is not None + containers = InferenceServiceRuntime.list() + assert len(containers) == 1 + + # Test waiting for server to start + # This call should be instantaneous as the server is already ready for the test + client = InferenceClient(f"[::]:{GRPC_PORT}") + assert client.WaitForServer(timeout=180, retry_interval=5) + assert client.IsHealthy() + + logger.debug("Testing training service...") + # Copy test image to temporary directory and test training service + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_image = Path(tmp_dir) / "test_image.jpg" + shutil.copy(NOS_TEST_IMAGE, tmp_image) + + volume_dir = client.Volume("dreambooth_training") + logger.debug(f"Copying files from {tmp_dir} to {volume_dir}...") + shutil.copytree(tmp_dir, volume_dir, dirs_exist_ok=True) + + client.Train( + method="stable-diffusion-dreambooth-lora", + inputs={ + "model_name": "stabilityai/stable-diffusion-2-1", + "instance_directory": volume_dir, + "instance_prompt": "A photo of a bench on the moon", + }, + ) + logger.debug("Training service test passed.") diff --git a/tests/server/test_inference_service.py b/tests/server/test_inference_service.py index 575738be..0059c55b 100644 --- a/tests/server/test_inference_service.py +++ b/tests/server/test_inference_service.py @@ -127,6 +127,7 @@ def test_shm_registry(client_with_server, request): # noqa: F811 # assert len(shm_files) == 0, "Expected no shared memory regions, but found some." +@pytest.mark.benchmark @pytest.mark.parametrize( "client_with_server", ("local_grpc_client_with_server", "grpc_client_with_cpu_backend", "grpc_client_with_gpu_backend"), diff --git a/tests/server/test_training_service.py b/tests/server/test_training_service.py index 4d9f804f..777f201f 100644 --- a/tests/server/test_training_service.py +++ b/tests/server/test_training_service.py @@ -15,7 +15,7 @@ def test_training_service(ray_executor: RayExecutor): # noqa: F811 """Test training service.""" - from nos.experimental.train import TrainingService + from nos.server.train import TrainingService # Test training service svc = TrainingService() From 6ee8363babc6b5dd02abfc992f1e461d9353c048 Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Wed, 30 Aug 2023 12:28:00 -0700 Subject: [PATCH 04/10] Working training API with dynamic model registry on training completion --- nos/client/grpc.py | 58 ++++++++++++++++--- nos/models/__init__.py | 1 + nos/models/dreambooth/dreambooth.py | 18 ++++++ nos/proto/nos_service.proto | 16 ++---- nos/server/_service.py | 41 +++++++++---- nos/server/train/_train_service.py | 61 ++++++++++++++++++-- tests/client/test_client_integration.py | 76 +++++++++++++++---------- 7 files changed, 207 insertions(+), 64 deletions(-) diff --git a/nos/client/grpc.py b/nos/client/grpc.py index 0696d4c4..21183937 100644 --- a/nos/client/grpc.py +++ b/nos/client/grpc.py @@ -167,6 +167,20 @@ def GetServiceVersion(self) -> str: except grpc.RpcError as e: raise NosServerReadyException(f"Failed to get service info (details={e.details()})", e) + def GetServiceRuntime(self) -> str: + """Get service runtime. + + Returns: + str: Service runtime (e.g. cpu, gpu, local). + Raises: + NosClientException: If the server fails to respond to the request. + """ + try: + response: nos_service_pb2.ServiceInfoResponse = self.stub.GetServiceInfo(empty_pb2.Empty()) + return response.runtime + except grpc.RpcError as e: + raise NosServerReadyException(f"Failed to get service info (details={e.details()})", e) + def CheckCompatibility(self) -> bool: """Check if the service version is compatible with the client. @@ -270,24 +284,26 @@ def Run( module: InferenceModule = self.Module(task, model_name) return module(**inputs) - def Train(self, method: str, **inputs: Dict[str, Any]) -> nos_service_pb2.TrainingJobResponse: + def Train( + self, method: str, inputs: Dict[str, Any], metadata: Dict[str, Any] = None + ) -> nos_service_pb2.GenericResponse: """Training module. Args: method (str): Training method (e.g. `stable-diffusion-dreambooth-lora`). - **inputs (Dict[str, Any]): Training inputs. + inputs (Dict[str, Any]): Training inputs. + metadata (Dict[str, Any], optional): Metadata for the training job. Defaults to None. Returns: str: Job ID. Raises: NosClientException: If the server fails to respond to the request. """ try: - request = nos_service_pb2.TrainingJobRequest( - method=method, - inputs=inputs, + request = nos_service_pb2.GenericRequest( + request_bytes=dumps({"method": method, "inputs": inputs, "metadata": metadata}) ) response = self.stub.Train(request) - return response.job_id + return loads(response.response_bytes) except grpc.RpcError as e: raise NosClientException(f"Failed to train model (details={(e.details())})", e) @@ -296,12 +312,38 @@ def Volume(self, name: str) -> str: Note: This is meant for remote volume mounts especially useful for training. """ - info = self.GetServiceInfo() - root = Path.home() / ".nos" if info.runtime == "local" else Path.home() / ".nosd" + runtime = self.GetServiceRuntime() + root = Path.home() / ".nos" if runtime == "local" else Path.home() / ".nosd" path = root / f"volumes/{name}_{uuid.uuid4().hex[:8]}" path.mkdir(parents=True, exist_ok=True) return str(path) + def Wait(self, job_id: str, timeout: int = 60, retry_interval: int = 5) -> None: + """Wait for job to finish. + + Args: + job_id (str): Job ID. + timeout (int, optional): Timeout in seconds. Defaults to 60. + retry_interval (int, optional): Retry interval in seconds. Defaults to 5. + """ + st = time.time() + response = None + while time.time() - st <= timeout: + try: + response: nos_service_pb2.GenericResponse = self.stub.GetJobStatus( + nos_service_pb2.GenericRequest(request_bytes=dumps({"job_id": job_id})) + ) + response = loads(response.response_bytes) + if str(response) != "PENDING" and str(response) != "RUNNING": + return response + else: + logger.debug(f"Waiting for job to finish [job_id={job_id}, response={response}]") + time.sleep(retry_interval) + except Exception: + logger.warning("Failed to fetch job status ... (elapsed={:.0f}s)".format(time.time() - st)) + logger.warning(f"Job timed out [job_id={job_id}]") + return response + @dataclass class InferenceModule: diff --git a/nos/models/__init__.py b/nos/models/__init__.py index eab06523..c562e63c 100644 --- a/nos/models/__init__.py +++ b/nos/models/__init__.py @@ -10,6 +10,7 @@ from ._noop import NoOp # noqa: F401 from .clip import CLIP # noqa: F401 +from .dreambooth.dreambooth import StableDiffusionLoRA # noqa: F401 from .faster_rcnn import FasterRCNN # noqa: F401 from .monodepth import MonoDepth # noqa: F401 from .openmmlab.mmdetection import MMDetection # noqa: F401 diff --git a/nos/models/dreambooth/dreambooth.py b/nos/models/dreambooth/dreambooth.py index c05fd2f6..7b86e72c 100644 --- a/nos/models/dreambooth/dreambooth.py +++ b/nos/models/dreambooth/dreambooth.py @@ -200,3 +200,21 @@ def __call__( height=height if height is not None else self.cfg.resolution, width=width if width is not None else self.cfg.resolution, ).images + + +from nos import hub +from nos.common import Batch, ImageSpec, ImageT, TaskType + + +for model_name in StableDiffusionLoRA.configs.keys(): + logger.debug(f"Registering model: {model_name}") + hub.register( + model_name, + TaskType.IMAGE_GENERATION, + StableDiffusionLoRA, + init_args=(model_name,), + init_kwargs={"dtype": torch.float16}, + method_name="__call__", + inputs={"prompts": Batch[str, 1], "num_images": int, "height": int, "width": int}, + outputs={"images": Batch[ImageT[Image.Image, ImageSpec(shape=(None, None, 3), dtype="uint8")]]}, + ) diff --git a/nos/proto/nos_service.proto b/nos/proto/nos_service.proto index 434719a6..c59307be 100644 --- a/nos/proto/nos_service.proto +++ b/nos/proto/nos_service.proto @@ -78,17 +78,6 @@ message GenericResponse { bytes response_bytes = 1; } - -// Training job request / responses -message TrainingJobRequest { - string method = 1; - map inputs = 2; -} - -message TrainingJobResponse { - bytes response_bytes = 1; -} - // Service definition service InferenceService { // Check health status of the inference server. @@ -107,7 +96,10 @@ service InferenceService { rpc Run(InferenceRequest) returns (InferenceResponse) {} // Dispatch a training request - rpc Train(TrainingJobRequest) returns (TrainingJobResponse) {} + rpc Train(GenericRequest) returns (GenericResponse) {} + + // Job status + rpc GetJobStatus(GenericRequest) returns (GenericResponse) {} // Register shared memory rpc RegisterSystemSharedMemory(GenericRequest) returns (GenericResponse) {} diff --git a/nos/server/_service.py b/nos/server/_service.py index 6ff72e05..14695200 100644 --- a/nos/server/_service.py +++ b/nos/server/_service.py @@ -248,23 +248,44 @@ def Run( context.abort(grpc.StatusCode.INTERNAL, "Internal Server Error") def Train( - self, request: nos_service_pb2.TrainingJobRequest, context: grpc.ServicerContext - ) -> nos_service_pb2.TrainingJobResponse: - logger.debug(f"=> Received training request [method={request.method}]") - if request.method not in TrainingService.config_cls: - context.abort(grpc.StatusCode.NOT_FOUND, f"Invalid training task [method={request.method}]") + self, request: nos_service_pb2.GenericRequest, context: grpc.ServicerContext + ) -> nos_service_pb2.GenericResponse: + request = loads(request.request_bytes) + logger.debug(f"=> Received training request [method={request['method']}]") + if request["method"] not in TrainingService.config_cls: + context.abort(grpc.StatusCode.NOT_FOUND, f"Invalid training task [method={request['method']}]") try: st = time.perf_counter() - logger.info(f"Training request [method={request.method}]") - job_id = self.train(request.method, training_inputs=request.inputs) + logger.info(f"Training request [method={request['method']}]") + job_id = self.train(request["method"], inputs=request["inputs"], metadata=request["metadata"]) response = {"job_id": job_id} logger.info( - f"Trained request dispatched [id={id}, method={request.method}, elapsed={(time.perf_counter() - st) * 1e3:.1f}ms]" + f"Trained request dispatched [id={id}, method={request['method']}, elapsed={(time.perf_counter() - st) * 1e3:.1f}ms]" + ) + return nos_service_pb2.GenericResponse(response_bytes=dumps(response)) + except (grpc.RpcError, Exception) as e: + msg = f"Failed to train request [method={request['method']}]" + msg += f"{traceback.format_exc()}" + logger.error(f"{msg}, e={e}") + context.abort(grpc.StatusCode.INTERNAL, "Internal Server Error") + + def GetJobStatus( + self, request: nos_service_pb2.GenericRequest, context: grpc.ServicerContext + ) -> nos_service_pb2.GenericResponse: + request = loads(request.request_bytes) + logger.debug(f"=> Received job status request [job_id={request['job_id']}]") + + try: + st = time.perf_counter() + logger.info(f"Job status request [job_id={request['job_id']}]") + response = self.jobs.status(request["job_id"]) + logger.info( + f"Job status request [job_id={request['job_id']}, response={response}, elapsed={(time.perf_counter() - st) * 1e3:.1f}ms]" ) - return nos_service_pb2.TrainingResponse(response_bytes=dumps(response)) + return nos_service_pb2.GenericResponse(response_bytes=dumps(response)) except (grpc.RpcError, Exception) as e: - msg = f"Failed to train request [method={request.method}]" + msg = f"Failed to get job status [job_id={request['job_id']}]" msg += f"{traceback.format_exc()}" logger.error(f"{msg}, e={e}") context.abort(grpc.StatusCode.INTERNAL, "Internal Server Error") diff --git a/nos/server/train/_train_service.py b/nos/server/train/_train_service.py index b829e20a..75c0892f 100644 --- a/nos/server/train/_train_service.py +++ b/nos/server/train/_train_service.py @@ -1,3 +1,5 @@ +import threading +import time from typing import Any, Dict from nos.exceptions import ModelNotFoundError @@ -11,6 +13,35 @@ nos_service_pb2_grpc = import_module("nos_service_pb2_grpc") +def register_model(model_name: str, *args, **kwargs): + import torch + from PIL import Image + + from nos import hub + from nos.common import Batch, ImageSpec, ImageT, TaskType + from nos.models.dreambooth.dreambooth import StableDiffusionDreamboothHub, StableDiffusionLoRA + + # Update the registry with newer configs + sd_hub = StableDiffusionDreamboothHub(namespace="custom") + psize = len(sd_hub) + sd_hub.update() + logger.debug(f"Updated registry with newer configs [namespace=custom, size={len(sd_hub)}, prev_size={psize}]") + + # Register the model + logger.debug(f"Registering new model [model={model_name}]") + hub.register( + model_name, + TaskType.IMAGE_GENERATION, + StableDiffusionLoRA, + init_args=(model_name,), + init_kwargs={"dtype": torch.float16}, + method_name="__call__", + inputs={"prompts": Batch[str, 1], "num_images": int, "height": int, "width": int}, + outputs={"images": Batch[ImageT[Image.Image, ImageSpec(shape=(None, None, 3), dtype="uint8")]]}, + ) + logger.debug(f"Registering new model [{model_name}]") + + class TrainingService: """Ray-executor based training service.""" @@ -24,12 +55,12 @@ def __init__(self): if not self.executor.is_initialized(): raise RuntimeError("Ray executor is not initialized") - def train(self, method: str, training_inputs: Dict[str, Any], metadata: Dict[str, Any] = None) -> str: + def train(self, method: str, inputs: Dict[str, Any], metadata: Dict[str, Any] = None) -> str: """Train / Fine-tune a model by submitting a job to the RayJobExecutor. Args: method (str): Training method (e.g. `stable-diffusion-dreambooth-lora`). - training_inputs (Dict[str, Any]): Training inputs. + inputs (Dict[str, Any]): Training inputs. Returns: str: Job ID. """ @@ -39,11 +70,11 @@ def train(self, method: str, training_inputs: Dict[str, Any], metadata: Dict[str raise ModelNotFoundError(f"Training not supported for method [method={method}]") # Check if the training inputs are correctly specified - config = config_cls(method=method, **training_inputs) + config = config_cls(method=method, **inputs) try: pass except Exception as e: - raise ValueError(f"Invalid training inputs [training_inputs={training_inputs}, e={e}]") + raise ValueError(f"Invalid training inputs [inputs={inputs}, e={e}]") # Submit the training job as a Ray job configd = config.job_dict() @@ -53,6 +84,28 @@ def train(self, method: str, training_inputs: Dict[str, Any], metadata: Dict[str logger.debug(f"config\n{configd}]") job_id = self.executor.jobs.submit(**configd) logger.debug(f"Submitted training job [job_id={job_id}, config={configd}]") + + hooks = { + "on_completed": [(register_model, (job_id,), {})], + } + + # Spawn a thread to monitor the job + def monitor_job(job_id: str, timeout: int = 180, retry_interval: int = 5): + st = time.time() + while time.time() - st < timeout: + status = self.executor.jobs.status(job_id) + if str(status) == "SUCCEEDED": + logger.debug(f"Training job completed [job_id={job_id}, status={status}]") + cb, args, kwargs = hooks["on_completed"] + logger.debug(f"Running callback [cb={cb}, args={args}, kwargs={kwargs}]") + cb(*args, **kwargs) + logger.debug(f"Callback completed [cb={cb}, args={args}, kwargs={kwargs}]") + break + else: + logger.debug(f"Training job not completed yet [job_id={job_id}, status={status}]") + time.sleep(retry_interval) + + threading.Thread(target=monitor_job, args=(job_id,), daemon=True).start() return job_id @property diff --git a/tests/client/test_client_integration.py b/tests/client/test_client_integration.py index 328185c7..b8f7b7c6 100644 --- a/tests/client/test_client_integration.py +++ b/tests/client/test_client_integration.py @@ -107,45 +107,61 @@ def test_client_inference_benchmark(runtime): # noqa: F811 @pytest.mark.client @pytest.mark.benchmark(group=PyTestGroup.INTEGRATION) -@pytest.mark.parametrize("runtime", ["local"]) -def test_client_training(runtime): # noqa: F811 +@pytest.mark.parametrize( + "client_with_server", + ("local_grpc_client_with_server",), +) +def test_client_training(client_with_server, request): # noqa: F811 """Test end-to-end client training interface.""" import shutil - import tempfile from pathlib import Path + from nos.common import TaskType from nos.logging import logger from nos.test.utils import NOS_TEST_IMAGE - # Initialize the server - container = nos.init(runtime=runtime, port=GRPC_PORT, utilization=1) - assert container is not None - assert container.id is not None - containers = InferenceServiceRuntime.list() - assert len(containers) == 1 - # Test waiting for server to start # This call should be instantaneous as the server is already ready for the test - client = InferenceClient(f"[::]:{GRPC_PORT}") - assert client.WaitForServer(timeout=180, retry_interval=5) + client = request.getfixturevalue(client_with_server) assert client.IsHealthy() + # Create a temporary volume for training images + volume_dir = client.Volume("dreambooth_training") + logger.debug("Testing training service...") - # Copy test image to temporary directory and test training service - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_image = Path(tmp_dir) / "test_image.jpg" - shutil.copy(NOS_TEST_IMAGE, tmp_image) - - volume_dir = client.Volume("dreambooth_training") - logger.debug(f"Copying files from {tmp_dir} to {volume_dir}...") - shutil.copytree(tmp_dir, volume_dir, dirs_exist_ok=True) - - client.Train( - method="stable-diffusion-dreambooth-lora", - inputs={ - "model_name": "stabilityai/stable-diffusion-2-1", - "instance_directory": volume_dir, - "instance_prompt": "A photo of a bench on the moon", - }, - ) - logger.debug("Training service test passed.") + + # Copy test image to volume and test training service + tmp_image = Path(volume_dir) / "test_image.jpg" + shutil.copy(NOS_TEST_IMAGE, tmp_image) + + # Train a new LoRA model with the image of a bench + response = client.Train( + method="stable-diffusion-dreambooth-lora", + inputs={ + "model_name": "stabilityai/stable-diffusion-2-1", + "instance_directory": volume_dir, + "instance_prompt": "A photo of a bench on the moon", + "max_train_steps": 10, + }, + ) + assert response is not None + model_id = response["job_id"] + logger.debug(f"Training service test completed [model_id={model_id}].") + + # model_id = "stable-diffusion-dreambooth-lora_16cd4490" + + # Wait for the model to be ready + response = client.Wait(job_id=model_id, timeout=180, retry_interval=5) + logger.debug(f"Training service test completed [model_id={model_id}, response={response}].") + time.sleep(10) + + # Test inference with the trained model + logger.debug("Testing inference service...") + response = client.Run( + task=TaskType.IMAGE_GENERATION, + model_name=f"custom/{model_id}", + prompts=["a photo of a bench on the moon"], + width=512, + height=512, + num_images=1, + ) From 860940299a184735efdba2613815dd24a241748b Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Wed, 30 Aug 2023 13:39:04 -0700 Subject: [PATCH 05/10] Fix OMP_NUM_THREADS defaults in entrypoint, if not already set --- docker/.dockerignore | 2 ++ scripts/entrypoint.sh | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docker/.dockerignore b/docker/.dockerignore index 68228799..bbf9890e 100644 --- a/docker/.dockerignore +++ b/docker/.dockerignore @@ -24,3 +24,5 @@ dist bdist *.cache *.ts + +site/ \ No newline at end of file diff --git a/scripts/entrypoint.sh b/scripts/entrypoint.sh index e70c162a..c7967271 100755 --- a/scripts/entrypoint.sh +++ b/scripts/entrypoint.sh @@ -2,9 +2,11 @@ set -e set -x -echo "Starting Ray server with OMP_NUM_THREADS=${OMP_NUM_THREADS}..." +# Get number of cores +NCORES=$(nproc --all) +echo "Starting Ray server with OMP_NUM_THREADS=${OMP_NUM_THREADS:-${NCORES}}..." # Get OMP_NUM_THREADS from environment variable, if set otherwise use 1 -OMP_NUM_THREADS=${OMP_NUM_THREADS} ray start --head +OMP_NUM_THREADS=${OMP_NUM_THREADS:-${NCORES}} ray start --head echo "Starting NOS server..." nos-grpc-server From 2ff9cd925e01a9a9d5ae3c66650f533a8fb70938 Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Wed, 30 Aug 2023 15:05:23 -0700 Subject: [PATCH 06/10] Working volume directory mounts for client-server training api --- docker/.dockerignore | 2 +- nos/client/grpc.py | 11 +++++++---- nos/models/dreambooth/dreambooth.py | 6 ++---- nos/server/train/dreambooth/config.py | 4 ++++ requirements/requirements.server.txt | 2 +- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/docker/.dockerignore b/docker/.dockerignore index bbf9890e..ef5efe55 100644 --- a/docker/.dockerignore +++ b/docker/.dockerignore @@ -25,4 +25,4 @@ bdist *.cache *.ts -site/ \ No newline at end of file +site/ diff --git a/nos/client/grpc.py b/nos/client/grpc.py index 21183937..f9a0d11c 100644 --- a/nos/client/grpc.py +++ b/nos/client/grpc.py @@ -20,7 +20,7 @@ NosServerReadyException, ) from nos.common.shm import NOS_SHM_ENABLED, SharedMemoryTransportManager -from nos.constants import DEFAULT_GRPC_PORT, NOS_PROFILING_ENABLED +from nos.constants import DEFAULT_GRPC_PORT, NOS_HOME, NOS_PROFILING_ENABLED from nos.logging import logger from nos.protoc import import_module from nos.version import __version__ @@ -307,14 +307,17 @@ def Train( except grpc.RpcError as e: raise NosClientException(f"Failed to train model (details={(e.details())})", e) - def Volume(self, name: str) -> str: + def Volume(self, name: str = None) -> str: """Remote volume module for NOS. Note: This is meant for remote volume mounts especially useful for training. """ runtime = self.GetServiceRuntime() - root = Path.home() / ".nos" if runtime == "local" else Path.home() / ".nosd" - path = root / f"volumes/{name}_{uuid.uuid4().hex[:8]}" + root = NOS_HOME / "volumes" if runtime == "local" else Path.home() / ".nosd/volumes" + if name is None: + root.mkdir(parents=True, exist_ok=True) + return str(root) + path = root / f"{name}_{uuid.uuid4().hex[:8]}" path.mkdir(parents=True, exist_ok=True) return str(path) diff --git a/nos/models/dreambooth/dreambooth.py b/nos/models/dreambooth/dreambooth.py index 7b86e72c..e9f7be72 100644 --- a/nos/models/dreambooth/dreambooth.py +++ b/nos/models/dreambooth/dreambooth.py @@ -6,6 +6,8 @@ import torch from PIL import Image +from nos import hub +from nos.common import Batch, ImageSpec, ImageT, TaskType from nos.hub.config import NOS_MODELS_DIR from nos.logging import logger @@ -202,10 +204,6 @@ def __call__( ).images -from nos import hub -from nos.common import Batch, ImageSpec, ImageT, TaskType - - for model_name in StableDiffusionLoRA.configs.keys(): logger.debug(f"Registering model: {model_name}") hub.register( diff --git a/nos/server/train/dreambooth/config.py b/nos/server/train/dreambooth/config.py index 7f275692..d881ca89 100644 --- a/nos/server/train/dreambooth/config.py +++ b/nos/server/train/dreambooth/config.py @@ -7,6 +7,7 @@ from typing import Any, Dict from nos.common.git import cached_repo +from nos.constants import NOS_HOME from nos.logging import logger from nos.models.dreambooth.dreambooth import StableDiffusionDreamboothConfigs from nos.server.train.config import TrainingJobConfig @@ -14,6 +15,7 @@ GIT_TAG = "v0.20.1" +NOS_VOLUME_DIR = NOS_HOME / "volumes" RUNTIME_ENVS = { "diffusers-latest": { "working_dir": "./nos/experimental/", @@ -77,6 +79,8 @@ def __post_init__(self): working_directory = Path(self.job_config.working_directory) # Copy the instance directory to the working directory + self.instance_directory = NOS_VOLUME_DIR / self.instance_directory + logger.debug(f"Instance directory [dir={self.instance_directory}]") if not Path(self.instance_directory).exists(): raise IOError(f"Failed to load instance_directory={self.instance_directory}.") instance_directory = working_directory / "instances" diff --git a/requirements/requirements.server.txt b/requirements/requirements.server.txt index 7a222391..4be6cc4d 100644 --- a/requirements/requirements.server.txt +++ b/requirements/requirements.server.txt @@ -3,7 +3,7 @@ diffusers>=0.17.1 huggingface_hub memray pyarrow>=12.0.0 -ray>=2.6.1 +ray[default]>=2.6.1 safetensors>=0.3.0 tabulate timm>=0.9.2 From 4150c1253a9c4b24c20182803580c1d2059de9e3 Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Wed, 30 Aug 2023 15:06:53 -0700 Subject: [PATCH 07/10] Discord training bot with new fine-tuning API for sdv2 lora --- examples/discord/.env.template | 1 + examples/discord/Dockerfile | 10 ++ examples/discord/Makefile | 5 + examples/discord/bot.py | 141 ++++++++++++++++++++++++++++ examples/discord/docker-compose.yml | 36 +++++++ examples/discord/requirements.txt | 3 + 6 files changed, 196 insertions(+) create mode 100644 examples/discord/.env.template create mode 100644 examples/discord/Dockerfile create mode 100644 examples/discord/Makefile create mode 100644 examples/discord/bot.py create mode 100644 examples/discord/docker-compose.yml create mode 100644 examples/discord/requirements.txt diff --git a/examples/discord/.env.template b/examples/discord/.env.template new file mode 100644 index 00000000..fd657628 --- /dev/null +++ b/examples/discord/.env.template @@ -0,0 +1 @@ +DISCORD_BOT_TOKEN= \ No newline at end of file diff --git a/examples/discord/Dockerfile b/examples/discord/Dockerfile new file mode 100644 index 00000000..130ebd0b --- /dev/null +++ b/examples/discord/Dockerfile @@ -0,0 +1,10 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE:-autonomi/nos:latest-cpu} + +WORKDIR /tmp/$PROJECT +ADD requirements.txt . +RUN pip install -r requirements.txt + +WORKDIR /app +COPY bot.py . +CMD ["python", "/app/bot.py"] \ No newline at end of file diff --git a/examples/discord/Makefile b/examples/discord/Makefile new file mode 100644 index 00000000..2166bf84 --- /dev/null +++ b/examples/discord/Makefile @@ -0,0 +1,5 @@ +SHELL := /bin/bash + +docker-compose-upd-discord-bot: + pushd ../../ && make docker-build-cpu docker-build-gpu && popd; + docker compose -f docker-compose.yml up --build \ No newline at end of file diff --git a/examples/discord/bot.py b/examples/discord/bot.py new file mode 100644 index 00000000..b6329a5b --- /dev/null +++ b/examples/discord/bot.py @@ -0,0 +1,141 @@ +import time +import io +import os +from pathlib import Path + +import discord +from discord.ext import commands + +import nos + +from nos.client import InferenceClient, TaskType +from nos.constants import NOS_TMP_DIR +from nos.logging import logger +# from nos.server._service import TrainingService + +NOS_PLAYGROUND_CHANNEL = "nos-playground" + +# Init nos server, wait for it to spin up then confirm its healthy: +client = InferenceClient() + +logger.debug("Waiting for server to start...") +client.WaitForServer() + +logger.debug("Confirming server is healthy...") +if not client.IsHealthy(): + raise RuntimeError("NOS server is not healthy") + +logger.debug("Server is healthy!") +NOS_VOLUME_DIR = Path(client.Volume()) +NOS_TRAINING_VOLUME_DIR = Path(client.Volume("nos-playground")) +logger.debug(f"Creating training data volume [volume={NOS_TRAINING_VOLUME_DIR}]") + +# Set permissions for our bot to allow it to read messages: +intents = discord.Intents.default() +intents.message_content = True + +# Create our bot, with the command prefix set to "/": +bot = commands.Bot(command_prefix="/", intents=intents) + +logger.debug("Starting bot, initializing existing threads ...") +THREADS = {} + +async def setup(): + """Initialize all threads in the NOS_PLAYGROUND_CHANNEL""" + logger.debug("initializing all threads") + for channel in bot.get_all_channels(): + if channel.name == NOS_PLAYGROUND_CHANNEL: + for thread in await channel.threads(): + THREADS[thread.name] = thread.id + logger.debug(f"initialized all threads [threads={THREADS}]") + + +@bot.command() +async def generate(ctx, *, prompt): + """Create a callback to read messages and generate images from prompt""" + logger.debug(f"/generate [prompt={prompt}, channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]") + + if ctx.channel.name != NOS_PLAYGROUND_CHANNEL: + logger.debug(f"ignoring [channel={ctx.channel.name}]") + return + + # Pull the channel id so we know which model to run: + st = time.perf_counter() + logger.debug(f"/generate request submitted [id={ctx.message.id}]") + response = client.Run( + TaskType.IMAGE_GENERATION, + "stabilityai/stable-diffusion-2", + prompts=[prompt], + width=512, + height=512, + num_images=1, + ) + logger.debug(f"/generate request completed [id={ctx.message.id}, elapsed={time.perf_counter() - st:.2f}s]") + image, = response["images"] + + # Save the image to a buffer and send it back to the user: + image_bytes = io.BytesIO() + image.save(image_bytes, format="PNG") + image_bytes.seek(0) + await ctx.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png")) + + +@bot.command() +async def train(ctx, *, prompt): + logger.debug(f"/train [channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]") + + if ctx.channel.name != NOS_PLAYGROUND_CHANNEL: + logger.debug("ignoring [channel={ctx.channel.name}]") + return + + if not ctx.message.attachments: + logger.debug("no attachments to train on, returning!") + return + + if "" not in prompt: + await ctx.send("Please include a in your training prompt!") + return + + # Create a thread for this training job + thread_id = str(ctx.message.id) + thread = await ctx.channel.create_thread(name=thread_id, type=discord.ChannelType.public_thread) + await thread.send(f"Created a new thread for training [id={thread.name}]") + + # Create the training directory for this thread + dirname = NOS_TRAINING_VOLUME_DIR / thread_id + dirname.mkdir(parents=True, exist_ok=True) + logger.debug(f"Created training directory [dirname={dirname}]") + + # Save all attachments to the training directory + for attachment in ctx.message.attachments: + logger.debug(f"Got attachement [filename={attachment.filename}]") + await attachment.save(str(dirname / str(attachment.filename))) + logger.debug(f"Saved attachment [filename={attachment.filename}]") + + # Train a new LoRA model with the image of a bench + response = client.Train( + method="stable-diffusion-dreambooth-lora", + inputs={ + "model_name": "stabilityai/stable-diffusion-2-1", + "instance_directory": dirname.relative_to(NOS_VOLUME_DIR), + "instance_prompt": "A photo of a on the moon", + "max_train_steps": 10, + }, + metadata={ + "name": "sdv21-dreambooth-lora-test", + }, + ) + logger.debug(f"Submitted training job [id={thread_id}, response={response}, dirname={dirname}]") + if response is None: + logger.error(f"Failed to submit training job [id={thread_id}, response={response}, dirname={dirname}]") + await thread.send(f"Failed to train [prompt={prompt}, response={response}, dirname={dirname}]") + + + +# Pull API token out of environment and run the bot: +bot_token = os.environ.get("DISCORD_BOT_TOKEN") +if bot_token is None: + raise Exception("DISCORD_BOT_TOKEN environment variable not set") +logger.debug(f"Starting bot with token [token={bot_token[:5]}****]") +# bot.loop.run_until_complete(setup()) +bot.run(bot_token) diff --git a/examples/discord/docker-compose.yml b/examples/discord/docker-compose.yml new file mode 100644 index 00000000..711147b9 --- /dev/null +++ b/examples/discord/docker-compose.yml @@ -0,0 +1,36 @@ +version: "3.8" + +services: + bot: + image: autonomi/nos:latest-discord-app + build: + context: . + dockerfile: Dockerfile + args: + - BASE_IMAGE=autonomi/nos:latest-cpu + env_file: + - .env + environment: + - NOS_HOME=/app/.nos + - NOS_LOGGING_LEVEL=DEBUG + volumes: + - ~/.nosd:/app/.nos + - /dev/shm:/dev/shm + network_mode: host + ipc: host + + server: + image: autonomi/nos:latest-gpu + environment: + - NOS_HOME=/app/.nos + - NOS_LOGGING_LEVEL=DEBUG + volumes: + - ~/.nosd:/app/.nos + - /dev/shm:/dev/shm + network_mode: host + ipc: host + deploy: + resources: + reservations: + devices: + - capabilities: [gpu] \ No newline at end of file diff --git a/examples/discord/requirements.txt b/examples/discord/requirements.txt new file mode 100644 index 00000000..9133197f --- /dev/null +++ b/examples/discord/requirements.txt @@ -0,0 +1,3 @@ +discord==2.3.2 +discord.py==2.3.2 +docker From d1da82ec7243160f931493dd826753b7631397ec Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Wed, 30 Aug 2023 18:10:53 -0700 Subject: [PATCH 08/10] Fully functional discord bot training --- examples/discord/.env.template | 2 +- examples/discord/Dockerfile | 2 +- examples/discord/Makefile | 6 +- examples/discord/bot.py | 131 +++++++++++++++++------- examples/discord/docker-compose.yml | 38 +++---- examples/discord/requirements.txt | 1 + nos/client/grpc.py | 2 +- nos/server/train/_train_service.py | 18 ++-- tests/client/test_client_integration.py | 6 +- 9 files changed, 134 insertions(+), 72 deletions(-) diff --git a/examples/discord/.env.template b/examples/discord/.env.template index fd657628..1bcaa3cd 100644 --- a/examples/discord/.env.template +++ b/examples/discord/.env.template @@ -1 +1 @@ -DISCORD_BOT_TOKEN= \ No newline at end of file +DISCORD_BOT_TOKEN= diff --git a/examples/discord/Dockerfile b/examples/discord/Dockerfile index 130ebd0b..3a92306e 100644 --- a/examples/discord/Dockerfile +++ b/examples/discord/Dockerfile @@ -7,4 +7,4 @@ RUN pip install -r requirements.txt WORKDIR /app COPY bot.py . -CMD ["python", "/app/bot.py"] \ No newline at end of file +CMD ["python", "/app/bot.py"] diff --git a/examples/discord/Makefile b/examples/discord/Makefile index 2166bf84..a4d0a8d1 100644 --- a/examples/discord/Makefile +++ b/examples/discord/Makefile @@ -1,5 +1,7 @@ SHELL := /bin/bash -docker-compose-upd-discord-bot: +docker-compose-upd-discord-bot: + sudo mkdir -p ~/.nosd/volumes + sudo chown -R $(USER):$(USER) ~/.nosd/volumes pushd ../../ && make docker-build-cpu docker-build-gpu && popd; - docker compose -f docker-compose.yml up --build \ No newline at end of file + docker compose -f docker-compose.yml up --build diff --git a/examples/discord/bot.py b/examples/discord/bot.py index b6329a5b..93fd42d7 100644 --- a/examples/discord/bot.py +++ b/examples/discord/bot.py @@ -1,17 +1,37 @@ -import time import io import os +import time +from dataclasses import dataclass from pathlib import Path import discord from discord.ext import commands - -import nos +from diskcache import Cache from nos.client import InferenceClient, TaskType from nos.constants import NOS_TMP_DIR from nos.logging import logger -# from nos.server._service import TrainingService + + +@dataclass +class LoRAPromptModel: + + thread_id: str + """Discord thread ID""" + + thread_name: str + """Discord thread name""" + + model_id: str + """Training job ID / model ID""" + + prompt: str + """Prompt used to train the model""" + + @property + def job_id(self) -> str: + return self.model_id + NOS_PLAYGROUND_CHANNEL = "nos-playground" @@ -38,46 +58,72 @@ bot = commands.Bot(command_prefix="/", intents=intents) logger.debug("Starting bot, initializing existing threads ...") -THREADS = {} -async def setup(): - """Initialize all threads in the NOS_PLAYGROUND_CHANNEL""" - logger.debug("initializing all threads") - for channel in bot.get_all_channels(): - if channel.name == NOS_PLAYGROUND_CHANNEL: - for thread in await channel.threads(): - THREADS[thread.name] = thread.id - logger.debug(f"initialized all threads [threads={THREADS}]") +# Maps channel_id -> LoRAPromptModel +MODEL_DB = Cache(str(NOS_TMP_DIR / NOS_PLAYGROUND_CHANNEL)) @bot.command() async def generate(ctx, *, prompt): """Create a callback to read messages and generate images from prompt""" - logger.debug(f"/generate [prompt={prompt}, channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]") + logger.debug( + f"/generate [prompt={prompt}, channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]" + ) - if ctx.channel.name != NOS_PLAYGROUND_CHANNEL: - logger.debug(f"ignoring [channel={ctx.channel.name}]") - return - - # Pull the channel id so we know which model to run: st = time.perf_counter() - logger.debug(f"/generate request submitted [id={ctx.message.id}]") - response = client.Run( - TaskType.IMAGE_GENERATION, - "stabilityai/stable-diffusion-2", - prompts=[prompt], - width=512, - height=512, - num_images=1, - ) - logger.debug(f"/generate request completed [id={ctx.message.id}, elapsed={time.perf_counter() - st:.2f}s]") - image, = response["images"] + if ctx.channel.name == NOS_PLAYGROUND_CHANNEL: + # Pull the channel id so we know which model to run: + logger.debug(f"/generate request submitted [id={ctx.message.id}]") + response = client.Run( + task=TaskType.IMAGE_GENERATION, + model_name="stabilityai/stable-diffusion-2", + prompts=[prompt], + width=512, + height=512, + num_images=1, + ) + logger.debug(f"/generate request completed [id={ctx.message.id}, elapsed={time.perf_counter() - st:.2f}s]") + (image,) = response["images"] + thread = ctx + else: + # Pull the channel id so we know which model to run: + thread = ctx.channel + thread_id = ctx.channel.id + model = MODEL_DB.get(thread_id, default=None) + if model is None: + logger.debug(f"Failed to fetch model [thread_id={thread_id}") + await thread.send("No model found") + return + + # Get model info + try: + models = client.ListModels() + info = {m.name: m for m in models}[model.model_id] + logger.debug(f"Got model info [model_id={model.model_id}, info={info}]") + except Exception as e: + logger.debug(f"Failed to fetch model info [model_id={model.model_id}, e={e}]") + await thread.send(f"Failed to fetch model [model_id={model.model_id}]") + return + + logger.debug(f"/generate request submitted [id={ctx.message.id}, model_id={model.model_id}, model={model}]") + response = client.Run( + task=TaskType.IMAGE_GENERATION, + model_name=model.model_id, + prompts=[prompt], + width=512, + height=512, + num_images=1, + ) + logger.debug( + f"/generate request completed [id={ctx.message.id}, model={model}, elapsed={time.perf_counter() - st:.2f}s]" + ) + (image,) = response["images"] # Save the image to a buffer and send it back to the user: image_bytes = io.BytesIO() image.save(image_bytes, format="PNG") image_bytes.seek(0) - await ctx.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png")) + await thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png")) @bot.command() @@ -91,18 +137,21 @@ async def train(ctx, *, prompt): if not ctx.message.attachments: logger.debug("no attachments to train on, returning!") return - + if "" not in prompt: await ctx.send("Please include a in your training prompt!") return # Create a thread for this training job - thread_id = str(ctx.message.id) - thread = await ctx.channel.create_thread(name=thread_id, type=discord.ChannelType.public_thread) + message_id = str(ctx.message.id) + thread = await ctx.channel.create_thread(name=f"{prompt} ({message_id})", type=discord.ChannelType.public_thread) await thread.send(f"Created a new thread for training [id={thread.name}]") + # Save the thread id + thread_id = thread.id + # Create the training directory for this thread - dirname = NOS_TRAINING_VOLUME_DIR / thread_id + dirname = NOS_TRAINING_VOLUME_DIR / str(thread_id) dirname.mkdir(parents=True, exist_ok=True) logger.debug(f"Created training directory [dirname={dirname}]") @@ -126,12 +175,22 @@ async def train(ctx, *, prompt): }, ) logger.debug(f"Submitted training job [id={thread_id}, response={response}, dirname={dirname}]") + await thread.send(f"Submitted training job [id={thread.name}, model={response['job_id']}]") + + # Save the model + MODEL_DB[thread_id] = LoRAPromptModel( + thread_id=thread_id, + thread_name=thread.name, + model_id=f"custom/{response['job_id']}", + prompt=prompt, + ) + logger.debug(f"Saved model [id={thread_id}, model={MODEL_DB[thread_id]}]") + if response is None: logger.error(f"Failed to submit training job [id={thread_id}, response={response}, dirname={dirname}]") await thread.send(f"Failed to train [prompt={prompt}, response={response}, dirname={dirname}]") - # Pull API token out of environment and run the bot: bot_token = os.environ.get("DISCORD_BOT_TOKEN") if bot_token is None: diff --git a/examples/discord/docker-compose.yml b/examples/discord/docker-compose.yml index 711147b9..09a7ea1e 100644 --- a/examples/discord/docker-compose.yml +++ b/examples/discord/docker-compose.yml @@ -1,24 +1,24 @@ version: "3.8" services: - bot: - image: autonomi/nos:latest-discord-app - build: - context: . - dockerfile: Dockerfile - args: - - BASE_IMAGE=autonomi/nos:latest-cpu - env_file: - - .env - environment: - - NOS_HOME=/app/.nos - - NOS_LOGGING_LEVEL=DEBUG - volumes: - - ~/.nosd:/app/.nos - - /dev/shm:/dev/shm - network_mode: host - ipc: host - + # bot: + # image: autonomi/nos:latest-discord-app + # build: + # context: . + # dockerfile: Dockerfile + # args: + # - BASE_IMAGE=autonomi/nos:latest-cpu + # env_file: + # - .env + # environment: + # - NOS_HOME=/app/.nos + # - NOS_LOGGING_LEVEL=DEBUG + # volumes: + # - ~/.nosd:/app/.nos + # - /dev/shm:/dev/shm + # network_mode: host + # ipc: host + server: image: autonomi/nos:latest-gpu environment: @@ -33,4 +33,4 @@ services: resources: reservations: devices: - - capabilities: [gpu] \ No newline at end of file + - capabilities: [gpu] diff --git a/examples/discord/requirements.txt b/examples/discord/requirements.txt index 9133197f..ffb8316a 100644 --- a/examples/discord/requirements.txt +++ b/examples/discord/requirements.txt @@ -1,3 +1,4 @@ discord==2.3.2 discord.py==2.3.2 +diskcache docker diff --git a/nos/client/grpc.py b/nos/client/grpc.py index f9a0d11c..2a480cbf 100644 --- a/nos/client/grpc.py +++ b/nos/client/grpc.py @@ -592,7 +592,7 @@ def __call__(self, **inputs: Dict[str, Any]) -> Dict[str, Any]: ) # Execute the request st = time.perf_counter() - logger.debug(f"Executing request [model={self._spec.name}]]") + logger.debug(f"Executing request [model={self._spec.name}]") response = self.stub.Run(request) if NOS_PROFILING_ENABLED: logger.debug( diff --git a/nos/server/train/_train_service.py b/nos/server/train/_train_service.py index 75c0892f..ce72ca37 100644 --- a/nos/server/train/_train_service.py +++ b/nos/server/train/_train_service.py @@ -28,18 +28,19 @@ def register_model(model_name: str, *args, **kwargs): logger.debug(f"Updated registry with newer configs [namespace=custom, size={len(sd_hub)}, prev_size={psize}]") # Register the model - logger.debug(f"Registering new model [model={model_name}]") + model_id = f"custom/{model_name}" + logger.debug(f"Registering new model [model={model_id}]") hub.register( - model_name, + model_id, TaskType.IMAGE_GENERATION, StableDiffusionLoRA, - init_args=(model_name,), + init_args=(model_id,), init_kwargs={"dtype": torch.float16}, method_name="__call__", inputs={"prompts": Batch[str, 1], "num_images": int, "height": int, "width": int}, outputs={"images": Batch[ImageT[Image.Image, ImageSpec(shape=(None, None, 3), dtype="uint8")]]}, ) - logger.debug(f"Registering new model [{model_name}]") + logger.debug(f"Registering new model [model={model_id}]") class TrainingService: @@ -85,12 +86,11 @@ def train(self, method: str, inputs: Dict[str, Any], metadata: Dict[str, Any] = job_id = self.executor.jobs.submit(**configd) logger.debug(f"Submitted training job [job_id={job_id}, config={configd}]") - hooks = { - "on_completed": [(register_model, (job_id,), {})], - } + hooks = {"on_completed": (register_model, (job_id,), {})} # Spawn a thread to monitor the job - def monitor_job(job_id: str, timeout: int = 180, retry_interval: int = 5): + def monitor_job_hook(job_id: str, timeout: int = 600, retry_interval: int = 5): + """Hook for monitoring the job status and running callbacks on completion.""" st = time.time() while time.time() - st < timeout: status = self.executor.jobs.status(job_id) @@ -105,7 +105,7 @@ def monitor_job(job_id: str, timeout: int = 180, retry_interval: int = 5): logger.debug(f"Training job not completed yet [job_id={job_id}, status={status}]") time.sleep(retry_interval) - threading.Thread(target=monitor_job, args=(job_id,), daemon=True).start() + threading.Thread(target=monitor_job_hook, args=(job_id,), daemon=True).start() return job_id @property diff --git a/tests/client/test_client_integration.py b/tests/client/test_client_integration.py index b8f7b7c6..27cd595a 100644 --- a/tests/client/test_client_integration.py +++ b/tests/client/test_client_integration.py @@ -148,10 +148,10 @@ def test_client_training(client_with_server, request): # noqa: F811 model_id = response["job_id"] logger.debug(f"Training service test completed [model_id={model_id}].") - # model_id = "stable-diffusion-dreambooth-lora_16cd4490" - # Wait for the model to be ready - response = client.Wait(job_id=model_id, timeout=180, retry_interval=5) + # For e.g. model_id = "stable-diffusion-dreambooth-lora_16cd4490" + # model_id = "stable-diffusion-dreambooth-lora_ef939db5" + response = client.Wait(job_id=model_id, timeout=600, retry_interval=5) logger.debug(f"Training service test completed [model_id={model_id}, response={response}].") time.sleep(10) From d14f3025d3ac896824580c826c7bf5c4afcddff5 Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Wed, 30 Aug 2023 23:03:45 -0700 Subject: [PATCH 09/10] Final fixes for discord bot with fully functional training / inference --- examples/discord/bot.py | 105 ++++++++++++++++++++++---- nos/client/grpc.py | 10 ++- nos/models/dreambooth/dreambooth.py | 4 +- nos/server/_service.py | 2 +- nos/server/train/dreambooth/config.py | 2 +- 5 files changed, 98 insertions(+), 25 deletions(-) diff --git a/examples/discord/bot.py b/examples/discord/bot.py index 93fd42d7..1e064b4b 100644 --- a/examples/discord/bot.py +++ b/examples/discord/bot.py @@ -1,6 +1,8 @@ +import asyncio import io import os import time +import uuid from dataclasses import dataclass from pathlib import Path @@ -32,9 +34,15 @@ class LoRAPromptModel: def job_id(self) -> str: return self.model_id + def __str__(self) -> str: + return f"LoRAPromptModel(thread_id={self.thread_id}, thread_name={self.thread_name}, model_id={self.model_id}, prompt={self.prompt})" + NOS_PLAYGROUND_CHANNEL = "nos-playground" +BASE_MODEL = "runwayml/stable-diffusion-v1-5" +# BASE_MODEL = "stabilityai/stable-diffusion-2-1" + # Init nos server, wait for it to spin up then confirm its healthy: client = InferenceClient() @@ -73,10 +81,15 @@ async def generate(ctx, *, prompt): st = time.perf_counter() if ctx.channel.name == NOS_PLAYGROUND_CHANNEL: # Pull the channel id so we know which model to run: + + # Acknowledge the request, by reacting to the message with a checkmark + logger.debug(f"Request acknowledged [id={ctx.message.id}]") + await ctx.message.add_reaction("✅") + logger.debug(f"/generate request submitted [id={ctx.message.id}]") response = client.Run( task=TaskType.IMAGE_GENERATION, - model_name="stabilityai/stable-diffusion-2", + model_name=BASE_MODEL, prompts=[prompt], width=512, height=512, @@ -91,7 +104,7 @@ async def generate(ctx, *, prompt): thread_id = ctx.channel.id model = MODEL_DB.get(thread_id, default=None) if model is None: - logger.debug(f"Failed to fetch model [thread_id={thread_id}") + logger.debug(f"Failed to fetch model [thread_id={thread_id}]") await thread.send("No model found") return @@ -105,6 +118,11 @@ async def generate(ctx, *, prompt): await thread.send(f"Failed to fetch model [model_id={model.model_id}]") return + # Acknowledge the request, by reacting to the message with a checkmark + logger.debug(f"Request acknowledged [id={ctx.message.id}]") + await ctx.message.add_reaction("✅") + + # Run inference on the trained model logger.debug(f"/generate request submitted [id={ctx.message.id}, model_id={model.model_id}, model={model}]") response = client.Run( task=TaskType.IMAGE_GENERATION, @@ -138,14 +156,14 @@ async def train(ctx, *, prompt): logger.debug("no attachments to train on, returning!") return - if "" not in prompt: - await ctx.send("Please include a in your training prompt!") + if "sks" not in prompt: + await ctx.send("Please include 'sks' in your training prompt!") return # Create a thread for this training job message_id = str(ctx.message.id) - thread = await ctx.channel.create_thread(name=f"{prompt} ({message_id})", type=discord.ChannelType.public_thread) - await thread.send(f"Created a new thread for training [id={thread.name}]") + thread = await ctx.message.create_thread(name=f"{prompt} ({message_id})") + logger.debug(f"Created thread [id={thread.id}, name={thread.name}]") # Save the thread id thread_id = thread.id @@ -156,32 +174,38 @@ async def train(ctx, *, prompt): logger.debug(f"Created training directory [dirname={dirname}]") # Save all attachments to the training directory + logger.debug(f"Saving attachments [dirname={dirname}, attachments={len(ctx.message.attachments)}]") for attachment in ctx.message.attachments: - logger.debug(f"Got attachement [filename={attachment.filename}]") - await attachment.save(str(dirname / str(attachment.filename))) - logger.debug(f"Saved attachment [filename={attachment.filename}]") + filename = str(dirname / f"{str(uuid.uuid4().hex[:8])}_{attachment.filename}") + await attachment.save(filename) + logger.debug(f"Saved attachment [filename={filename}]") + + # Acknowledge the request, by reacting to the message with a checkmark + logger.debug(f"Request acknowledged [id={ctx.message.id}]") + await ctx.message.add_reaction("✅") # Train a new LoRA model with the image of a bench response = client.Train( method="stable-diffusion-dreambooth-lora", inputs={ - "model_name": "stabilityai/stable-diffusion-2-1", + "model_name": BASE_MODEL, "instance_directory": dirname.relative_to(NOS_VOLUME_DIR), - "instance_prompt": "A photo of a on the moon", - "max_train_steps": 10, + "instance_prompt": prompt, + "max_train_steps": 500, }, metadata={ - "name": "sdv21-dreambooth-lora-test", + "name": "sdv15-dreambooth-lora", }, ) + job_id = response["job_id"] logger.debug(f"Submitted training job [id={thread_id}, response={response}, dirname={dirname}]") - await thread.send(f"Submitted training job [id={thread.name}, model={response['job_id']}]") + await thread.send(f"@here Submitted training job [id={thread.name}, model={job_id}]") # Save the model MODEL_DB[thread_id] = LoRAPromptModel( thread_id=thread_id, thread_name=thread.name, - model_id=f"custom/{response['job_id']}", + model_id=f"custom/{job_id}", prompt=prompt, ) logger.debug(f"Saved model [id={thread_id}, model={MODEL_DB[thread_id]}]") @@ -190,6 +214,46 @@ async def train(ctx, *, prompt): logger.error(f"Failed to submit training job [id={thread_id}, response={response}, dirname={dirname}]") await thread.send(f"Failed to train [prompt={prompt}, response={response}, dirname={dirname}]") + # Create a new thread to watch the training job + async def post_on_training_complete_async(): + # Wait for the model to be ready + response = client.Wait(job_id=job_id, timeout=600, retry_interval=10) + logger.debug(f"Training completed [job_id={job_id}, response={response}].") + + # Get the thread + _thread = bot.get_channel(thread_id) + await _thread.send(f"@here Training complete [id={_thread.name}, model={job_id}]") + + # Wait for model to be registered after the job is complete + await asyncio.sleep(5) + + # Run inference on the trained model + st = time.perf_counter() + response = client.Run( + task=TaskType.IMAGE_GENERATION, + model_name=f"custom/{job_id}", + prompts=[prompt], + width=512, + height=512, + num_images=1, + ) + logger.debug(f"/generate request completed [model={job_id}, elapsed={time.perf_counter() - st:.2f}s]") + (image,) = response["images"] + + # Save the image to a buffer and send it back to the user: + image_bytes = io.BytesIO() + image.save(image_bytes, format="PNG") + image_bytes.seek(0) + await _thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png")) + + # def post_on_training_complete(): + # asyncio.run(post_on_training_complete_async()) + + logger.debug(f"Starting thread to watch training job [id={thread_id}, job_id={job_id}]") + # threading.Thread(target=post_on_training_complete, daemon=True).start() + asyncio.run_coroutine_threadsafe(post_on_training_complete_async(), loop) + logger.debug(f"Started thread to watch training job [id={thread_id}, job_id={job_id}]") + # Pull API token out of environment and run the bot: bot_token = os.environ.get("DISCORD_BOT_TOKEN") @@ -197,4 +261,13 @@ async def train(ctx, *, prompt): raise Exception("DISCORD_BOT_TOKEN environment variable not set") logger.debug(f"Starting bot with token [token={bot_token[:5]}****]") # bot.loop.run_until_complete(setup()) -bot.run(bot_token) + + +async def run_bot(): + await bot.start(bot_token) + + +if __name__ == "__main__": + loop = asyncio.get_event_loop() + loop.create_task(run_bot()) + loop.run_forever() diff --git a/nos/client/grpc.py b/nos/client/grpc.py index 2a480cbf..68a04474 100644 --- a/nos/client/grpc.py +++ b/nos/client/grpc.py @@ -340,10 +340,12 @@ def Wait(self, job_id: str, timeout: int = 60, retry_interval: int = 5) -> None: if str(response) != "PENDING" and str(response) != "RUNNING": return response else: - logger.debug(f"Waiting for job to finish [job_id={job_id}, response={response}]") - time.sleep(retry_interval) - except Exception: - logger.warning("Failed to fetch job status ... (elapsed={:.0f}s)".format(time.time() - st)) + logger.debug( + f"Waiting for job to finish [job_id={job_id}, response={response}, elapsed={time.time() - st:.0f}s]" + ) + except Exception as e: + logger.warning(f"Failed to fetch job status ... [elapsed={time.time() - st:.0f}s, e={e}]") + time.sleep(retry_interval) logger.warning(f"Job timed out [job_id={job_id}]") return response diff --git a/nos/models/dreambooth/dreambooth.py b/nos/models/dreambooth/dreambooth.py index e9f7be72..46b4e1b9 100644 --- a/nos/models/dreambooth/dreambooth.py +++ b/nos/models/dreambooth/dreambooth.py @@ -187,8 +187,7 @@ def __call__( self, prompts: Union[str, List[str]], num_images: int = 1, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, + num_inference_steps: int = 30, height: int = None, width: int = None, ) -> List[Image.Image]: @@ -198,7 +197,6 @@ def __call__( return self.pipe( prompts * num_images, num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, height=height if height is not None else self.cfg.resolution, width=width if width is not None else self.cfg.resolution, ).images diff --git a/nos/server/_service.py b/nos/server/_service.py index 14695200..7479395f 100644 --- a/nos/server/_service.py +++ b/nos/server/_service.py @@ -279,7 +279,7 @@ def GetJobStatus( try: st = time.perf_counter() logger.info(f"Job status request [job_id={request['job_id']}]") - response = self.jobs.status(request["job_id"]) + response = str(self.jobs.status(request["job_id"])) logger.info( f"Job status request [job_id={request['job_id']}, response={response}, elapsed={(time.perf_counter() - st) * 1e3:.1f}ms]" ) diff --git a/nos/server/train/dreambooth/config.py b/nos/server/train/dreambooth/config.py index d881ca89..c298e25f 100644 --- a/nos/server/train/dreambooth/config.py +++ b/nos/server/train/dreambooth/config.py @@ -109,7 +109,7 @@ def entrypoint(self): f""" --resolution={self.resolution}""" f""" --train_batch_size=1""" f""" --gradient_accumulation_steps=1""" - f""" --checkpointing_steps=100""" + f""" --checkpointing_steps={self.max_train_steps // 5}""" f""" --learning_rate=1e-4""" f''' --lr_scheduler="constant"''' f""" --lr_warmup_steps=0""" From ed4c5136c54977112397898a44b2c7e792cfe22d Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Thu, 31 Aug 2023 18:28:34 -0700 Subject: [PATCH 10/10] Discord bot with fine-tuning API - added tests for fine-tuning API --- examples/discord/bot.py | 58 ++++++++++++------- examples/discord/docker-compose.yml | 34 +++++------ nos/executors/ray.py | 17 +++++- nos/experimental/discord/nos_bot.py | 52 ----------------- nos/server/_service.py | 2 +- nos/server/train/__init__.py | 2 +- .../train/{_train_service.py => _service.py} | 3 +- tests/server/test_training_service.py | 30 ++-------- 8 files changed, 79 insertions(+), 119 deletions(-) delete mode 100755 nos/experimental/discord/nos_bot.py rename nos/server/train/{_train_service.py => _service.py} (98%) diff --git a/examples/discord/bot.py b/examples/discord/bot.py index 1e064b4b..aed34cd4 100644 --- a/examples/discord/bot.py +++ b/examples/discord/bot.py @@ -1,3 +1,4 @@ +"""Example discord both with Stable Diffusion LoRA fine-tuning support.""" import asyncio import io import os @@ -39,11 +40,9 @@ def __str__(self) -> str: NOS_PLAYGROUND_CHANNEL = "nos-playground" - BASE_MODEL = "runwayml/stable-diffusion-v1-5" -# BASE_MODEL = "stabilityai/stable-diffusion-2-1" -# Init nos server, wait for it to spin up then confirm its healthy: +# Init NOS server, wait for it to spin up then confirm its healthy. client = InferenceClient() logger.debug("Waiting for server to start...") @@ -58,22 +57,34 @@ def __str__(self) -> str: NOS_TRAINING_VOLUME_DIR = Path(client.Volume("nos-playground")) logger.debug(f"Creating training data volume [volume={NOS_TRAINING_VOLUME_DIR}]") -# Set permissions for our bot to allow it to read messages: +# Set permissions for our bot to allow it to read messages intents = discord.Intents.default() intents.message_content = True -# Create our bot, with the command prefix set to "/": +# Create our bot, with the command prefix set to "/" bot = commands.Bot(command_prefix="/", intents=intents) logger.debug("Starting bot, initializing existing threads ...") -# Maps channel_id -> LoRAPromptModel +# Simple persistent dict/database for storing models +# This maps (channel-id -> LoRAPromptModel) MODEL_DB = Cache(str(NOS_TMP_DIR / NOS_PLAYGROUND_CHANNEL)) @bot.command() async def generate(ctx, *, prompt): - """Create a callback to read messages and generate images from prompt""" + """Create a callback to read messages and generate images from prompt + + Usage: + 1. In the main channel, you can run: + /generate a photo of a dog on the moon + to generate an image with the pre-trained SDv1.5 model. + + 2. In a thread that has been created by fine-tuning a new model, + you can run: + /generate a photo of a sks dog on the moon + to generate the specific instance of the dog using the fine-tuned model. + """ logger.debug( f"/generate [prompt={prompt}, channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]" ) @@ -146,6 +157,13 @@ async def generate(ctx, *, prompt): @bot.command() async def train(ctx, *, prompt): + """Fine-tune a new model with the provided prompt and images. + + Example: + Upload a few images of your favorite dog, and then run: + `/train sks a photo of a sks dog on the moon` + """ + logger.debug(f"/train [channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]") if ctx.channel.name != NOS_PLAYGROUND_CHANNEL: @@ -246,28 +264,24 @@ async def post_on_training_complete_async(): image_bytes.seek(0) await _thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png")) - # def post_on_training_complete(): - # asyncio.run(post_on_training_complete_async()) - logger.debug(f"Starting thread to watch training job [id={thread_id}, job_id={job_id}]") - # threading.Thread(target=post_on_training_complete, daemon=True).start() asyncio.run_coroutine_threadsafe(post_on_training_complete_async(), loop) logger.debug(f"Started thread to watch training job [id={thread_id}, job_id={job_id}]") -# Pull API token out of environment and run the bot: -bot_token = os.environ.get("DISCORD_BOT_TOKEN") -if bot_token is None: - raise Exception("DISCORD_BOT_TOKEN environment variable not set") -logger.debug(f"Starting bot with token [token={bot_token[:5]}****]") -# bot.loop.run_until_complete(setup()) - - -async def run_bot(): - await bot.start(bot_token) +async def run_bot(token: str): + """Start the bot with the user-provided token.""" + await bot.start(token) if __name__ == "__main__": + # Get the bot token from the environment + token = os.environ.get("DISCORD_BOT_TOKEN") + if token is None: + raise Exception("DISCORD_BOT_TOKEN environment variable not set") + logger.debug(f"Starting bot with token [token={token[:5]}****]") + + # Start the asyncio event loop, and run the bot loop = asyncio.get_event_loop() - loop.create_task(run_bot()) + loop.create_task(run_bot(token)) loop.run_forever() diff --git a/examples/discord/docker-compose.yml b/examples/discord/docker-compose.yml index 09a7ea1e..336f093e 100644 --- a/examples/discord/docker-compose.yml +++ b/examples/discord/docker-compose.yml @@ -1,23 +1,23 @@ version: "3.8" services: - # bot: - # image: autonomi/nos:latest-discord-app - # build: - # context: . - # dockerfile: Dockerfile - # args: - # - BASE_IMAGE=autonomi/nos:latest-cpu - # env_file: - # - .env - # environment: - # - NOS_HOME=/app/.nos - # - NOS_LOGGING_LEVEL=DEBUG - # volumes: - # - ~/.nosd:/app/.nos - # - /dev/shm:/dev/shm - # network_mode: host - # ipc: host + bot: + image: autonomi/nos:latest-discord-app + build: + context: . + dockerfile: Dockerfile + args: + - BASE_IMAGE=autonomi/nos:latest-cpu + env_file: + - .env + environment: + - NOS_HOME=/app/.nos + - NOS_LOGGING_LEVEL=DEBUG + volumes: + - ~/.nosd:/app/.nos + - /dev/shm:/dev/shm + network_mode: host + ipc: host server: image: autonomi/nos:latest-gpu diff --git a/nos/executors/ray.py b/nos/executors/ray.py index 72c2e868..e0277ac6 100644 --- a/nos/executors/ray.py +++ b/nos/executors/ray.py @@ -211,7 +211,22 @@ def status(self, job_id: str) -> str: def logs(self, job_id: str) -> str: """Get logs for a job.""" return self.client.get_job_logs(job_id) - + + def wait(self, job_id: str, timeout: int = 600, retry_interval: int = 5) -> str: + """Wait for a job to complete.""" + status = None + st = time.time() + while time.time() - st < timeout: + status = self.status(job_id) + if str(status) == "SUCCEEDED": + logger.debug(f"Training job completed [job_id={job_id}, status={status}]") + return status + else: + logger.debug(f"Training job not completed yet [job_id={job_id}, status={status}]") + time.sleep(retry_interval) + logger.warning(f"Training job timed out [job_id={job_id}, status={status}]") + return status + def init(*args, **kwargs) -> bool: """Initialize Ray executor.""" diff --git a/nos/experimental/discord/nos_bot.py b/nos/experimental/discord/nos_bot.py deleted file mode 100755 index a2b535b2..00000000 --- a/nos/experimental/discord/nos_bot.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python - -import io -import os - -import discord -from discord.ext import commands - -import nos -from nos.client import InferenceClient, TaskType - - -# Init nos server, wait for it to spin up then confirm its healthy: -nos.init(runtime="gpu") -nos_client = InferenceClient() -nos_client.WaitForServer() -if not nos_client.IsHealthy(): - raise RuntimeError("NOS server is not healthy") - -# Set permissions for our bot to allow it to read messages: -intents = discord.Intents.default() -intents.message_content = True - -# Create our bot: -bot = commands.Bot(command_prefix="$", intents=intents) - -# Create a callback to read messages and generate images from prompt: -@bot.command() -async def generate(ctx, *, prompt): - response = nos_client.Run( - TaskType.IMAGE_GENERATION, - "stabilityai/stable-diffusion-2", - prompts=[prompt], - width=512, - height=512, - num_images=1, - ) - image = response["images"][0] - - image_bytes = io.BytesIO() - image.save(image_bytes, format="PNG") - image_bytes.seek(0) - - await ctx.send(file=discord.File(image_bytes, filename="image.png")) - - -# Pull API token out of environment and run the bot: -bot_token = os.environ.get("BOT_TOKEN") -if bot_token is None: - raise Exception("BOT_TOKEN environment variable not set") - -bot.run(bot_token) diff --git a/nos/server/_service.py b/nos/server/_service.py index 7479395f..b6a20761 100644 --- a/nos/server/_service.py +++ b/nos/server/_service.py @@ -20,7 +20,7 @@ from nos.logging import logger from nos.managers import ModelHandle, ModelManager from nos.protoc import import_module -from nos.server.train._train_service import TrainingService +from nos.server.train._service import TrainingService from nos.version import __version__ diff --git a/nos/server/train/__init__.py b/nos/server/train/__init__.py index bf8dc8e4..ab18c774 100644 --- a/nos/server/train/__init__.py +++ b/nos/server/train/__init__.py @@ -1 +1 @@ -from ._train_service import TrainingService # noqa: F401 +from ._service import TrainingService # noqa: F401 diff --git a/nos/server/train/_train_service.py b/nos/server/train/_service.py similarity index 98% rename from nos/server/train/_train_service.py rename to nos/server/train/_service.py index ce72ca37..97d4de98 100644 --- a/nos/server/train/_train_service.py +++ b/nos/server/train/_service.py @@ -6,7 +6,8 @@ from nos.executors.ray import RayExecutor, RayJobExecutor from nos.logging import logger from nos.protoc import import_module -from nos.server.train.dreambooth.config import StableDiffusionTrainingJobConfig + +from .dreambooth.config import StableDiffusionTrainingJobConfig nos_service_pb2 = import_module("nos_service_pb2") diff --git a/tests/server/test_training_service.py b/tests/server/test_training_service.py index 777f201f..967600e7 100644 --- a/tests/server/test_training_service.py +++ b/tests/server/test_training_service.py @@ -27,7 +27,7 @@ def test_training_service(ray_executor: RayExecutor): # noqa: F811 job_id = svc.train( method="stable-diffusion-dreambooth-lora", - training_inputs={ + inputs={ "model_name": "stabilityai/stable-diffusion-2-1", "instance_directory": tmp_dir, "instance_prompt": "A photo of a bench on the moon", @@ -57,26 +57,8 @@ def test_training_service(ray_executor: RayExecutor): # noqa: F811 assert status is not None logger.debug(f"Status for job {job_id}: {status}") - -def test_inference_service_with_trained_model(ray_executor: RayExecutor): # noqa: F811 - """Test inference service.""" - from nos.server._service import InferenceService - - # Test training service - InferenceService() - # job_id = svc.execute( - # method="stable-diffusion-dreambooth-lora", - # inference_inputs={ - # "model_name": "stabilityai/stable-diffusion-2-1", - # "instance_directory": tmp_dir, - # "instance_prompt": "A photo of a bench on the moon", - # "resolution": 512, - # "max_train_steps": 100, - # "seed": 0, - # }, - # metadata={ - # "name": "sdv21-dreambooth-lora-test-bench", - # }, - # ) - # assert job_id is not None - # logger.debug(f"Submitted job with id: {job_id}") + # Wait for the job to complete + status = svc.jobs.wait(job_id, timeout=600, retry_interval=5) + assert status is not None + logger.debug(f"Status for job {job_id}: {status}") + assert status == "SUCCEEDED"