diff --git a/frontend/src/@types/piece/piece.d.ts b/frontend/src/@types/piece/piece.d.ts index c04098cf..f3bc45c9 100644 --- a/frontend/src/@types/piece/piece.d.ts +++ b/frontend/src/@types/piece/piece.d.ts @@ -37,11 +37,26 @@ export interface PieceSchema { $defs: Definitions; } +interface ContainerResources { + limits: { + cpu: number; + memory: number; + }; + requests: { + cpu: number; + memory: number; + }; + use_gpu?: boolean; +} + export interface Piece { id: number; name: string; description: string; + container_resources: ContainerResources; + tags: string[]; + repository_id: number; input_schema: PieceSchema; diff --git a/frontend/src/features/workflowEditor/components/WorkflowEditor.tsx b/frontend/src/features/workflowEditor/components/WorkflowEditor.tsx index 4870bc9f..8033a7c5 100644 --- a/frontend/src/features/workflowEditor/components/WorkflowEditor.tsx +++ b/frontend/src/features/workflowEditor/components/WorkflowEditor.tsx @@ -25,8 +25,10 @@ import * as yup from "yup"; import { type IWorkflowPieceData, storageAccessModes } from "../context/types"; import { type GenerateWorkflowsParams } from "../context/workflowsEditor"; -import { containerResourcesSchema } from "../schemas/containerResourcesSchemas"; -import { extractDefaultInputValues, extractDefaultValues } from "../utils"; +import { + extractDefaultContainerResources, + extractDefaultInputValues, +} from "../utils"; import { type Differences, importJsonWorkflow, @@ -352,8 +354,9 @@ export const WorkflowsEditorComponent: React.FC = () => { const defaultInputs = extractDefaultInputValues( piece as unknown as Piece, ); - const defaultContainerResources = extractDefaultValues( - containerResourcesSchema as any, + + const defaultContainerResources = extractDefaultContainerResources( + piece?.container_resources, ); const currentWorkflowPieces = await getForageWorkflowPieces(); diff --git a/frontend/src/features/workflowEditor/utils/jsonSchema.ts b/frontend/src/features/workflowEditor/utils/jsonSchema.ts index 0beef8cd..cd9fec6a 100644 --- a/frontend/src/features/workflowEditor/utils/jsonSchema.ts +++ b/frontend/src/features/workflowEditor/utils/jsonSchema.ts @@ -1,6 +1,12 @@ // Extract default values from Schema -import { type IWorkflowPieceData } from "../context/types"; +import { isEmpty } from "utils"; + +import { defaultContainerResources } from "../components/SidebarForm/ContainerResourceForm"; +import { + type IContainerResourceFormData, + type IWorkflowPieceData, +} from "../context/types"; import { getFromUpstream } from "./getFromUpstream"; @@ -85,7 +91,7 @@ export const extractDefaultValues = ( ) => { output = output === null ? {} : output; - if (schema) { + if (!isEmpty(schema) && "properties" in schema) { const properties = schema.properties; for (const [key, value] of Object.entries(properties)) { if (value?.from_upstream === "always") { @@ -104,3 +110,23 @@ export const extractDefaultValues = ( return output; }; + +export const extractDefaultContainerResources = ( + cr?: Piece["container_resources"], +): IContainerResourceFormData => { + if (cr && !isEmpty(cr) && "limits" in cr && "requests" in cr) { + return { + cpu: { + max: Number(cr.limits.cpu), + min: Number(cr.requests.cpu), + }, + memory: { + max: Number(cr.limits.memory), + min: Number(cr.requests.memory), + }, + useGpu: cr?.use_gpu ?? false, + }; + } else { + return defaultContainerResources; + } +}; diff --git a/rest/constants/schemas/__init__.py b/rest/constants/schemas/__init__.py new file mode 100644 index 00000000..eceaf84b --- /dev/null +++ b/rest/constants/schemas/__init__.py @@ -0,0 +1 @@ +from .container_resources import ContainerResourcesModel \ No newline at end of file diff --git a/rest/constants/schemas/container_resources.py b/rest/constants/schemas/container_resources.py new file mode 100644 index 00000000..d4dfcd1d --- /dev/null +++ b/rest/constants/schemas/container_resources.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel, Field + +class SystemRequirementsModel(BaseModel): + cpu: int = Field(default=128) + memory: int = Field(memory=100) + + +class ContainerResourcesModel(BaseModel): + requests: SystemRequirementsModel = Field(default=SystemRequirementsModel(cpu=100, memory=128)) + limits: SystemRequirementsModel = Field(default=SystemRequirementsModel(cpu=100, memory=128)) + use_gpu: bool = False \ No newline at end of file diff --git a/rest/database/alembic/versions/93da7356c3d7_.py b/rest/database/alembic/versions/93da7356c3d7_.py new file mode 100644 index 00000000..fb54eac9 --- /dev/null +++ b/rest/database/alembic/versions/93da7356c3d7_.py @@ -0,0 +1,30 @@ +"""empty message + +Revision ID: 93da7356c3d7 +Revises: f7214a10a4df +Create Date: 2023-11-29 07:55:27.576939 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '93da7356c3d7' +down_revision = 'f7214a10a4df' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('piece', sa.Column('tags', sa.ARRAY(sa.String()), server_default='{}', nullable=False)) + op.add_column('piece', sa.Column('container_resources', sa.JSON(), server_default=sa.text("'{}'::jsonb"), nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('piece', 'container_resources') + op.drop_column('piece', 'tags') + # ### end Alembic commands ### diff --git a/rest/database/models/piece.py b/rest/database/models/piece.py index 64450d8a..3e6f5068 100644 --- a/rest/database/models/piece.py +++ b/rest/database/models/piece.py @@ -1,6 +1,6 @@ from database.models.base import Base, BaseDatabaseModel from sqlalchemy.orm import relationship -from sqlalchemy import Column, String, Integer, JSON, ForeignKey, text +from sqlalchemy import Column, String, Integer, JSON, ForeignKey, text, ARRAY class Piece(Base, BaseDatabaseModel): __tablename__ = "piece" @@ -13,6 +13,8 @@ class Piece(Base, BaseDatabaseModel): input_schema = Column(JSON, nullable=False, server_default=text("'{}'::jsonb")) output_schema = Column(JSON, nullable=False, server_default=text("'{}'::jsonb")) # Using server default empty JSON object to avoid null value in database secrets_schema = Column(JSON, nullable=False, server_default=text("'{}'::jsonb")) + tags = Column(ARRAY(String), nullable=False, server_default="{}") + container_resources = Column(JSON, nullable=False, server_default=text("'{}'::jsonb")) style = Column(JSON, nullable=True) source_url = Column(String, nullable=True) repository_id = Column(Integer, ForeignKey('piece_repository.id', ondelete='cascade'), nullable=False) diff --git a/rest/schemas/responses/piece.py b/rest/schemas/responses/piece.py index e70ea934..2c13c47a 100644 --- a/rest/schemas/responses/piece.py +++ b/rest/schemas/responses/piece.py @@ -11,6 +11,8 @@ class GetPiecesResponse(BaseModel): input_schema: Optional[Dict] = None output_schema: Optional[Dict] = None secrets_schema: Optional[Dict] = None + container_resources: Optional[Dict] = None + tags: Optional[List[str]] = None style: Optional[Dict] = None source_url: Optional[str] = None repository_url: str diff --git a/rest/services/piece_service.py b/rest/services/piece_service.py index 1b5feb70..a0e51cc2 100644 --- a/rest/services/piece_service.py +++ b/rest/services/piece_service.py @@ -1,16 +1,12 @@ from typing import List -import json from schemas.requests.piece import ListPiecesFilters from schemas.exceptions.base import ResourceNotFoundException -from clients.github_rest_client import GithubRestClient - +from constants.schemas import ContainerResourcesModel from core.logger import get_configured_logger -from core.settings import settings from repository.user_repository import UserRepository from repository.workspace_repository import WorkspaceRepository from repository.piece_repository_repository import PieceRepositoryRepository from database.models import Piece, PieceRepository -from database.models.enums import RepositorySource from clients.local_files_client import LocalFilesClient from repository.piece_repository import PieceRepository from schemas.responses.piece import GetPiecesResponse @@ -46,7 +42,7 @@ def list_pieces( Returns: List[GetPiecesResponse]: List of all pieces data """ - + piece_repository = self.piece_repository_repository.find_by_id(id=piece_repository_id) if not piece_repository: raise ResourceNotFoundException(message="Workspace or Piece Repository not found") @@ -58,13 +54,13 @@ def list_pieces( filters=filters.model_dump(exclude_none=True), ) return [ - GetPiecesResponse(**piece.to_dict(),repository_url=piece_repository.url) for piece in pieces + GetPiecesResponse(**piece.to_dict(), repository_url=piece_repository.url) for piece in pieces ] def check_pieces_to_update_github( - self, - repository_id: int, + self, + repository_id: int, compiled_metadata: dict, dependencies_map: dict, ) -> None: @@ -106,6 +102,8 @@ def _update_pieces_from_metadata(self, piece_metadata: dict, dependencies_map: d piece_style = piece_metadata.get("style") name = piece_metadata.get("name") style = get_frontend_node_style(module_name=name, **piece_style) + + container_resources = ContainerResourcesModel(**piece_metadata.get("container_resources", {})) new_piece = Piece( name=piece_metadata.get("name"), dependency=piece_metadata.get("dependency"), @@ -115,6 +113,8 @@ def _update_pieces_from_metadata(self, piece_metadata: dict, dependencies_map: d input_schema=piece_metadata.get("input_schema", {}), output_schema=piece_metadata.get("output_schema", {}), secrets_schema=piece_metadata.get("secrets_schema", {}), + container_resources=container_resources.model_dump(), + tags=piece_metadata.get("tags", []), style=style, repository_id=repository_id ) diff --git a/src/domino/base_piece.py b/src/domino/base_piece.py index 10518ff3..9dcccd6e 100644 --- a/src/domino/base_piece.py +++ b/src/domino/base_piece.py @@ -149,22 +149,6 @@ def format_xcom(self, output_obj: pydantic.BaseModel) -> dict: self.logger.info(f"Piece {self.__class__.__name__} is not returning a valid XCOM object. Auto-generating a base XCOM for it...") xcom_obj = dict() - # Add arguments types to XCOM - # TODO - this is a temporary solution. We should find a better way to do this - # output_schema = output_obj.model_json_schema() - # for k, v in output_schema["properties"].items(): - # if "type" in v: - # # Get file-path and directory-path types - # if v["type"] == "string" and "format" in v: - # v_type = v["format"] - # else: - # v_type = v["type"] - # elif "anyOf" in v: - # if "$ref" in v["anyOf"][0]: - # type_model = v["anyOf"][0]["$ref"].split("/")[-1] - # v_type = output_schema["definitions"][type_model]["type"] - # xcom_obj[f"{k}_type"] = v_type - # Serialize self.display_result and add it to XCOM if isinstance(self.display_result, dict): if "file_type" not in self.display_result: @@ -185,6 +169,7 @@ def format_xcom(self, output_obj: pydantic.BaseModel) -> dict: self.display_result["file_path"] = None self.display_result["file_type"] = "txt" self.display_result["base64_content"] = base64_content + xcom_obj["display_result"] = self.display_result # Update XCOM with extra metadata @@ -240,7 +225,7 @@ def run_piece_function( self, piece_input_data: dict, piece_input_model: pydantic.BaseModel, - piece_output_model: pydantic.BaseModel, + piece_output_model: pydantic.BaseModel, piece_secrets_model: Optional[pydantic.BaseModel] = None, airflow_context: Optional[dict] = None ): @@ -397,7 +382,7 @@ def piece_function(self): It should have all the necessary content for auto-generating json schemas. All arguments should be type annotated and docstring should carry description for each argument. """ - raise NotImplementedError("This method must be implemented in the child class!") + raise NotImplementedError("This method must be implemented in the child class!") def serialize_display_result_file(self, file_path: Union[str, Path], file_type: DisplayResultFileType) -> dict: """ diff --git a/src/domino/custom_operators/k8s_operator.py b/src/domino/custom_operators/k8s_operator.py index 2293ed87..b980df04 100644 --- a/src/domino/custom_operators/k8s_operator.py +++ b/src/domino/custom_operators/k8s_operator.py @@ -13,18 +13,21 @@ from domino.schemas import WorkflowSharedStorage, ContainerResourcesModel from domino.storage.s3 import S3StorageRepository from domino.logger import get_configured_logger +from airflow.exceptions import AirflowException +from airflow.kubernetes.pod_generator import PodDefaults +import json class DominoKubernetesPodOperator(KubernetesPodOperator): def __init__( - self, + self, dag_id: str, task_id: str, - piece_name: str, + piece_name: str, deploy_mode: str, # TODO enum repository_url: str, repository_version: str, - workspace_id: int, - piece_input_kwargs: Optional[Dict] = None, + workspace_id: int, + piece_input_kwargs: Optional[Dict] = None, workflow_shared_storage: WorkflowSharedStorage = None, container_resources: Optional[Dict] = None, **k8s_operator_kwargs @@ -85,19 +88,19 @@ def __init__( def _make_volumes_and_volume_mounts_dev(self): - """ + """ Make volumes and volume mounts for the pod when in DEVELOPMENT mode. """ config.load_incluster_config() k8s_client = client.CoreV1Api() - + all_volumes = [] all_volume_mounts = [] - + repository_raw_project_name = str(self.piece_source_image).split('/')[-1].split(':')[0] persistent_volume_claim_name = 'pvc-{}'.format(str(repository_raw_project_name.lower().replace('_', '-'))) persistent_volume_name = 'pv-{}'.format(str(repository_raw_project_name.lower().replace('_', '-'))) - + pvc_exists = False try: k8s_client.read_namespaced_persistent_volume_claim(name=persistent_volume_claim_name, namespace='default') @@ -122,14 +125,14 @@ def _make_volumes_and_volume_mounts_dev(self): ), ) volume_mount_dev_pieces = k8s.V1VolumeMount( - name='dev-op-{path_name}'.format(path_name=str(repository_raw_project_name.lower().replace('_', '-'))), + name='dev-op-{path_name}'.format(path_name=str(repository_raw_project_name.lower().replace('_', '-'))), mount_path=f'/home/domino/pieces_repository', - sub_path=None, + sub_path=None, read_only=True ) all_volumes.append(volume_dev_pieces) all_volume_mounts.append(volume_mount_dev_pieces) - + ######################## For local domino-py dev ############################################### domino_package_local_claim_name = 'domino-dev-volume-claim' pvc_exists = False @@ -144,32 +147,23 @@ def _make_volumes_and_volume_mounts_dev(self): volume_dev = k8s.V1Volume( name='jobs-persistent-storage-dev', persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(claim_name=domino_package_local_claim_name), - ) + ) """ # TODO Remove deprecated_volume_mount_dev once we have all the pieces repositories updated with the new base pod image """ - volume_mount_dev = k8s.V1VolumeMount( - name='jobs-persistent-storage-dev', - mount_path='/home/domino/domino_py/src/domino', - sub_path=None, - read_only=True - ) - deprecated_volume_mount_dev = k8s.V1VolumeMount( - name='jobs-persistent-storage-dev', - mount_path='/home/domino/domino_py/domino', + volume_mount_pkg = k8s.V1VolumeMount( + name='jobs-persistent-storage-dev', + mount_path='/usr/local/lib/python3.10/site-packages/domino/', sub_path=None, read_only=True ) all_volumes.append(volume_dev) - all_volume_mounts.append(volume_mount_dev) - # TODO remove - all_volume_mounts.append(deprecated_volume_mount_dev) + all_volume_mounts.append(volume_mount_pkg) return all_volumes, all_volume_mounts - def build_pod_request_obj(self, context: Optional['Context'] = None) -> k8s.V1Pod: """ We override this method to add the shared storage to the pod. @@ -382,7 +376,7 @@ def _get_upstream_xcom_data_from_task_ids(task_ids: list, context: Context): return upstream_xcoms_data def _get_piece_kwargs_value_from_upstream_xcom( - self, + self, value: Any ): if isinstance(value, dict) and value.get("type") == "fromUpstream": @@ -395,7 +389,7 @@ def _get_piece_kwargs_value_from_upstream_xcom( return [self._get_piece_kwargs_value_from_upstream_xcom(item) for item in value] elif isinstance(value, dict): return { - k: self._get_piece_kwargs_value_from_upstream_xcom(v) + k: self._get_piece_kwargs_value_from_upstream_xcom(v) for k, v in value.items() } return value @@ -426,7 +420,7 @@ def _update_env_var_value_from_name(self, name: str, value: str): def _prepare_execute_environment(self, context: Context): - """ + """ Runs at the begining of the execute method. Pass extra arguments and configuration as environment variables to the pod """ @@ -447,7 +441,7 @@ def _prepare_execute_environment(self, context: Context): self.upstream_xcoms_data = self._get_upstream_xcom_data_from_task_ids(task_ids=upstream_task_ids, context=context) self._update_piece_kwargs_with_upstream_xcom() self._update_env_var_value_from_name(name='DOMINO_RUN_PIECE_KWARGS', value=str(self.piece_input_kwargs)) - + # Add pieces secrets to environment variables piece_secrets = self._get_piece_secrets( repository_url=self.repository_url, @@ -542,7 +536,60 @@ def _kill_shared_storage_sidecar(self, pod: k8s.V1Pod): self.log.info('Sending signal to delete shared storage sidecar container') self.pod_manager._exec_pod_command(resp, 'kill -s SIGINT 1') + def extract_xcom(self, pod: k8s.V1Pod): + """Retrieves xcom value and kills xcom sidecar container""" + result = self.pod_manager_extract_xcom(pod) + if isinstance(result, str) and result.rstrip() == "__airflow_xcom_result_empty__": + self.log.info("Result file is empty.") + return None + else: + self.log.info("xcom result: \n%s", result) + return json.loads(result) + + def pod_manager_extract_xcom(self, pod: k8s.V1Pod) -> str: + client = kubernetes_stream( + self.pod_manager._client.connect_get_namespaced_pod_exec, + pod.metadata.name, + pod.metadata.namespace, + container=PodDefaults.SIDECAR_CONTAINER_NAME, + command=[ + '/bin/sh', + '-c', + f"if [ -s {PodDefaults.XCOM_MOUNT_PATH}/return.json ]; then cat {PodDefaults.XCOM_MOUNT_PATH}/return.json; else echo __airflow_xcom_result_empty__; fi", + ], + stderr=True, + stdin=False, + stdout=True, + tty=False, + _preload_content=False, + _request_timeout=10, + ) + client.run_forever(timeout=10) + result = client.read_all() + + _ = kubernetes_stream( + self.pod_manager._client.connect_get_namespaced_pod_exec, + pod.metadata.name, + pod.metadata.namespace, + container=PodDefaults.SIDECAR_CONTAINER_NAME, + command=[ + '/bin/sh', + '-c', + 'kill -s SIGINT 1', + ], + stderr=True, + stdin=False, + stdout=True, + tty=False, + _preload_content=True, + _request_timeout=10, + ) + client.close() + + if result is None: + raise AirflowException(f"Failed to extract xcom from pod: {pod.metadata.name}") + return result