From 6f247af33ac4da3fd87e2b330de1342c04176406 Mon Sep 17 00:00:00 2001 From: Alex Ioannidis Date: Tue, 11 Jun 2024 08:58:22 +0200 Subject: [PATCH] global: support Jinja templating for job args * Closes #36. --- invenio_jobs/models.py | 48 ++++++++++- invenio_jobs/services/scheduler.py | 32 +++----- invenio_jobs/services/services.py | 7 +- invenio_jobs/tasks.py | 1 + invenio_jobs/utils.py | 43 ++++++++++ tests/resources/test_resources.py | 125 ++++++++++++++++++++++++++++- 6 files changed, 230 insertions(+), 26 deletions(-) create mode 100644 invenio_jobs/utils.py diff --git a/invenio_jobs/models.py b/invenio_jobs/models.py index 8a45b09..dd2732b 100644 --- a/invenio_jobs/models.py +++ b/invenio_jobs/models.py @@ -13,6 +13,7 @@ from datetime import timedelta from inspect import signature +import sqlalchemy as sa from celery import current_app as current_celery_app from celery.schedules import crontab from invenio_accounts.models import User @@ -23,6 +24,8 @@ from sqlalchemy_utils.types import ChoiceType, JSONType, UUIDType from werkzeug.utils import cached_property +from .utils import eval_tpl_str, walk_values + JSON = ( db.JSON() .with_variant(postgresql.JSONB(none_as_null=True), "postgresql") @@ -31,6 +34,11 @@ ) +def _dump_dict(model): + """Dump a model to a dictionary.""" + return {c.key: getattr(model, c.key) for c in sa.inspect(model).mapper.column_attrs} + + class Job(db.Model, Timestamp): """Job model.""" @@ -44,7 +52,6 @@ class Job(db.Model, Timestamp): default_args = db.Column(JSON, default=lambda: dict(), nullable=True) schedule = db.Column(JSON, nullable=True) - # TODO: See if we move this to an API class @property def last_run(self): """Last run of the job.""" @@ -63,6 +70,10 @@ def parsed_schedule(self): elif stype == "interval": return timedelta(**schedule) + def dump(self): + """Dump the job as a dictionary.""" + return _dump_dict(self) + class RunStatusEnum(enum.Enum): """Enumeration of a run's possible states.""" @@ -109,6 +120,41 @@ def started_by(self): args = db.Column(JSON, default=lambda: dict(), nullable=True) queue = db.Column(db.String(64), nullable=False) + @classmethod + def create(cls, job, **kwargs): + """Create a new run.""" + if "args" not in kwargs: + kwargs["args"] = cls.generate_args(job) + if "queue" not in kwargs: + kwargs["queue"] = job.default_queue + + return cls(job=job, **kwargs) + + @classmethod + def generate_args(cls, job): + """Generate new run args. + + We allow a templating mechanism to generate the args for the run. It's important + that the Jinja template context only includes "safe" values, i.e. no DB model + classes or Python objects or functions. Otherwise we risk that users could + execute arbitrary code, or perform harfmul DB operations (e.g. delete rows). + """ + args = deepcopy(job.default_args) + ctx = {"job": job.dump()} + # Add last runs + last_runs = {} + for status in RunStatusEnum: + run = job.runs.filter_by(status=status).order_by(cls.created.desc()).first() + last_runs[status.name.lower()] = run.dump() if run else None + ctx["last_runs"] = last_runs + ctx["last_run"] = job.last_run.dump() if job.last_run else None + walk_values(args, lambda val: eval_tpl_str(val, ctx)) + return args + + def dump(self): + """Dump the run as a dictionary.""" + return _dump_dict(self) + class Task: """Celery Task model.""" diff --git a/invenio_jobs/services/scheduler.py b/invenio_jobs/services/scheduler.py index 270b6fa..494b9cc 100644 --- a/invenio_jobs/services/scheduler.py +++ b/invenio_jobs/services/scheduler.py @@ -9,13 +9,11 @@ import traceback import uuid -from typing import Any from celery.beat import ScheduleEntry, Scheduler, logger from invenio_db import db -from sqlalchemy import and_ -from invenio_jobs.models import Job, Run, Task +from invenio_jobs.models import Job, Run from invenio_jobs.tasks import execute_run @@ -49,27 +47,23 @@ class RunScheduler(Scheduler): Entry = JobEntry entries = {} - def __init__(self, *args: Any, **kwargs: Any) -> None: - """Initialize the database scheduler.""" - super().__init__(*args, **kwargs) - @property def schedule(self): """Get currently scheduled entries.""" return self.entries - # Celery internal override + # + # Celery overrides + # def setup_schedule(self): """Setup schedule.""" self.sync() - # Celery internal override def reserve(self, entry): """Update entry to next run execution time.""" new_entry = self.schedule[entry.job.id] = next(entry) return new_entry - # Celery internal override def apply_entry(self, entry, producer=None): """Create and apply a JobEntry.""" with self.app.flask_app.app_context(): @@ -93,26 +87,24 @@ def apply_entry(self, entry, producer=None): else: logger.debug("%s sent.", entry.task) - # Celery internal override def sync(self): """Sync Jobs from db to the scheduler.""" # TODO Should we also have a cleaup task for runs? "stale" run (status running, starttime > hour, Run pending for > 1 hr) with self.app.flask_app.app_context(): jobs = Job.query.filter( - and_(Job.active == True, Job.schedule != None) - ).all() + Job.active.is_(True), + Job.schedule.isnot(None), + ) self.entries = {} # because some jobs might be deactivated for job in jobs: self.entries[job.id] = JobEntry.from_job(job) + # + # Helpers + # def create_run(self, entry): """Create run from a JobEntry.""" - job = Job.query.filter_by(id=entry.job.id).one() - run = Run( - job=job, - args=job.default_args, - queue=job.default_queue, - task_id=uuid.uuid4(), - ) + job = Job.query.get(entry.job.id) + run = Run.create(job=job, task_id=uuid.uuid4()) db.session.commit() return run diff --git a/invenio_jobs/services/services.py b/invenio_jobs/services/services.py index 3ea604c..3a4192f 100644 --- a/invenio_jobs/services/services.py +++ b/invenio_jobs/services/services.py @@ -236,11 +236,10 @@ def create(self, identity, job_id, data, uow=None): raise_errors=True, ) - valid_data.setdefault("queue", job.default_queue) - run = Run( - id=str(uuid.uuid4()), + run = Run.create( job=job, - task_id=uuid.uuid4(), + id=str(uuid.uuid4()), + task_id=str(uuid.uuid4()), started_by_id=identity.id, status=RunStatusEnum.QUEUED, **valid_data, diff --git a/invenio_jobs/tasks.py b/invenio_jobs/tasks.py index 61ef9f1..3ba9bea 100644 --- a/invenio_jobs/tasks.py +++ b/invenio_jobs/tasks.py @@ -6,6 +6,7 @@ # under the terms of the MIT License; see LICENSE file for more details. """Tasks.""" + from datetime import datetime, timezone from celery import shared_task diff --git a/invenio_jobs/utils.py b/invenio_jobs/utils.py new file mode 100644 index 0000000..3ccdaf8 --- /dev/null +++ b/invenio_jobs/utils.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Invenio-Jobs is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Utilities.""" + +import ast + +from jinja2.sandbox import SandboxedEnvironment + +jinja_env = SandboxedEnvironment() + + +def eval_tpl_str(val, ctx): + """Evaluate a Jinja template string.""" + tpl = jinja_env.from_string(val) + res = tpl.render(**ctx) + + try: + res = ast.literal_eval(res) + except Exception: + pass + + return res + + +def walk_values(obj, transform_fn): + """Recursively apply a function in-place to the value of dictionary or list.""" + if isinstance(obj, dict): + items = obj.items() + elif isinstance(obj, list): + items = enumerate(obj) + else: + return transform_fn(obj) + + for key, val in items: + if isinstance(val, (dict, list)): + walk_values(val, transform_fn) + else: + obj[key] = transform_fn(val) diff --git a/tests/resources/test_resources.py b/tests/resources/test_resources.py index f6b53ab..4ef1f66 100644 --- a/tests/resources/test_resources.py +++ b/tests/resources/test_resources.py @@ -7,7 +7,6 @@ """Resource tests.""" -import pdb from unittest.mock import patch from invenio_jobs.tasks import execute_run @@ -429,3 +428,127 @@ def test_jobs_delete(db, client, jobs): assert res.json["hits"]["total"] == 2 hits = res.json["hits"]["hits"] assert all(j["id"] != jobs.simple.id for j in hits) + + +@patch.object(execute_run, "apply_async") +def test_job_template_args(mock_apply_async, app, db, client, user): + client = user.login(client) + job_payload = { + "title": "Job with template args", + "task": "tasks.mock_task", + "default_args": { + "arg1": "{{ 1 + 1 }}", + "arg2": "{{ job.title | upper }}", + "kwarg1": "{{ last_run.created.isoformat() if last_run else None }}", + }, + } + + # Create a job + res = client.post("/jobs", json=job_payload) + assert res.status_code == 201 + job_id = res.json["id"] + expected_job = { + "id": job_id, + "title": "Job with template args", + "description": None, + "active": True, + "task": "tasks.mock_task", + "default_queue": "celery", + "default_args": { + "arg1": "{{ 1 + 1 }}", + "arg2": "{{ job.title | upper }}", + "kwarg1": "{{ last_run.created.isoformat() if last_run else None }}", + }, + "schedule": None, + "last_run": None, + "created": res.json["created"], + "updated": res.json["updated"], + "links": { + "self": f"https://127.0.0.1:5000/api/jobs/{job_id}", + "runs": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs", + }, + } + assert res.json == expected_job + + # Create/trigger a run + res = client.post(f"/jobs/{job_id}/runs") + assert res.status_code == 201 + run_id = res.json["id"] + expected_run = { + "id": run_id, + "job_id": job_id, + "task_id": res.json["task_id"], + "started_by_id": int(user.id), + "started_by": { + "id": str(user.id), + "username": user.username, + "profile": user._user_profile, + "links": { + # "self": f"https://127.0.0.1:5000/api/users/{user.id}", + }, + "identities": {}, + "is_current_user": True, + "type": "user", + }, + "started_at": res.json["started_at"], + "finished_at": res.json["finished_at"], + "status": "QUEUED", + "message": None, + "title": None, + "args": { + "arg1": 2, + "arg2": "JOB WITH TEMPLATE ARGS", + "kwarg1": None, + }, + "queue": "celery", + "created": res.json["created"], + "updated": res.json["updated"], + "links": { + "self": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}", + "logs": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/logs", + "stop": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/actions/stop", + }, + } + assert res.json == expected_run + last_run_created = res.json["created"].replace("+00:00", "") + + # Trigger another run to test the kwarg1 template depending on the last run + res = client.post(f"/jobs/{job_id}/runs") + assert res.status_code == 201 + run_id = res.json["id"] + expected_run = { + "id": run_id, + "job_id": job_id, + "task_id": res.json["task_id"], + "started_by_id": int(user.id), + "started_by": { + "id": str(user.id), + "username": user.username, + "profile": user._user_profile, + "links": { + # "self": f"https://127.0.0.1:5000/api/users/{user.id}", + }, + "identities": {}, + "is_current_user": True, + "type": "user", + }, + "started_at": res.json["started_at"], + "finished_at": res.json["finished_at"], + "status": "QUEUED", + "message": None, + "title": None, + "args": { + "arg1": 2, + "arg2": "JOB WITH TEMPLATE ARGS", + "kwarg1": last_run_created, + }, + "queue": "celery", + "created": res.json["created"], + "updated": res.json["updated"], + "links": { + "self": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}", + "logs": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/logs", + "stop": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/actions/stop", + }, + } + assert res.json == expected_run