From 20ebc00e2859893e707fc21db8ea430032c9bf1d Mon Sep 17 00:00:00 2001 From: Michael Thies Date: Mon, 1 Apr 2024 16:30:41 +0200 Subject: [PATCH] Allow storing float not-a-number values in MySQL database (as NULL value) --- shc/interfaces/mysql.py | 13 ++++++++++--- test/interfaces/test_mysql.py | 7 ++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/shc/interfaces/mysql.py b/shc/interfaces/mysql.py index 07df609a..8ac324da 100644 --- a/shc/interfaces/mysql.py +++ b/shc/interfaces/mysql.py @@ -3,6 +3,7 @@ import enum import json import logging +import math from typing import Optional, Type, Generic, List, Tuple, Any, Dict, Callable import aiomysql @@ -316,7 +317,8 @@ def _get_to_mysql_converter(type_: Type[T]) -> Callable[[T], Any]: elif issubclass(type_, int): return lambda value: int(value) # type: ignore # type_ is equivalent to T -> int(a: int) is valid elif issubclass(type_, float): - return lambda value: float(value) # type: ignore # type_ is equivalent to T -> float(a: float) is valid + # type_ is equivalent to T -> float(a: float) is valid + return lambda value: None if math.isnan(value) else float(value) # type: ignore elif issubclass(type_, str): return lambda value: str(value) elif issubclass(type_, enum.Enum): @@ -328,10 +330,15 @@ def _get_to_mysql_converter(type_: Type[T]) -> Callable[[T], Any]: def _get_from_mysql_converter(type_: Type[T]) -> Callable[[Any], T]: if type_ is bool: return lambda x: bool(x) # type: ignore # type_ is equivalent to T -> type_ is bool here - elif type_ in (int, float, str): + elif type_ in (int, str): return lambda x: x - elif issubclass(type_, (bool, int, float, str, enum.Enum)): + elif type_ is float: + return lambda x: x if x is not None else float("nan") + elif issubclass(type_, (bool, int, str, enum.Enum)): return lambda value: type_(value) # type: ignore # type_ is equivalent to T -> type_() is an instance of T + elif issubclass(type_, float): + # type_ is equivalent to T -> type_() is an instance of T + return lambda value: type_(value) if value is not None else type_(float("nan")) # type: ignore else: return lambda value: from_json(type_, json.loads(value)) diff --git a/test/interfaces/test_mysql.py b/test/interfaces/test_mysql.py index e4fcc28e..194bb9f0 100644 --- a/test/interfaces/test_mysql.py +++ b/test/interfaces/test_mysql.py @@ -1,5 +1,6 @@ import datetime import asyncio +import math import unittest import urllib.parse from typing import Tuple, Type, Iterable, Dict, NamedTuple, Sequence, Any @@ -117,6 +118,7 @@ async def _create_log_variable_with_data(self, type_: Type[T], data: Iterable[Tu async def test_persistence_variables(self) -> None: data: Sequence[Tuple[str, Sequence[Any]]] = [ ("test_int", [5, 7]), + ("test_float", [3.14, float("nan")]), ("test_string", ["foo", "bar"]), ("test_tuple", [ExampleTuple(42, 3.14, "foo"), ExampleTuple(56, 0.414, "[{barπŸ™‚πŸ˜•}]")]), ] @@ -125,4 +127,7 @@ async def test_persistence_variables(self) -> None: variable = self.interface.persistence_variable(type(value_list[0]), variable_name) for value in value_list: await variable.write(value, [self]) - self.assertEqual(value_list[-1], await variable.read()) + if isinstance(value_list[-1], float) and math.isnan(value_list[-1]): + self.assertTrue(math.isnan(await variable.read())) + else: + self.assertEqual(value_list[-1], await variable.read())