diff --git a/pyproject.toml b/pyproject.toml index b098173f..97e5f0c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dependencies = [ "httpx>=0.27.2", "numpy>=1.9.0", "pandas>=2.2.3", - "pandera>=0.20.4", + "pandera==0.20.4", # https://github.com/unionai-oss/pandera/issues/1856 "polyfactory>=2.17.0", "pydantic>=2.9.2", "rdflib>=7.0.0", diff --git a/src/cimsparql/data_models.py b/src/cimsparql/data_models.py index 65029e9f..d9e4bc6b 100644 --- a/src/cimsparql/data_models.py +++ b/src/cimsparql/data_models.py @@ -2,15 +2,16 @@ from typing import Self import pandera as pa +from pandera.api.pandas.model_config import BaseConfig from pandera.typing import DataFrame, Index, Series -class JsonSchemaOut(pa.DataFrameModel): - class Config: +class CoercingSchema(pa.DataFrameModel): + class Config(BaseConfig): coerce = True -class FullModelSchema(JsonSchemaOut): +class FullModelSchema(CoercingSchema): model: Series[str] time: Series[str] profile: Series[str] @@ -25,7 +26,7 @@ def unique_time(cls, df: DataFrame[Self]) -> bool: FullModelDataFrame = DataFrame[FullModelSchema] -class MridResourceSchema(JsonSchemaOut): +class MridResourceSchema(CoercingSchema): """ Common class for resources with an mrid as index """ @@ -56,7 +57,7 @@ class MarketDatesSchema(NamedResourceSchema): MarketDatesDataFrame = DataFrame[MarketDatesSchema] -class BusDataSchema(JsonSchemaOut): +class BusDataSchema(CoercingSchema): node: Index[str] = pa.Field(unique=True) busname: Series[str] = pa.Field() substation: Series[str] = pa.Field() @@ -142,7 +143,7 @@ class ExchangeSchema(NamedMarketResourceSchema): ExchangeDataFrame = DataFrame[ExchangeSchema] -class PhaseTapChangerSchema(JsonSchemaOut): +class PhaseTapChangerSchema(CoercingSchema): mrid: Series[str] = pa.Field() phase_shift_increment: Series[float] = pa.Field() enabled: bool = pa.Field() @@ -186,7 +187,7 @@ class TransfConToConverterSchema(NamedResourceSchema): TransfConToConverterDataFrame = DataFrame[TransfConToConverterSchema] -class CoordinatesSchema(JsonSchemaOut): +class CoordinatesSchema(CoercingSchema): mrid: Series[str] = pa.Field() x: Series[float] = pa.Field() y: Series[float] = pa.Field() @@ -230,7 +231,7 @@ class AcLinesSchema(ShuntComponentSchema): AcLinesDataFrame = DataFrame[AcLinesSchema] -class TransformersSchema(JsonSchemaOut): +class TransformersSchema(CoercingSchema): name: Series[str] = pa.Field() p_mrid: Series[str] = pa.Field() w_mrid: Series[str] = pa.Field() @@ -253,7 +254,7 @@ class TransformerWindingSchema(ShuntComponentSchema): TransformerWindingDataFrame = DataFrame[TransformerWindingSchema] -class SubstationVoltageSchema(JsonSchemaOut): +class SubstationVoltageSchema(CoercingSchema): substation: Index[str] = pa.Field() container: Series[str] = pa.Field() v: Series[float] = pa.Field() @@ -262,7 +263,7 @@ class SubstationVoltageSchema(JsonSchemaOut): SubstationVoltageDataFrame = DataFrame[SubstationVoltageSchema] -class DisconnectedSchema(JsonSchemaOut): +class DisconnectedSchema(CoercingSchema): mrid: Series[str] = pa.Field(unique=True) @@ -305,7 +306,7 @@ class RegionsSchema(MridResourceSchema): RegionsDataFrame = DataFrame[RegionsSchema] -class StationGroupCodeNameSchema(JsonSchemaOut): +class StationGroupCodeNameSchema(CoercingSchema): station_group: Index[str] = pa.Field(unique=True) name: Series[str] = pa.Field() alias_name: Series[str] = pa.Field(nullable=True) @@ -322,7 +323,7 @@ class HVDCBidzonesSchema(MridResourceSchema): HVDCBidzonesDataFrame = DataFrame[HVDCBidzonesSchema] -class TransformerWindingsSchema(JsonSchemaOut): +class TransformerWindingsSchema(CoercingSchema): mrid: Series[str] = pa.Field() end_number: Series[int] = pa.Field(gt=0) w_mrid: Index[str] = pa.Field(unique=True) @@ -331,7 +332,7 @@ class TransformerWindingsSchema(JsonSchemaOut): TransformerWindingsDataFrame = DataFrame[TransformerWindingsSchema] -class SvInjectionSchema(JsonSchemaOut): +class SvInjectionSchema(CoercingSchema): node: Series[str] = pa.Field(unique=True) p: Series[float] = pa.Field() q: Series[float] = pa.Field() @@ -340,7 +341,7 @@ class SvInjectionSchema(JsonSchemaOut): SvInjectionDataFrame = DataFrame[SvInjectionSchema] -class RASEquipmentSchema(JsonSchemaOut): +class RASEquipmentSchema(CoercingSchema): mrid: Series[str] = pa.Field(unique=True) equipment_mrid: Series[str] = pa.Field() name: Series[str] = pa.Field() @@ -349,7 +350,7 @@ class RASEquipmentSchema(JsonSchemaOut): RASEquipmentDataFrame = DataFrame[RASEquipmentSchema] -class Switches(JsonSchemaOut): +class Switches(CoercingSchema): mrid: Index[str] = pa.Field(unique=True) is_open: Series[bool] = pa.Field() equipment_type: Series[str] = pa.Field() @@ -360,7 +361,7 @@ class Switches(JsonSchemaOut): SwitchesDataFrame = DataFrame[Switches] -class ConnectivityNode(JsonSchemaOut): +class ConnectivityNode(CoercingSchema): mrid: Index[str] = pa.Field(unique=True) container: Series[str] = pa.Field() container_name: Series[str] = pa.Field() @@ -373,7 +374,7 @@ class ConnectivityNode(JsonSchemaOut): ConnectivityNodeDataFrame = DataFrame[ConnectivityNode] -class SvPowerDeviationSchema(JsonSchemaOut): +class SvPowerDeviationSchema(CoercingSchema): node: Series[str] = pa.Field(unique=True) sum_terminal_flow: Series[float] = pa.Field() reported_sv_injection: Series[float] = pa.Field() @@ -384,7 +385,7 @@ class SvPowerDeviationSchema(JsonSchemaOut): SvPowerDeviationDataFrame = DataFrame[SvPowerDeviationSchema] -class HVDC(JsonSchemaOut): +class HVDC(CoercingSchema): converter_mrid_1: Series[str] = pa.Field() converter_mrid_2: Series[str] = pa.Field() name: Series[str] = pa.Field() @@ -394,7 +395,7 @@ class HVDC(JsonSchemaOut): HVDCDataFrame = DataFrame[HVDC] -class BaseVoltage(JsonSchemaOut): +class BaseVoltage(CoercingSchema): mrid: str = pa.Field(unique=True) un: float = pa.Field() operating_voltage: float = pa.Field() @@ -403,7 +404,7 @@ class BaseVoltage(JsonSchemaOut): BaseVoltageDataFrame = DataFrame[BaseVoltage] -class AssociatedSwitches(JsonSchemaOut): +class AssociatedSwitches(CoercingSchema): mrid: str = pa.Field(unique=True) name: str switch_mrids: str diff --git a/tests/test_type_mapper.py b/tests/test_type_mapper.py index 8b312a49..0c13eed0 100644 --- a/tests/test_type_mapper.py +++ b/tests/test_type_mapper.py @@ -11,7 +11,7 @@ from pandas.testing import assert_frame_equal from cimsparql import type_mapper -from cimsparql.data_models import JsonSchemaOut +from cimsparql.data_models import CoercingSchema from cimsparql.graphdb import RestApi from cimsparql.model import ServiceConfig @@ -126,8 +126,8 @@ def test_have_cim_version(httpserver: HTTPServer, cim_version: int, have_cim_ver assert tm.have_cim_version(str(cim_version)) == have_cim_version -class FloatSchema(JsonSchemaOut): - float_col: pa.typing.Series[float] = pa.Field(nullable=True) +class FloatSchema(CoercingSchema): + float_col: float = pa.Field(nullable=True) @pytest.mark.parametrize(