diff --git a/frontend/src/features/workflowEditor/components/SidebarForm/ContainerResourceForm.tsx b/frontend/src/features/workflowEditor/components/SidebarForm/ContainerResourceForm.tsx index 6bc4a237..d1d92796 100644 --- a/frontend/src/features/workflowEditor/components/SidebarForm/ContainerResourceForm.tsx +++ b/frontend/src/features/workflowEditor/components/SidebarForm/ContainerResourceForm.tsx @@ -73,7 +73,7 @@ const ContainerResourceForm: React.FC = () => { { None: self.task_id = task_id @@ -30,6 +31,7 @@ def __init__( self.workspace_id = workspace_id self.piece_input_kwargs = piece_input_kwargs self.workflow_shared_storage = workflow_shared_storage + self.container_resources = container_resources or {} # Environment variables self.environment = { @@ -76,11 +78,17 @@ def __init__( ), ) + self.device_requests = [] + if self.container_resources.get('use_gpu', False): + self.device_requests=[ + docker.types.DeviceRequest(count=-1, capabilities=[['gpu']]) + ] super().__init__( **docker_operator_kwargs, task_id=task_id, docker_url='tcp://docker-proxy:2375', mounts=mounts, + device_requests=self.device_requests, environment=self.environment, ) diff --git a/src/domino/task.py b/src/domino/task.py index 15d51372..b36daa6f 100644 --- a/src/domino/task.py +++ b/src/domino/task.py @@ -8,7 +8,6 @@ from domino.custom_operators.docker_operator import DominoDockerOperator from domino.custom_operators.python_operator import PythonOperator from domino.custom_operators.worker_operator import DominoWorkerOperator -from domino.utils import dict_deep_update from domino.logger import get_configured_logger from domino.schemas import shared_storage_map, StorageSource @@ -140,15 +139,17 @@ def _set_operator(self) -> BaseOperator: workspace_id=self.workspace_id, piece_input_kwargs=self.piece_input_kwargs, workflow_shared_storage=self.workflow_shared_storage, + container_resources=self.container_resources, # ----------------- Docker ----------------- + # TODO uncoment image=self.piece["source_image"], + entrypoint=["domino", "run-piece-docker"], do_xcom_push=True, mount_tmp_dir=False, tty=True, xcom_all=False, retrieve_output=True, retrieve_output_path='/airflow/xcom/return.out', - entrypoint=["domino", "run-piece-docker"], ) def __call__(self) -> Callable: