Skip to content

Commit

Permalink
global: support Jinja templating for job args
Browse files Browse the repository at this point in the history
* Closes #36.
  • Loading branch information
slint committed Jun 17, 2024
1 parent 47a91bd commit 5e83e28
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 26 deletions.
48 changes: 47 additions & 1 deletion invenio_jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from datetime import timedelta
from inspect import signature

import sqlalchemy as sa
from celery import current_app as current_celery_app
from celery.schedules import crontab
from invenio_accounts.models import User
Expand All @@ -23,6 +24,8 @@
from sqlalchemy_utils.types import ChoiceType, JSONType, UUIDType
from werkzeug.utils import cached_property

from .utils import eval_tpl_str, walk_values

JSON = (
db.JSON()
.with_variant(postgresql.JSONB(none_as_null=True), "postgresql")
Expand All @@ -31,6 +34,11 @@
)


def _dump_dict(model):
"""Dump a model to a dictionary."""
return {c.key: getattr(model, c.key) for c in sa.inspect(model).mapper.column_attrs}


class Job(db.Model, Timestamp):
"""Job model."""

Expand All @@ -44,7 +52,6 @@ class Job(db.Model, Timestamp):
default_args = db.Column(JSON, default=lambda: dict(), nullable=True)
schedule = db.Column(JSON, nullable=True)

# TODO: See if we move this to an API class
@property
def last_run(self):
"""Last run of the job."""
Expand All @@ -63,6 +70,10 @@ def parsed_schedule(self):
elif stype == "interval":
return timedelta(**schedule)

def dump(self):
"""Dump the job as a dictionary."""
return _dump_dict(self)


class RunStatusEnum(enum.Enum):
"""Enumeration of a run's possible states."""
Expand Down Expand Up @@ -109,6 +120,41 @@ def started_by(self):
args = db.Column(JSON, default=lambda: dict(), nullable=True)
queue = db.Column(db.String(64), nullable=False)

@classmethod
def create(cls, job, **kwargs):
"""Create a new run."""
if "args" not in kwargs:
kwargs["args"] = cls.generate_args(job)
if "queue" not in kwargs:
kwargs["queue"] = job.default_queue

return cls(job=job, **kwargs)

@classmethod
def generate_args(cls, job):
"""Generate new run args.
We allow a templating mechanism to generate the args for the run. It's important
that the Jinja template context only includes "safe" values, i.e. no DB model
classes or Python objects or functions. Otherwise we risk that users could
execute arbitrary code, or perform harfmul DB operations (e.g. delete rows).
"""
args = deepcopy(job.default_args)
ctx = {"job": job.dump()}
# Add last runs
last_runs = {}
for status in RunStatusEnum:
run = job.runs.filter_by(status=status).order_by(cls.created.desc()).first()
last_runs[status.name.lower()] = run.dump() if run else None
ctx["last_runs"] = last_runs
ctx["last_run"] = job.last_run.dump() if job.last_run else None
walk_values(args, lambda val: eval_tpl_str(val, ctx))
return args

def dump(self):
"""Dump the run as a dictionary."""
return _dump_dict(self)


class Task:
"""Celery Task model."""
Expand Down
32 changes: 12 additions & 20 deletions invenio_jobs/services/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@

import traceback
import uuid
from typing import Any

from celery.beat import ScheduleEntry, Scheduler, logger
from invenio_db import db
from sqlalchemy import and_

from invenio_jobs.models import Job, Run, Task
from invenio_jobs.models import Job, Run
from invenio_jobs.tasks import execute_run


Expand Down Expand Up @@ -49,27 +47,23 @@ class RunScheduler(Scheduler):
Entry = JobEntry
entries = {}

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the database scheduler."""
super().__init__(*args, **kwargs)

@property
def schedule(self):
"""Get currently scheduled entries."""
return self.entries

# Celery internal override
#
# Celery overrides
#
def setup_schedule(self):
"""Setup schedule."""
self.sync()

# Celery internal override
def reserve(self, entry):
"""Update entry to next run execution time."""
new_entry = self.schedule[entry.job.id] = next(entry)
return new_entry

# Celery internal override
def apply_entry(self, entry, producer=None):
"""Create and apply a JobEntry."""
with self.app.flask_app.app_context():
Expand All @@ -93,26 +87,24 @@ def apply_entry(self, entry, producer=None):
else:
logger.debug("%s sent.", entry.task)

# Celery internal override
def sync(self):
"""Sync Jobs from db to the scheduler."""
# TODO Should we also have a cleaup task for runs? "stale" run (status running, starttime > hour, Run pending for > 1 hr)
with self.app.flask_app.app_context():
jobs = Job.query.filter(
and_(Job.active == True, Job.schedule != None)
).all()
Job.active.is_(True),
Job.schedule.isnot(None),
)
self.entries = {} # because some jobs might be deactivated
for job in jobs:
self.entries[job.id] = JobEntry.from_job(job)

#
# Helpers
#
def create_run(self, entry):
"""Create run from a JobEntry."""
job = Job.query.filter_by(id=entry.job.id).one()
run = Run(
job=job,
args=job.default_args,
queue=job.default_queue,
task_id=uuid.uuid4(),
)
job = Job.query.get(entry.job.id)
run = Run.create(job=job, task_id=uuid.uuid4())
db.session.commit()
return run
7 changes: 3 additions & 4 deletions invenio_jobs/services/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,10 @@ def create(self, identity, job_id, data, uow=None):
raise_errors=True,
)

valid_data.setdefault("queue", job.default_queue)
run = Run(
id=str(uuid.uuid4()),
run = Run.create(
job=job,
task_id=uuid.uuid4(),
id=str(uuid.uuid4()),
task_id=str(uuid.uuid4()),
started_by_id=identity.id,
status=RunStatusEnum.QUEUED,
**valid_data,
Expand Down
1 change: 1 addition & 0 deletions invenio_jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# under the terms of the MIT License; see LICENSE file for more details.

"""Tasks."""

from datetime import datetime, timezone

from celery import shared_task
Expand Down
43 changes: 43 additions & 0 deletions invenio_jobs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# Invenio-Jobs is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Utilities."""

import ast

from jinja2.sandbox import SandboxedEnvironment

jinja_env = SandboxedEnvironment()


def eval_tpl_str(val, ctx):
"""Evaluate a Jinja template string."""
tpl = jinja_env.from_string(val)
res = tpl.render(**ctx)

try:
res = ast.literal_eval(res)
except Exception:
pass

return res


def walk_values(obj, transform_fn):
"""Recursively apply a function in-place to the value of dictionary or list."""
if isinstance(obj, dict):
items = obj.items()
elif isinstance(obj, list):
items = enumerate(obj)
else:
return transform_fn(obj)

for key, val in items:
if isinstance(val, (dict, list)):
walk_values(val, transform_fn)
else:
obj[key] = transform_fn(val)
125 changes: 124 additions & 1 deletion tests/resources/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

"""Resource tests."""

import pdb
from unittest.mock import patch

from invenio_jobs.tasks import execute_run
Expand Down Expand Up @@ -429,3 +428,127 @@ def test_jobs_delete(db, client, jobs):
assert res.json["hits"]["total"] == 2
hits = res.json["hits"]["hits"]
assert all(j["id"] != jobs.simple.id for j in hits)


@patch.object(execute_run, "apply_async")
def test_job_template_args(mock_apply_async, app, db, client, user):
client = user.login(client)
job_payload = {
"title": "Job with template args",
"task": "tasks.mock_task",
"default_args": {
"arg1": "{{ 1 + 1 }}",
"arg2": "{{ job.title | upper }}",
"kwarg1": "{{ last_run.created.isoformat() if last_run else None }}",
},
}

# Create a job
res = client.post("/jobs", json=job_payload)
assert res.status_code == 201
job_id = res.json["id"]
expected_job = {
"id": job_id,
"title": "Job with template args",
"description": None,
"active": True,
"task": "tasks.mock_task",
"default_queue": "celery",
"default_args": {
"arg1": "{{ 1 + 1 }}",
"arg2": "{{ job.title | upper }}",
"kwarg1": "{{ last_run.created.isoformat() if last_run else None }}",
},
"schedule": None,
"last_run": None,
"created": res.json["created"],
"updated": res.json["updated"],
"links": {
"self": f"https://127.0.0.1:5000/api/jobs/{job_id}",
"runs": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs",
},
}
assert res.json == expected_job

# Create/trigger a run
res = client.post(f"/jobs/{job_id}/runs")
assert res.status_code == 201
run_id = res.json["id"]
expected_run = {
"id": run_id,
"job_id": job_id,
"task_id": res.json["task_id"],
"started_by_id": int(user.id),
"started_by": {
"id": str(user.id),
"username": user.username,
"profile": user._user_profile,
"links": {
# "self": f"https://127.0.0.1:5000/api/users/{user.id}",
},
"identities": {},
"is_current_user": True,
"type": "user",
},
"started_at": res.json["started_at"],
"finished_at": res.json["finished_at"],
"status": "QUEUED",
"message": None,
"title": None,
"args": {
"arg1": 2,
"arg2": "JOB WITH TEMPLATE ARGS",
"kwarg1": None,
},
"queue": "celery",
"created": res.json["created"],
"updated": res.json["updated"],
"links": {
"self": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}",
"logs": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/logs",
"stop": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/actions/stop",
},
}
assert res.json == expected_run
last_run_created = res.json["created"].replace("+00:00", "")

# Trigger another run to test the kwarg1 template depending on the last run
res = client.post(f"/jobs/{job_id}/runs")
assert res.status_code == 201
run_id = res.json["id"]
expected_run = {
"id": run_id,
"job_id": job_id,
"task_id": res.json["task_id"],
"started_by_id": int(user.id),
"started_by": {
"id": str(user.id),
"username": user.username,
"profile": user._user_profile,
"links": {
# "self": f"https://127.0.0.1:5000/api/users/{user.id}",
},
"identities": {},
"is_current_user": True,
"type": "user",
},
"started_at": res.json["started_at"],
"finished_at": res.json["finished_at"],
"status": "QUEUED",
"message": None,
"title": None,
"args": {
"arg1": 2,
"arg2": "JOB WITH TEMPLATE ARGS",
"kwarg1": last_run_created,
},
"queue": "celery",
"created": res.json["created"],
"updated": res.json["updated"],
"links": {
"self": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}",
"logs": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/logs",
"stop": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/actions/stop",
},
}
assert res.json == expected_run

0 comments on commit 5e83e28

Please sign in to comment.