diff --git a/cli/codeflash/verification/test_results.py b/cli/codeflash/verification/test_results.py new file mode 100644 index 0000000..2906ccd --- /dev/null +++ b/cli/codeflash/verification/test_results.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import logging +from enum import Enum +from typing import Optional, Iterator, List + +from pydantic.dataclasses import dataclass + +from codeflash.verification.comparator import comparator + + +class TestType(Enum): + EXISTING_UNIT_TEST = 1 + INSPIRED_REGRESSION = 2 + GENERATED_REGRESSION = 3 + + def to_name(self) -> str: + if self == TestType.EXISTING_UNIT_TEST: + return "⚙️ Existing Unit Tests" + elif self == TestType.INSPIRED_REGRESSION: + return "🎨 Inspired Regression Tests" + elif self == TestType.GENERATED_REGRESSION: + return "🌀 Generated Regression Tests" + + +@dataclass(frozen=True) +class InvocationId: + test_module_path: str # The fully qualified name of the test module + test_class_name: Optional[str] # The name of the class where the test is defined + test_function_name: ( + str # The name of the test_function. Does not include the components of the file_name + ) + function_getting_tested: str + iteration_id: Optional[str] + + # test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id + def id(self): + return f"{self.test_module_path}:{self.test_class_name or ''}.{self.test_function_name}:{self.function_getting_tested}:{self.iteration_id}" + + @staticmethod + def from_str_id(string_id: str): + components = string_id.split(":") + assert len(components) == 4 + second_components = components[1].split(".") + if len(second_components) == 1: + test_class_name = None + test_function_name = second_components[0] + else: + test_class_name = second_components[0] + test_function_name = second_components[1] + return InvocationId( + test_module_path=components[0], + test_class_name=test_class_name, + test_function_name=test_function_name, + function_getting_tested=components[2], + iteration_id=components[3], + ) + + +@dataclass(frozen=True) +class FunctionTestInvocation: + id: InvocationId # The fully qualified name of the function invocation (id) + file_name: str # The file where the test is defined + did_pass: bool # Whether the test this function invocation was part of, passed or failed + runtime: Optional[int] # Time in nanoseconds + test_framework: str # unittest or pytest + test_type: TestType + return_value: Optional[object] # The return value of the function invocation + + +class TestResults: + test_results: list[FunctionTestInvocation] + + def __init__(self, test_results=None): + if test_results is None: + test_results = [] + self.test_results = test_results + + def add(self, function_test_invocation: FunctionTestInvocation) -> None: + self.test_results.append(function_test_invocation) + + def merge(self, other: "TestResults") -> None: + self.test_results.extend(other.test_results) + + def get_by_id(self, invocation_id: InvocationId) -> Optional[FunctionTestInvocation]: + return next((r for r in self.test_results if r.id == invocation_id), None) + + def get_all_ids(self) -> List[InvocationId]: + return [test_result.id for test_result in self.test_results] + + def get_test_pass_fail_report(self) -> str: + passed = 0 + failed = 0 + for test_result in self.test_results: + if test_result.did_pass: + passed += 1 + else: + failed += 1 + return f"Passed: {passed}, Failed: {failed}" + + def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: + report = {} + for test_type in TestType: + report[test_type] = {"passed": 0, "failed": 0} + for test_result in self.test_results: + if test_result.did_pass: + report[test_result.test_type]["passed"] += 1 + else: + report[test_result.test_type]["failed"] += 1 + return report + + @staticmethod + def report_to_string(report: dict[TestType, dict[str, int]]) -> str: + return " ".join( + [ + f"{test_type.to_name()}- (Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']})" + for test_type in TestType + ] + ) + + def total_passed_runtime(self) -> int: + for result in self.test_results: + if result.did_pass and result.runtime is None: + logging.debug(f"Ignoring test case that passed but had no runtime -> {result.id}") + timing = sum( + [ + result.runtime + for result in self.test_results + if (result.did_pass and result.runtime is not None) + ] + ) + return timing + + def __iter__(self) -> Iterator[FunctionTestInvocation]: + return iter(self.test_results) + + def __len__(self) -> int: + return len(self.test_results) + + def __getitem__(self, index: int) -> FunctionTestInvocation: + return self.test_results[index] + + def __setitem__(self, index: int, value: FunctionTestInvocation) -> None: + self.test_results[index] = value + + def __delitem__(self, index: int) -> None: + del self.test_results[index] + + def __contains__(self, value: FunctionTestInvocation) -> bool: + return value in self.test_results + + def __bool__(self) -> bool: + return bool(self.test_results) + + def __eq__(self, other: TestResults): + # Unordered comparison + if type(self) != type(other): + return False + if len(self) != len(other): + return False + for test_result in self: + other_test_result = other.get_by_id(test_result.id) + if other_test_result is None: + return False + if ( + test_result.file_name != other_test_result.file_name + or test_result.did_pass != other_test_result.did_pass + or test_result.runtime != other_test_result.runtime + or test_result.test_framework != other_test_result.test_framework + or test_result.test_type != other_test_result.test_type + or not comparator(test_result.return_value, other_test_result.return_value) + ): + return False + return True