Skip to content

Commit

Permalink
fix(anta.cli): Evaluate nrfu subcommands args before running the tests (
Browse files Browse the repository at this point in the history
  • Loading branch information
gmuloc authored Jul 4, 2024
1 parent e9aff4a commit 0f88b28
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 22 deletions.
31 changes: 10 additions & 21 deletions anta/cli/nrfu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,14 @@

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, get_args

import click

from anta.cli.nrfu import commands
from anta.cli.utils import AliasedGroup, catalog_options, inventory_options
from anta.custom_types import TestStatus
from anta.models import AntaTest
from anta.result_manager import ResultManager
from anta.runner import main

from .utils import anta_progress_bar, print_settings

if TYPE_CHECKING:
from anta.catalog import AntaCatalog
Expand All @@ -37,6 +32,7 @@ def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]:
"""Ignore MissingParameter exception when parsing arguments if `--help` is present for a subcommand."""
# Adding a flag for potential callbacks
ctx.ensure_object(dict)
ctx.obj["args"] = args
if "--help" in args:
ctx.obj["_anta_help"] = True

Expand Down Expand Up @@ -125,29 +121,22 @@ def nrfu(
# If help is invoke somewhere, skip the command
if ctx.obj.get("_anta_help"):
return

# We use ctx.obj to pass stuff to the next Click functions
ctx.ensure_object(dict)
ctx.obj["result_manager"] = ResultManager()
ctx.obj["ignore_status"] = ignore_status
ctx.obj["ignore_error"] = ignore_error
ctx.obj["hide"] = set(hide) if hide else None
print_settings(inventory, catalog)
with anta_progress_bar() as AntaTest.progress:
asyncio.run(
main(
ctx.obj["result_manager"],
inventory,
catalog,
tags=tags,
devices=set(device) if device else None,
tests=set(test) if test else None,
dry_run=dry_run,
)
)
if dry_run:
return
ctx.obj["catalog"] = catalog
ctx.obj["inventory"] = inventory
ctx.obj["tags"] = tags
ctx.obj["device"] = device
ctx.obj["test"] = test
ctx.obj["dry_run"] = dry_run

# Invoke `anta nrfu table` if no command is passed
if ctx.invoked_subcommand is None:
if not ctx.invoked_subcommand:
ctx.invoke(commands.table)


Expand Down
6 changes: 5 additions & 1 deletion anta/cli/nrfu/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from anta.cli.utils import exit_with_code

from .utils import print_jinja, print_json, print_table, print_text
from .utils import print_jinja, print_json, print_table, print_text, run_tests

logger = logging.getLogger(__name__)

Expand All @@ -32,6 +32,7 @@ def table(
group_by: Literal["device", "test"] | None,
) -> None:
"""ANTA command to check network states with table result."""
run_tests(ctx)
print_table(ctx, group_by=group_by)
exit_with_code(ctx)

Expand All @@ -48,6 +49,7 @@ def table(
)
def json(ctx: click.Context, output: pathlib.Path | None) -> None:
"""ANTA command to check network state with JSON result."""
run_tests(ctx)
print_json(ctx, output=output)
exit_with_code(ctx)

Expand All @@ -56,6 +58,7 @@ def json(ctx: click.Context, output: pathlib.Path | None) -> None:
@click.pass_context
def text(ctx: click.Context) -> None:
"""ANTA command to check network states with text result."""
run_tests(ctx)
print_text(ctx)
exit_with_code(ctx)

Expand All @@ -80,5 +83,6 @@ def text(ctx: click.Context) -> None:
)
def tpl_report(ctx: click.Context, template: pathlib.Path, output: pathlib.Path | None) -> None:
"""ANTA command to check network state with templated report."""
run_tests(ctx)
print_jinja(results=ctx.obj["result_manager"], template=template, output=output)
exit_with_code(ctx)
34 changes: 34 additions & 0 deletions anta/cli/nrfu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import asyncio
import json
import logging
from typing import TYPE_CHECKING, Literal
Expand All @@ -14,7 +15,9 @@
from rich.progress import BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn

from anta.cli.console import console
from anta.models import AntaTest
from anta.reporter import ReportJinja, ReportTable
from anta.runner import main

if TYPE_CHECKING:
import pathlib
Expand All @@ -28,6 +31,37 @@
logger = logging.getLogger(__name__)


def run_tests(ctx: click.Context) -> None:
"""Run the tests."""
# Digging up the parameters from the parent context
if ctx.parent is None:
ctx.exit()
nrfu_ctx_params = ctx.parent.params
tags = nrfu_ctx_params["tags"]
device = nrfu_ctx_params["device"] or None
test = nrfu_ctx_params["test"] or None
dry_run = nrfu_ctx_params["dry_run"]

catalog = ctx.obj["catalog"]
inventory = ctx.obj["inventory"]

print_settings(inventory, catalog)
with anta_progress_bar() as AntaTest.progress:
asyncio.run(
main(
ctx.obj["result_manager"],
inventory,
catalog,
tags=tags,
devices=set(device) if device else None,
tests=set(test) if test else None,
dry_run=dry_run,
)
)
if dry_run:
ctx.exit()


def _get_result_manager(ctx: click.Context) -> ResultManager:
"""Get a ResultManager instance based on Click context."""
return ctx.obj["result_manager"].filter(ctx.obj.get("hide")) if ctx.obj.get("hide") is not None else ctx.obj["result_manager"]
Expand Down

0 comments on commit 0f88b28

Please sign in to comment.