Skip to content

Commit

Permalink
refactor(anta): Clean reporter typing (#438)
Browse files Browse the repository at this point in the history
* Refactor(anta): Remove ColorManager class

* Refactor: Remove ListResult

* Refactor: Separate functions to render get_results

* Test: Make python 3.8 happy again

* Doc: Remove ColorManager from doc

* Doc: Remove ListResult from doc
  • Loading branch information
gmuloc authored Oct 27, 2023
1 parent abd22ac commit a78db01
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 216 deletions.
2 changes: 2 additions & 0 deletions anta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class RICH_COLOR_PALETTE:
SUCCESS = "green4"
SKIPPED = "bold orange4"
HEADER = "cyan"
UNSET = "grey74"


# Dictionary to use in a Rich.Theme: custom_theme = Theme(RICH_COLOR_THEME)
Expand All @@ -40,4 +41,5 @@ class RICH_COLOR_PALETTE:
"skipped": RICH_COLOR_PALETTE.SKIPPED,
"failure": RICH_COLOR_PALETTE.FAILURE,
"error": RICH_COLOR_PALETTE.ERROR,
"unset": RICH_COLOR_PALETTE.UNSET,
}
13 changes: 6 additions & 7 deletions anta/cli/nrfu/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) 2023 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.

"""
Utils functions to use with anta.cli.check.commands module.
"""
Expand Down Expand Up @@ -56,27 +55,27 @@ def print_json(results: ResultManager, output: pathlib.Path | None = None) -> No
"""Print result in a json format"""
console.print()
console.print(Panel("JSON results of all tests", style="cyan"))
rich.print_json(results.get_results(output_format="json"))
rich.print_json(results.get_json_results())
if output is not None:
with open(output, "w", encoding="utf-8") as fout:
fout.write(results.get_results(output_format="json"))
fout.write(results.get_json_results())


def print_list(results: ResultManager, output: pathlib.Path | None = None) -> None:
"""Print result in a list"""
console.print()
console.print(Panel.fit("List results of all tests", style="cyan"))
pprint(results.get_results(output_format="list"))
pprint(results.get_results())
if output is not None:
with open(output, "w", encoding="utf-8") as fout:
fout.write(str(results.get_results(output_format="list")))
fout.write(str(results.get_results()))


def print_text(results: ResultManager, search: str | None = None, skip_error: bool = False) -> None:
"""Print results as simple text"""
console.print()
regexp = re.compile(search or ".*")
for line in results.get_results(output_format="list"):
for line in results.get_results():
if any(regexp.match(entry) for entry in [line.name, line.test]) and (not skip_error or line.result != "error"):
message = f" ({str(line.messages[0])})" if len(line.messages) > 0 else ""
console.print(f"{line.name} :: {line.test} :: [{line.result}]{line.result.upper()}[/{line.result}]{message}", highlight=False)
Expand All @@ -86,7 +85,7 @@ def print_jinja(results: ResultManager, template: pathlib.Path, output: pathlib.
"""Print result based on template."""
console.print()
reporter = ReportJinja(template_path=template)
json_data = json.loads(results.get_results(output_format="json"))
json_data = json.loads(results.get_json_results())
report = reporter.render(json_data)
console.print(report)
if output is not None:
Expand Down
77 changes: 54 additions & 23 deletions anta/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

from jinja2 import Template
from rich.table import Table
from rich.text import Text

from anta import RICH_COLOR_PALETTE
from anta import RICH_COLOR_PALETTE, RICH_COLOR_THEME
from anta.custom_types import TestStatus
from anta.result_manager import ResultManager

from .models import ColorManager

logger = logging.getLogger(__name__)


Expand All @@ -30,11 +30,7 @@ def __init__(self) -> None:
"""
__init__ Class constructor
"""
self.colors = []
self.colors.append(ColorManager(level="success", color=RICH_COLOR_PALETTE.SUCCESS))
self.colors.append(ColorManager(level="failure", color=RICH_COLOR_PALETTE.FAILURE))
self.colors.append(ColorManager(level="error", color=RICH_COLOR_PALETTE.ERROR))
self.colors.append(ColorManager(level="skipped", color=RICH_COLOR_PALETTE.SKIPPED))
self.color_manager = ColorManager()

def _split_list_to_txt_list(self, usr_list: list[str], delimiter: Optional[str] = None) -> str:
"""
Expand Down Expand Up @@ -71,24 +67,18 @@ def _build_headers(self, headers: list[str], table: Table) -> Table:
table.add_column(header, justify="left")
return table

def _color_result(self, status: str, output_type: str = "Text") -> Any:
def _color_result(self, status: TestStatus) -> str:
"""
Helper to implement color based on test status.
It gives output for either standard str or Text() colorized with Style()
Return a colored string based on the status value.
Args:
status (str): status value to colorized
output_type (str, optional): Which format to output code. Defaults to 'Text'.
status (TestStatus): status value to color
Returns:
Any: Can be either str or Text with Style
str: the colored string
"""
# TODO refactor this code as it looks quite surprising
if len([result for result in self.colors if str(result.level).upper() == status.upper()]) == 1:
code: ColorManager = [result for result in self.colors if str(result.level).upper() == status.upper()][0]
return code.style_rich() if output_type == "Text" else code.string()
return None
color = RICH_COLOR_THEME.get(status, "")
return f"[{color}]{status}" if color != "" else str(status)

def report_all(
self,
Expand All @@ -115,10 +105,10 @@ def report_all(
headers = ["Device", "Test Name", "Test Status", "Message(s)", "Test description", "Test category"]
table = self._build_headers(headers=headers, table=table)

for result in result_manager.get_results(output_format="list"):
for result in result_manager.get_results():
# pylint: disable=R0916
if (host is None and testcase is None) or (host is not None and str(result.name) == host) or (testcase is not None and testcase == str(result.test)):
state = self._color_result(status=str(result.result), output_type="str")
state = self._color_result(result.result)
message = self._split_list_to_txt_list(result.messages) if len(result.messages) > 0 else ""
categories = ", ".join(result.categories)
table.add_row(str(result.name), result.test, state, message, result.description, categories)
Expand Down Expand Up @@ -238,7 +228,7 @@ def render(self, data: list[dict[str, Any]], trim_blocks: bool = True, lstrip_bl
Report is built based on a J2 template provided by user.
Data structure sent to template is:
>>> data = ResultManager.get_results(output_format="json")
>>> data = ResultManager.get_json_results()
>>> print(data)
[
{
Expand All @@ -263,3 +253,44 @@ def render(self, data: list[dict[str, Any]], trim_blocks: bool = True, lstrip_bl
template = Template(file_.read(), trim_blocks=trim_blocks, lstrip_blocks=lstrip_blocks)

return template.render({"data": data})


class ColorManager:
"""Color management for status report."""

def get_color(self, level: TestStatus) -> str:
"""Return the color attributed to the status in RICH_COLOR_THEME.
Args:
level (TestStatus): The status to colorized
Returns:
str: the colors attributed to this or empty string
"""
return RICH_COLOR_THEME.get(level, "")

def style_rich(self, level: TestStatus) -> Text:
"""
Build a rich Text syntax with color
Args:
level (TestStatus): The status to colorized
Returns:
Text: object with level string and its associated color.
"""
return Text(level, style=self.get_color(level))

def string(self, level: TestStatus) -> str:
"""
Build an str with color code
Args:
level (TestStatus): The status to colorized
Returns:
str: String with level and its associated color
"""
color = self.get_color(level)
return f"[{color}]{level}" if color != "" else str(level)
39 changes: 0 additions & 39 deletions anta/reporter/models.py

This file was deleted.

55 changes: 16 additions & 39 deletions anta/result_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@

import json
import logging
from typing import Any

from pydantic import TypeAdapter

from anta.custom_types import TestStatus
from anta.result_manager.models import ListResult, TestResult
from anta.result_manager.models import TestResult
from anta.tools.pydantic import pydantic_to_dict

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -89,7 +88,7 @@ def __init__(self) -> None:
If the status of the added test is error, the status is untouched and the
error_status is set to True.
"""
self._result_entries = ListResult()
self._result_entries: list[TestResult] = []
# Initialize status
self.status: TestStatus = "unset"
self.error_status = False
Expand Down Expand Up @@ -140,33 +139,25 @@ def get_status(self, ignore_error: bool = False) -> str:
"""
return "error" if self.error_status and not ignore_error else self.status

def get_results(self, output_format: str = "native") -> Any:
def get_results(self) -> list[TestResult]:
"""
Expose list of all test results in different format
Support multiple format:
- native: ListResults format
- list: a list of TestResult
- json: a native JSON format
Args:
output_format (str, optional): format selector. Can be either native/list/json. Defaults to 'native'.
Returns:
any: List of results.
"""
if output_format == "list":
return list(self._result_entries)
return self._result_entries

if output_format == "json":
return json.dumps(pydantic_to_dict(self._result_entries), indent=4)
def get_json_results(self) -> str:
"""
Expose list of all test results in JSON
if output_format == "native":
# Default return for native format.
return self._result_entries
raise ValueError(f"{output_format} is not a valid value ['list', 'json', 'native']")
Returns:
str: JSON dumps of the list of results
"""
return json.dumps(pydantic_to_dict(self._result_entries), indent=4)

def get_result_by_test(self, test_name: str, output_format: str = "native") -> Any:
def get_result_by_test(self, test_name: str) -> list[TestResult]:
"""
Get list of test result for a given test.
Expand All @@ -177,16 +168,9 @@ def get_result_by_test(self, test_name: str, output_format: str = "native") -> A
Returns:
list[TestResult]: List of results related to the test.
"""
if output_format == "list":
return [result for result in self._result_entries if str(result.test) == test_name]
return [result for result in self._result_entries if str(result.test) == test_name]

result_manager_filtered = ListResult()
for result in self._result_entries:
if result.test == test_name:
result_manager_filtered.append(result)
return result_manager_filtered

def get_result_by_host(self, host_ip: str, output_format: str = "native") -> Any:
def get_result_by_host(self, host_ip: str) -> list[TestResult]:
"""
Get list of test result for a given host.
Expand All @@ -195,16 +179,9 @@ def get_result_by_host(self, host_ip: str, output_format: str = "native") -> Any
output_format (str, optional): format selector. Can be either native/list. Defaults to 'native'.
Returns:
Any: List of results related to the host.
list[TestResult]: List of results related to the host.
"""
if output_format == "list":
return [result for result in self._result_entries if str(result.name) == host_ip]

result_manager_filtered = ListResult()
for result in self._result_entries:
if str(result.name) == host_ip:
result_manager_filtered.append(result)
return result_manager_filtered
return [result for result in self._result_entries if str(result.name) == host_ip]

def get_testcases(self) -> list[str]:
"""
Expand Down
37 changes: 1 addition & 36 deletions anta/result_manager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
"""Models related to anta.result_manager module."""
from __future__ import annotations

from collections.abc import Iterator

# Need to keep List for pydantic in 3.8
from typing import List, Optional

from pydantic import BaseModel, ConfigDict, RootModel
from pydantic import BaseModel, ConfigDict

from anta.custom_types import TestStatus

Expand Down Expand Up @@ -95,36 +93,3 @@ def __str__(self) -> str:
Returns a human readable string of this TestResult
"""
return f"Test {self.test} on device {self.name} has result {self.result}"


class ListResult(RootModel[List[TestResult]]):
"""
list result for all tests on all devices.
Attributes:
__root__ (list[TestResult]): A list of TestResult objects.
"""

root: List[TestResult] = []

def extend(self, values: list[TestResult]) -> None:
"""Add support for extend method."""
self.root.extend(values)

def append(self, value: TestResult) -> None:
"""Add support for append method."""
self.root.append(value)

def __iter__(self) -> Iterator[TestResult]: # type: ignore
"""Use custom iter method."""
# TODO - mypy is not happy because we overwrite BaseModel.__iter__
# return type and are breaking Liskov Substitution Principle.
return iter(self.root)

def __getitem__(self, item: int) -> TestResult:
"""Use custom getitem method."""
return self.root[item]

def __len__(self) -> int:
"""Support for length of __root__"""
return len(self.root)
Loading

0 comments on commit a78db01

Please sign in to comment.