Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

global: support Jinja templating for job args #37

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion invenio_jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from datetime import timedelta
from inspect import signature

import sqlalchemy as sa
from celery import current_app as current_celery_app
from celery.schedules import crontab
from invenio_accounts.models import User
Expand All @@ -23,6 +24,8 @@
from sqlalchemy_utils.types import ChoiceType, JSONType, UUIDType
from werkzeug.utils import cached_property

from .utils import eval_tpl_str, walk_values

JSON = (
db.JSON()
.with_variant(postgresql.JSONB(none_as_null=True), "postgresql")
Expand All @@ -31,6 +34,11 @@
)


def _dump_dict(model):
"""Dump a model to a dictionary."""
return {c.key: getattr(model, c.key) for c in sa.inspect(model).mapper.column_attrs}


class Job(db.Model, Timestamp):
"""Job model."""

Expand All @@ -44,7 +52,6 @@ class Job(db.Model, Timestamp):
default_args = db.Column(JSON, default=lambda: dict(), nullable=True)
schedule = db.Column(JSON, nullable=True)

# TODO: See if we move this to an API class
@property
def last_run(self):
"""Last run of the job."""
Expand All @@ -63,6 +70,10 @@ def parsed_schedule(self):
elif stype == "interval":
return timedelta(**schedule)

def dump(self):
"""Dump the job as a dictionary."""
return _dump_dict(self)


class RunStatusEnum(enum.Enum):
"""Enumeration of a run's possible states."""
Expand Down Expand Up @@ -109,6 +120,41 @@ def started_by(self):
args = db.Column(JSON, default=lambda: dict(), nullable=True)
queue = db.Column(db.String(64), nullable=False)

@classmethod
def create(cls, job, **kwargs):
"""Create a new run."""
if "args" not in kwargs:
kwargs["args"] = cls.generate_args(job)
if "queue" not in kwargs:
kwargs["queue"] = job.default_queue

return cls(job=job, **kwargs)

@classmethod
def generate_args(cls, job):
"""Generate new run args.

We allow a templating mechanism to generate the args for the run. It's important
that the Jinja template context only includes "safe" values, i.e. no DB model
classes or Python objects or functions. Otherwise we risk that users could
execute arbitrary code, or perform harfmul DB operations (e.g. delete rows).
"""
args = deepcopy(job.default_args)
ctx = {"job": job.dump()}
# Add last runs
last_runs = {}
for status in RunStatusEnum:
run = job.runs.filter_by(status=status).order_by(cls.created.desc()).first()
last_runs[status.name.lower()] = run.dump() if run else None
ctx["last_runs"] = last_runs
ctx["last_run"] = job.last_run.dump() if job.last_run else None
walk_values(args, lambda val: eval_tpl_str(val, ctx))
return args

def dump(self):
"""Dump the run as a dictionary."""
return _dump_dict(self)


class Task:
"""Celery Task model."""
Expand Down
32 changes: 12 additions & 20 deletions invenio_jobs/services/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@

import traceback
import uuid
from typing import Any

from celery.beat import ScheduleEntry, Scheduler, logger
from invenio_db import db
from sqlalchemy import and_

from invenio_jobs.models import Job, Run, Task
from invenio_jobs.models import Job, Run
from invenio_jobs.tasks import execute_run


Expand Down Expand Up @@ -49,27 +47,23 @@ class RunScheduler(Scheduler):
Entry = JobEntry
entries = {}

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the database scheduler."""
super().__init__(*args, **kwargs)

@property
def schedule(self):
"""Get currently scheduled entries."""
return self.entries

# Celery internal override
#
# Celery overrides
#
def setup_schedule(self):
"""Setup schedule."""
self.sync()

# Celery internal override
def reserve(self, entry):
"""Update entry to next run execution time."""
new_entry = self.schedule[entry.job.id] = next(entry)
return new_entry

# Celery internal override
def apply_entry(self, entry, producer=None):
"""Create and apply a JobEntry."""
with self.app.flask_app.app_context():
Expand All @@ -93,26 +87,24 @@ def apply_entry(self, entry, producer=None):
else:
logger.debug("%s sent.", entry.task)

# Celery internal override
def sync(self):
"""Sync Jobs from db to the scheduler."""
# TODO Should we also have a cleaup task for runs? "stale" run (status running, starttime > hour, Run pending for > 1 hr)
with self.app.flask_app.app_context():
jobs = Job.query.filter(
and_(Job.active == True, Job.schedule != None)
).all()
Job.active.is_(True),
Job.schedule.isnot(None),
)
self.entries = {} # because some jobs might be deactivated
for job in jobs:
self.entries[job.id] = JobEntry.from_job(job)

#
# Helpers
#
def create_run(self, entry):
"""Create run from a JobEntry."""
job = Job.query.filter_by(id=entry.job.id).one()
run = Run(
job=job,
args=job.default_args,
queue=job.default_queue,
task_id=uuid.uuid4(),
)
job = Job.query.get(entry.job.id)
run = Run.create(job=job, task_id=uuid.uuid4())
db.session.commit()
return run
7 changes: 3 additions & 4 deletions invenio_jobs/services/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,10 @@ def create(self, identity, job_id, data, uow=None):
raise_errors=True,
)

valid_data.setdefault("queue", job.default_queue)
run = Run(
id=str(uuid.uuid4()),
run = Run.create(
job=job,
task_id=uuid.uuid4(),
id=str(uuid.uuid4()),
task_id=str(uuid.uuid4()),
started_by_id=identity.id,
status=RunStatusEnum.QUEUED,
**valid_data,
Expand Down
1 change: 1 addition & 0 deletions invenio_jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# under the terms of the MIT License; see LICENSE file for more details.

"""Tasks."""

from datetime import datetime, timezone

from celery import shared_task
Expand Down
43 changes: 43 additions & 0 deletions invenio_jobs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# Invenio-Jobs is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Utilities."""

import ast

from jinja2.sandbox import SandboxedEnvironment

jinja_env = SandboxedEnvironment()


def eval_tpl_str(val, ctx):
"""Evaluate a Jinja template string."""
tpl = jinja_env.from_string(val)
res = tpl.render(**ctx)

try:
res = ast.literal_eval(res)
except Exception:
pass

return res


def walk_values(obj, transform_fn):
"""Recursively apply a function in-place to the value of dictionary or list."""
if isinstance(obj, dict):
items = obj.items()
elif isinstance(obj, list):
items = enumerate(obj)
else:
return transform_fn(obj)

for key, val in items:
if isinstance(val, (dict, list)):
walk_values(val, transform_fn)
else:
obj[key] = transform_fn(val)
125 changes: 124 additions & 1 deletion tests/resources/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

"""Resource tests."""

import pdb
from unittest.mock import patch

from invenio_jobs.tasks import execute_run
Expand Down Expand Up @@ -429,3 +428,127 @@ def test_jobs_delete(db, client, jobs):
assert res.json["hits"]["total"] == 2
hits = res.json["hits"]["hits"]
assert all(j["id"] != jobs.simple.id for j in hits)


@patch.object(execute_run, "apply_async")
def test_job_template_args(mock_apply_async, app, db, client, user):
client = user.login(client)
job_payload = {
"title": "Job with template args",
"task": "tasks.mock_task",
"default_args": {
"arg1": "{{ 1 + 1 }}",
"arg2": "{{ job.title | upper }}",
"kwarg1": "{{ last_run.created.isoformat() if last_run else None }}",
},
}

# Create a job
res = client.post("/jobs", json=job_payload)
assert res.status_code == 201
job_id = res.json["id"]
expected_job = {
"id": job_id,
"title": "Job with template args",
"description": None,
"active": True,
"task": "tasks.mock_task",
"default_queue": "celery",
"default_args": {
"arg1": "{{ 1 + 1 }}",
"arg2": "{{ job.title | upper }}",
"kwarg1": "{{ last_run.created.isoformat() if last_run else None }}",
},
"schedule": None,
"last_run": None,
"created": res.json["created"],
"updated": res.json["updated"],
"links": {
"self": f"https://127.0.0.1:5000/api/jobs/{job_id}",
"runs": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs",
},
}
assert res.json == expected_job

# Create/trigger a run
res = client.post(f"/jobs/{job_id}/runs")
assert res.status_code == 201
run_id = res.json["id"]
expected_run = {
"id": run_id,
"job_id": job_id,
"task_id": res.json["task_id"],
"started_by_id": int(user.id),
"started_by": {
"id": str(user.id),
"username": user.username,
"profile": user._user_profile,
"links": {
# "self": f"https://127.0.0.1:5000/api/users/{user.id}",
},
"identities": {},
"is_current_user": True,
"type": "user",
},
"started_at": res.json["started_at"],
"finished_at": res.json["finished_at"],
"status": "QUEUED",
"message": None,
"title": None,
"args": {
"arg1": 2,
"arg2": "JOB WITH TEMPLATE ARGS",
"kwarg1": None,
},
"queue": "celery",
"created": res.json["created"],
"updated": res.json["updated"],
"links": {
"self": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}",
"logs": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/logs",
"stop": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/actions/stop",
},
}
assert res.json == expected_run
last_run_created = res.json["created"].replace("+00:00", "")

# Trigger another run to test the kwarg1 template depending on the last run
res = client.post(f"/jobs/{job_id}/runs")
assert res.status_code == 201
run_id = res.json["id"]
expected_run = {
"id": run_id,
"job_id": job_id,
"task_id": res.json["task_id"],
"started_by_id": int(user.id),
"started_by": {
"id": str(user.id),
"username": user.username,
"profile": user._user_profile,
"links": {
# "self": f"https://127.0.0.1:5000/api/users/{user.id}",
},
"identities": {},
"is_current_user": True,
"type": "user",
},
"started_at": res.json["started_at"],
"finished_at": res.json["finished_at"],
"status": "QUEUED",
"message": None,
"title": None,
"args": {
"arg1": 2,
"arg2": "JOB WITH TEMPLATE ARGS",
"kwarg1": last_run_created,
},
"queue": "celery",
"created": res.json["created"],
"updated": res.json["updated"],
"links": {
"self": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}",
"logs": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/logs",
"stop": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/actions/stop",
},
}
assert res.json == expected_run