Skip to content

Commit

Permalink
Fix bug in set comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
niknetniko committed May 30, 2024
1 parent 2f70198 commit bae7d44
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 11 deletions.
6 changes: 3 additions & 3 deletions tested/oracles/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
TextOutputChannel,
ValueOutputChannel,
)
from tested.utils import sorted_no_duplicates
from tested.utils import sorted_no_duplicates, sorting_value_extract

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,12 +106,12 @@ def _prepare_value_for_type_check(value: Value) -> Value:
basic_type = as_basic_type(value)
if basic_type.type == BasicSequenceTypes.SET:
value.data = sorted_no_duplicates(
value.data, recursive_key=lambda x: x.data # type: ignore
value.data, recursive_key=sorting_value_extract
)
elif isinstance(value, ObjectType):
assert isinstance(value, ObjectType)
value.data = sorted_no_duplicates(
value.data, key=lambda x: x.key, recursive_key=lambda x: x.data # type: ignore
value.data, key=lambda x: x.key, recursive_key=sorting_value_extract
)
else:
assert isinstance(value.type, SimpleTypes)
Expand Down
8 changes: 7 additions & 1 deletion tested/serialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from collections.abc import Iterable
from decimal import Decimal
from enum import StrEnum, auto, unique
from functools import reduce
from functools import reduce, total_ordering
from types import NoneType
from typing import Any, Literal, Optional, Union, cast

Expand Down Expand Up @@ -611,6 +611,7 @@ def serialize_from_python(value: Any, type_: AllTypes | None = None) -> Value:
raise TypeError(f"No clue how to convert {value} into TESTed value.")


@total_ordering
class ComparableFloat:
def __init__(self, value):
self.value = value
Expand All @@ -624,6 +625,11 @@ def __eq__(self, other):
except Exception:
return False

def __lt__(self, other):
if not isinstance(other, ComparableFloat):
return NotImplemented
return self.value < other.value

Check warning on line 631 in tested/serialisation.py

View check run for this annotation

Codecov / codecov/patch

tested/serialisation.py#L629-L631

Added lines #L629 - L631 were not covered by tests

def __str__(self):
return str(self.value)

Expand Down
35 changes: 29 additions & 6 deletions tested/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ def recursive_dict_merge(one: dict, two: dict) -> dict:
return new_dictionary


def sorting_value_extract(maybe_value: Any) -> Any:
if hasattr(maybe_value, "data"):
return maybe_value.data

Check warning on line 167 in tested/utils.py

View check run for this annotation

Codecov / codecov/patch

tested/utils.py#L166-L167

Added lines #L166 - L167 were not covered by tests
else:
return maybe_value

Check warning on line 169 in tested/utils.py

View check run for this annotation

Codecov / codecov/patch

tested/utils.py#L169

Added line #L169 was not covered by tests


def sorted_no_duplicates(
iterable: Iterable[T],
key: Callable[[T], K] = lambda x: x,
Expand Down Expand Up @@ -200,18 +207,34 @@ def order(x: Any, y: Any) -> int:
:param y: second value
:return: 1 if x < y else -1 if x > y else 0
"""
# Attempt to use a key to extract the data if needed.
if recursive_key: # Parent function parameter
if x is not None:
x = recursive_key(x)
if y is not None:
y = recursive_key(y)
cmp = type_order(x, y)
if cmp != 0:
return cmp
elif not isinstance(x, str) and isinstance(x, Iterable):

# Try to compare types; this might be enough.
type_compare = type_order(x, y)
if type_compare != 0:
return type_compare

Check warning on line 220 in tested/utils.py

View check run for this annotation

Codecov / codecov/patch

tested/utils.py#L218-L220

Added lines #L218 - L220 were not covered by tests

# Next, if we have iterables, attempt to use those (but not for strings)
# Both should be iterable in this case, since the types are the same.
if (

Check warning on line 224 in tested/utils.py

View check run for this annotation

Codecov / codecov/patch

tested/utils.py#L224

Added line #L224 was not covered by tests
not isinstance(x, str)
and not isinstance(y, str)
and isinstance(x, Iterable)
and isinstance(y, Iterable)
):
return order_iterable(x, y)
else:
return int(x < y) - int(x > y)

# Finally, attempt to use the values themselves.
try:
return int(x < y) - int(x > y) # type: ignore
except TypeError:

Check warning on line 235 in tested/utils.py

View check run for this annotation

Codecov / codecov/patch

tested/utils.py#L233-L235

Added lines #L233 - L235 were not covered by tests
# These types cannot be compared, so fallback to string comparison.
return str(x) < str(y)

Check warning on line 237 in tested/utils.py

View check run for this annotation

Codecov / codecov/patch

tested/utils.py#L237

Added line #L237 was not covered by tests

# Sort functions, custom implementation needed for efficient recursive ordering
# of values that have different types
Expand Down
24 changes: 23 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

import yaml

from tested.utils import sorted_no_duplicates
from tested.datatypes import (
AdvancedNothingTypes,
AdvancedSequenceTypes,
BasicNumericTypes,
BasicSequenceTypes,
)
from tested.serialisation import NothingType, NumberType, SequenceType
from tested.utils import sorted_no_duplicates, sorting_value_extract
from tests.manual_utils import assert_valid_output, configuration, execute_config


Expand Down Expand Up @@ -210,6 +217,21 @@ def test_sort_empty():
assert [] == sorted_no_duplicates([])


def test_can_sort_nested_sets():
value = SequenceType(
type=AdvancedSequenceTypes.LIST,
data=[
NumberType(type=BasicNumericTypes.INTEGER, data=5, diagnostic=None),
NothingType(
type=AdvancedNothingTypes.UNDEFINED, data=None, diagnostic=None
),
],
diagnostic=None,
)
result = sorted_no_duplicates([value, value], key=sorting_value_extract)
assert [value] == result


def test_valid_yaml_and_json():
"""
Test to validate if all YAML and JSON can be parsed correctly.
Expand Down

0 comments on commit bae7d44

Please sign in to comment.