Skip to content

Commit

Permalink
added test generation code to pydra generate
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed Aug 18, 2023
1 parent 8e9db03 commit 16d9755
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 12 deletions.
78 changes: 66 additions & 12 deletions pydra/pydra-auto-gen.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import os
import sys
import attrs
from pathlib import Path
import subprocess as sp
import typing as ty
from importlib import import_module
import logging
import re
import click
import black.report
import black.parsing


logging.basicConfig(level=logging.WARNING)

logger = logging.getLogger("pydra-auto-gen")
Expand Down Expand Up @@ -41,18 +45,21 @@
help="whether to log errors (and skip to the next tool) instead of raising them",
)
def auto_gen_mrtrix3_pydra(cmd_dir: Path, output_dir: Path, log_errors: bool):

cmd_dir = cmd_dir.absolute()

# Insert output dir to path so we can load the generated tasks in order to
# generate the tests
sys.path.insert(0, str(output_dir))

for cmd_name in sorted(os.listdir(cmd_dir)):
if cmd_name.startswith("_") or "." in cmd_name or cmd_name in IGNORE:
continue
cmd = [str(cmd_dir / cmd_name)]
auto_gen_cmd(cmd, cmd_name, output_dir, log_errors)
auto_gen_test(cmd_name, output_dir, log_errors)


def auto_gen_cmd(cmd: ty.List[str], cmd_name: str, output_dir: Path, log_errors: bool):

try:
code_str = sp.check_output(cmd + ["__print_usage_pydra__"]).decode("utf-8")
except sp.CalledProcessError:
Expand All @@ -74,15 +81,9 @@ def auto_gen_cmd(cmd: ty.List[str], cmd_name: str, output_dir: Path, log_errors:
if cmd_name.startswith("5tt"):
old_name = cmd_name
cmd_name = old_name.replace("5tt", "fivetissuetype")
code_str = code_str.replace(
f"class {old_name}", f"class {cmd_name}"
)
code_str = code_str.replace(
f"{old_name}_input_spec", f"{cmd_name}_input_spec"
)
code_str = code_str.replace(
f"{old_name}_output_spec", f"{cmd_name}_input_spec"
)
code_str = code_str.replace(f"class {old_name}", f"class {cmd_name}")
code_str = code_str.replace(f"{old_name}_input_spec", f"{cmd_name}_input_spec")
code_str = code_str.replace(f"{old_name}_output_spec", f"{cmd_name}_input_spec")
try:
code_str = black.format_file_contents(
code_str, fast=False, mode=black.FileMode()
Expand All @@ -100,15 +101,68 @@ def auto_gen_cmd(cmd: ty.List[str], cmd_name: str, output_dir: Path, log_errors:
logger.info("%s", cmd_name)


def auto_gen_test(cmd_name: str, output_dir: Path, log_errors: bool):

tests_dir = output_dir / "tests"
tests_dir.mkdir(exist_ok=True)
module = import_module(cmd_name)
task = getattr(module, cmd_name)()

code_str = f"""# Auto-generated tests for {cmd_name}
from .{cmd_name} import {cmd_name}
def test_{cmd_name}():
task = {cmd_name}(
"""
for field in attrs.fields(type(task.inputs)):
if field.default is not attrs.NOTHING:
value = field.default
else:
raise NotImplementedError

code_str += f" {field.name}={value},\n"

code_str += """
)
result = task(plugin="serial")
assert result.exit_code == 0
"""

try:
code_str = black.format_file_contents(
code_str, fast=False, mode=black.FileMode()
)
except black.report.NothingChanged:
pass
except black.parsing.InvalidInput:
if log_errors:
logger.error("Could not parse generated interface for '%s'", cmd_name)
else:
raise

with open(tests_dir / f"test_{cmd_name}.py", "w") as f:
f.write(code_str)


if __name__ == "__main__":
from pathlib import Path

script_dir = Path(__file__).parent

mrtrix_version = sp.check_output(
"git describe --tags --abbrev=0", cwd=script_dir, shell=True
).decode("utf-8")
pkg_version = "v" + "_".join(mrtrix_version.split(".")[:2])

auto_gen_mrtrix3_pydra(
[
str(script_dir.parent / "bin"),
str(script_dir / "src" / "pydra" / "tasks" / "mrtrix3" / "latest"),
str(script_dir / "src" / "pydra" / "tasks" / "mrtrix3" / pkg_version),
"--raise-errors",
]
)

latest = script_dir / "src" / "pydra" / "tasks" / "mrtrix3" / "latest.py"
latest.write_text(f"# Auto-generated, do not edit\n\nfrom .{pkg_version} import *\n")
print(f"Generated pydra.tasks.mrtrix3.{pkg_version} package")
1 change: 1 addition & 0 deletions pydra/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
black >=22.12.0
click >=8.1.3
attrs >=23.1.0

0 comments on commit 16d9755

Please sign in to comment.