Skip to content

Commit

Permalink
fix list workflow schedule arg
Browse files Browse the repository at this point in the history
  • Loading branch information
vinicvaz committed Dec 7, 2023
1 parent 42fa0cc commit 46d4088
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 38 deletions.
12 changes: 6 additions & 6 deletions rest/schemas/responses/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class WorkflowConfigResponse(BaseModel):
@field_validator('schedule')
def set_schedule(cls, schedule):
return schedule or ScheduleIntervalTypeResponse.none

class BaseWorkflowModel(BaseModel):
workflow: WorkflowConfigResponse
tasks: Dict[
Expand All @@ -82,20 +82,20 @@ class GetWorkflowsResponseData(BaseModel):
@field_validator('schedule')
def set_schedule(cls, schedule):
return schedule or ScheduleIntervalTypeResponse.none


@field_validator('created_at', mode='before')
def add_utc_timezone_created_at(cls, v):
if isinstance(v, datetime) and v.tzinfo is None:
v = v.replace(tzinfo=timezone.utc)
return v

@field_validator('last_changed_at', mode='before')
def add_utc_timezone_last_changed_at(cls, v):
if isinstance(v, datetime) and v.tzinfo is None:
v = v.replace(tzinfo=timezone.utc)
return v

@field_validator('next_dagrun', mode='before')
def add_utc_timezone_next_dagrun(cls, v):
if isinstance(v, datetime) and v.tzinfo is None:
Expand Down Expand Up @@ -152,7 +152,7 @@ class GetWorkflowRunsResponseData(BaseModel):
@field_validator('state')
def set_state(cls, state):
return state or WorkflowRunState.none


model_config = ConfigDict(populate_by_name=True)

Expand All @@ -178,7 +178,7 @@ class GetWorkflowRunTasksResponseData(BaseModel):
def set_state(cls, state):
return state or WorkflowRunTaskState.none


model_config = ConfigDict(populate_by_name=True)


Expand Down
64 changes: 32 additions & 32 deletions rest/services/workflow_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def create_workflow(
) -> CreateWorkflowResponse:
# If workflow with this name already exists for the group raise conflict
workflow = self.workflow_repository.find_by_name_and_workspace_id(
body.workflow.name,
body.workflow.name,
workspace_id=workspace_id
)
if workflow:
Expand All @@ -89,15 +89,15 @@ def create_workflow(
workspace_id=workspace_id
)
workflow = self.workflow_repository.create(new_workflow)

data_dict = body.model_dump()
data_dict['workflow']['id'] = workflow_id

try:
self._validate_workflow_tasks(tasks_dict=data_dict.get('tasks'), workspace_id=workspace_id)

_, workflow_code, pieces_repositories_ids = self._create_dag_code_from_raw_json(data_dict, workspace_id=workspace_id)

workflow_piece_repository_associations = [
WorkflowPieceRepositoryAssociative(
workflow_id=workflow.id,
Expand Down Expand Up @@ -158,7 +158,7 @@ async def list_workflows(
workspace_id=workspace_id,
page=page,
page_size=page_size,
filters=filters.dict(exclude_none=True),
filters=filters.model_dump(exclude_none=True),
descending=True
)

Expand All @@ -167,14 +167,14 @@ async def list_workflows(
for workflow, _ in workflows
}

# TODO - Create chunks of N dag_ids if necessary
# TODO - Create chunks of N dag_ids if necessary
airflow_dags_responses = await self.get_airflow_dags_by_id_gather_chunk(list(workflow_uuid_map.keys()))

max_total_errors_limit = 100
import_errors_response = self.airflow_client.list_import_errors(limit=max_total_errors_limit)
if import_errors_response.status_code != 200:
raise BaseException("Error when trying to fetch import errors from airflow webserver.")

import_errors_response_content = import_errors_response.json()
if import_errors_response_content['total_entries'] >= max_total_errors_limit:
# TODO handle to many import errors
Expand Down Expand Up @@ -203,15 +203,15 @@ async def list_workflows(
is_paused = False

if response and not is_dag_broken:
schedule = response.get("schedule")
schedule = response.get("schedule_interval")
if isinstance(schedule, dict):
schedule = schedule.get("value")
status = WorkflowStatus.active.value

is_paused = response.get("is_paused")
is_active = response.get("is_active")
next_dagrun = response.get("next_dagrun")

data.append(
GetWorkflowsResponseData(
id=dag_data.id,
Expand Down Expand Up @@ -253,7 +253,7 @@ def get_workflow(self, workspace_id: int, workflow_id: str, auth_context: Author
raise ForbiddenException()

airflow_dag_info = self.airflow_client.get_dag_by_id(dag_id=workflow.uuid_name)

if airflow_dag_info.status_code == 404:
airflow_dag_info = {
'is_paused': 'creating',
Expand All @@ -272,7 +272,7 @@ def get_workflow(self, workspace_id: int, workflow_id: str, auth_context: Author
}
else:
airflow_dag_info = airflow_dag_info.json()

# Airflow 2.4.0 deprecated schedule_interval in dag but the API (2.7.2) still using it
schedule = airflow_dag_info.pop("schedule_interval")
if isinstance(schedule, dict):
Expand All @@ -293,11 +293,11 @@ def get_workflow(self, workspace_id: int, workflow_id: str, auth_context: Author
)

return response


def check_existing_workflow(self):
pass

def _validate_workflow_tasks(self, tasks_dict: dict, workspace_id: int):
# get all repositories ids necessary for tasks
# get all workspaces ids necessary for repositories referenced by tasks
Expand All @@ -309,7 +309,7 @@ def _validate_workflow_tasks(self, tasks_dict: dict, workspace_id: int):
pieces_names.add(v['piece']['name'])
pieces_source_images.add(v['piece']['source_image'])
shared_storage_sources.append(v['workflow_shared_storage']['source'])

shared_storage_source = list(set(shared_storage_sources))
if len(shared_storage_source) > 1:
raise BadRequestException("Workflow can't have more than one shared storage source.")
Expand All @@ -333,7 +333,7 @@ def _validate_workflow_tasks(self, tasks_dict: dict, workspace_id: int):
for secret in shared_storage_secrets:
if not secret.value and not secret.required:
raise BadRequestException("Missing secrets for shared storage.")

# Find necessary repositories from pieces names and workspace_id
necessary_repositories_and_pieces = self.piece_repository.find_repositories_by_piece_name_and_workspace_id(
pieces_names=pieces_names,
Expand All @@ -358,13 +358,13 @@ def _validate_workflow_tasks(self, tasks_dict: dict, workspace_id: int):
def _get_storage_repository_from_tasks(self, tasks: list, workspace_id: int):
if not tasks:
return None

usage_task = tasks[list(tasks.keys())[0]]
workflow_shared_storage = usage_task.get('workflow_shared_storage', None)

shared_source = WorkflowSharedStorageSourceEnum(workflow_shared_storage['source']).name
shared_model = storage_default_piece_model_map.get(shared_source, None)

if not shared_model:
return None

Expand Down Expand Up @@ -410,7 +410,7 @@ def _create_dag_code_from_raw_json(self, data: dict, workspace_id: int):

if 'end_date' in workflow_kwargs and not workflow_kwargs['end_date']:
workflow_kwargs.pop('end_date')

# TODO define how to use generate report
remove_from_dag_kwargs = ['name', 'workspace_id', 'generate_report', 'description', 'id']
for key in remove_from_dag_kwargs:
Expand All @@ -424,7 +424,7 @@ def _create_dag_code_from_raw_json(self, data: dict, workspace_id: int):
raw_input_kwargs = task_value.get('piece_input_kwargs', {})
workflow_shared_storage = task_value.get('workflow_shared_storage', None)
container_resources = task_value.get('container_resources', None)

if container_resources:
container_resources['requests']['cpu'] = f'{container_resources["requests"]["cpu"]}m'
container_resources['requests']['memory'] = f'{container_resources["requests"]["memory"]}Mi'
Expand Down Expand Up @@ -470,7 +470,7 @@ def _create_dag_code_from_raw_json(self, data: dict, workspace_id: int):
input_kwargs[input_key] = array_input_kwargs
else:
input_kwargs[input_key] = input_value['value']


# piece_request = {"id": 1, "name": "SimpleLogPiece"}
piece_db = self.piece_repository.find_repository_by_piece_name_and_workspace_id(
Expand Down Expand Up @@ -500,7 +500,7 @@ def _create_dag_code_from_raw_json(self, data: dict, workspace_id: int):
)
workflow_processed_schema['tasks'] = stream_tasks_dict
io_obj = io.StringIO()
stream.dump(io_obj)
stream.dump(io_obj)
py_code = io_obj.getvalue()

return workflow_processed_schema, py_code, pieces_repositories_ids
Expand All @@ -523,7 +523,7 @@ def run_workflow(self, workflow_id: int):
"is_paused": False
}
update_response = self.airflow_client.update_dag(
dag_id=airflow_workflow_id,
dag_id=airflow_workflow_id,
payload=payload
)
if update_response.status_code == 404:
Expand Down Expand Up @@ -553,7 +553,7 @@ async def delete_workflow_files(self, workflow_uuid):
path=f"{settings.DOMINO_LOCAL_WORKFLOWS_REPOSITORY}/{workflow_uuid}.py"
)
return

self.github_rest_client.delete_file(
repo_name=settings.DOMINO_GITHUB_WORKFLOWS_REPOSITORY,
file_path=f"workflows/{workflow_uuid}.py"
Expand Down Expand Up @@ -617,11 +617,11 @@ def list_workflow_runs(self, workflow_id: int, page: int, page_size: int):
)
)
return response

records = len(response_data['dag_runs'])
total = response_data['total_entries']
last_page = ceil(total / page_size) - 1

# Airflow API returns always the same number of elements defined by page size.
# So if we are in the last page we need to remove the extra elements.
# Example: total = 12, page_size = 10, last_page = 1.
Expand Down Expand Up @@ -651,7 +651,7 @@ def list_run_tasks(self, workflow_id: int, workflow_run_id: str, page: int, page
workflow = self.workflow_repository.find_by_id(id=workflow_id)
if not workflow:
raise ResourceNotFoundException("Workflow not found")

airflow_workflow_id = workflow.uuid_name

response = self.airflow_client.get_all_run_tasks_instances(
Expand Down Expand Up @@ -704,17 +704,17 @@ def parse_log(log_text: str):
# If the log is empty probably it is still running so we can parse all logs from the start pattern
if not log:
log = re.findall(f"[^\n]*{start_command_pattern}.*", log_text, re.DOTALL)

if not log:
return []

# Parse the log lines
output_lines = []
for line in log[0].split('\n')[:-1]:
# Remove pod manager info it exists
l = re.sub(r"{pod_manager.py:[0-9]*}", '', line)
# Get datetime pattern
datetime_pattern = r'\[*\d{4}[-/]\d{2}[-/]\d{2} \d{2}:\d{2}:\d{2}(,|.)\d+\]*'
datetime_pattern = r'\[*\d{4}[-/]\d{2}[-/]\d{2} \d{2}:\d{2}:\d{2}(,|.)\d+\]*'
# Remove duplicated datetimes if they exist (they are added by the airflow logger and the domino so in some cases we have 2 datetimes in one line)
matches = list(re.finditer(datetime_pattern, l))
if len(matches) > 1:
Expand Down Expand Up @@ -750,8 +750,8 @@ def parse_log(log_text: str):
if l:
output_lines.append(l)
continue
# Is header, get the header and the values for the domino header and the content

# Is header, get the header and the values for the domino header and the content
header_value = header_match.group()
content_value = l.replace(header_value, '')
content_value = content_value.strip()
Expand Down

0 comments on commit 46d4088

Please sign in to comment.