Skip to content

Commit

Permalink
further improvements to type annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
jornbr committed Dec 10, 2024
1 parent 671192f commit 9a91b44
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 128 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/fabm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ jobs:
- name: Type checking with mypy
if: matrix.python-version != '3.7' && matrix.python-version != '3.8'
run: |
conda install mypy pyyaml types-pyyaml netcdf4
conda install mypy pyyaml types-pyyaml netcdf4 pyqt-stubs
mypy src/pyfabm
- name: Install with customization via command line arguments
run: |
Expand Down
96 changes: 57 additions & 39 deletions src/pyfabm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@

# typing.Final not available in Python 3.7
try:
from typing import Final
from typing import Final, SupportsIndex
except ImportError:
from typing import Any as Final
from typing import Any

Final = SupportsIndex = Any # type: ignore

try:
import importlib.metadata
Expand Down Expand Up @@ -594,24 +596,24 @@ def __init__(self, model: "Model", variable_pointer: ctypes.c_void_p):
self._pvariable = variable_pointer

def __getitem__(self, key: str) -> Union[float, int, bool]:
typecode = self.model.fabm.variable_get_property_type(
typecode: int = self.model.fabm.variable_get_property_type(
self._pvariable, key.encode("ascii")
)
if typecode == DataType.REAL:
return self.model.fabm.variable_get_real_property(
value = self.model.fabm.variable_get_real_property(
self._pvariable, key.encode("ascii"), -1.0
)
return cast(float, value)
elif typecode == DataType.INTEGER:
return self.model.fabm.variable_get_integer_property(
value = self.model.fabm.variable_get_integer_property(
self._pvariable, key.encode("ascii"), 0
)
return cast(int, value)
elif typecode == DataType.LOGICAL:
return (
self.model.fabm.variable_get_logical_property(
self._pvariable, key.encode("ascii"), 0
)
!= 0
value = self.model.fabm.variable_get_logical_property(
self._pvariable, key.encode("ascii"), 0
)
return cast(int, value) != 0
raise KeyError


Expand Down Expand Up @@ -681,12 +683,14 @@ def long_path(self) -> str:
@property
def missing_value(self) -> float:
"""Value that indicates missing data, for instance, on land."""
return self.model.fabm.variable_get_missing_value(self._pvariable)
value: float = self.model.fabm.variable_get_missing_value(self._pvariable)
return value

def getRealProperty(self, name, default=-1.0) -> float:
return self.model.fabm.variable_get_real_property(
def getRealProperty(self, name: str, default: float = -1.0) -> float:
value: float = self.model.fabm.variable_get_real_property(
self._pvariable, name.encode("ascii"), default
)
return value


class Dependency(VariableFromPointer):
Expand Down Expand Up @@ -727,7 +731,8 @@ def link(self, data: np.ndarray):

@property
def required(self) -> bool:
return self.model.fabm.variable_is_required(self._pvariable) != 0
value: int = self.model.fabm.variable_is_required(self._pvariable)
return value != 0


class StateVariable(VariableFromPointer):
Expand All @@ -747,21 +752,25 @@ def value(self, value: npt.ArrayLike):

@property
def background_value(self) -> float:
return self.model.fabm.variable_get_background_value(self._pvariable)
value: float = self.model.fabm.variable_get_background_value(self._pvariable)
return value

@property
def output(self) -> bool:
return self.model.fabm.variable_get_output(self._pvariable) != 0
value: int = self.model.fabm.variable_get_output(self._pvariable)
return value != 0

@property
def no_river_dilution(self) -> bool:
return self.model.fabm.variable_get_no_river_dilution(self._pvariable) != 0
value: int = self.model.fabm.variable_get_no_river_dilution(self._pvariable)
return value != 0

@property
def no_precipitation_dilution(self) -> bool:
return (
self.model.fabm.variable_get_no_precipitation_dilution(self._pvariable) != 0
value: int = self.model.fabm.variable_get_no_precipitation_dilution(
self._pvariable
)
return value != 0


class DiagnosticVariable(VariableFromPointer):
Expand All @@ -784,7 +793,8 @@ def value(self) -> Optional[np.ndarray]:
@property
def output(self) -> bool:
"""Whether this diagnostic is meant to be included in output by default"""
return self.model.fabm.variable_get_output(self._pvariable) != 0
value: int = self.model.fabm.variable_get_output(self._pvariable)
return value != 0

@property
def save(self) -> bool:
Expand Down Expand Up @@ -820,23 +830,23 @@ def __init__(
self._index = index + 1
self._has_default = has_default

def _get_value(self, *, default: bool = False):
def _get_value(self, *, default: bool = False) -> Union[float, int, bool, str]:
idefault = 1 if default else 0
if self._type == DataType.REAL:
return self.model.fabm.get_real_parameter(
value = self.model.fabm.get_real_parameter(
self.model.pmodel, self._index, idefault
)
return cast(float, value)
elif self._type == DataType.INTEGER:
return self.model.fabm.get_integer_parameter(
value = self.model.fabm.get_integer_parameter(
self.model.pmodel, self._index, idefault
)
return cast(int, value)
elif self._type == DataType.LOGICAL:
return (
self.model.fabm.get_logical_parameter(
self.model.pmodel, self._index, idefault
)
!= 0
value = self.model.fabm.get_logical_parameter(
self.model.pmodel, self._index, idefault
)
return cast(int, value) != 0
elif self._type == DataType.STRING:
result = ctypes.create_string_buffer(ATTRIBUTE_LENGTH)
self.model.fabm.get_string_parameter(
Expand Down Expand Up @@ -880,7 +890,7 @@ def default(self) -> Union[float, int, bool, str, None]:
return None
return self._get_value(default=True)

def reset(self):
def reset(self) -> None:
"""Reset this parameter to its default value"""
settings = self.model._save_state()
self.model.fabm.reset_parameter(self.model.pmodel, self._index)
Expand All @@ -897,7 +907,8 @@ def __init__(
@property
def missing_value(self) -> float:
"""Value that indicates missing data, for instance, on land."""
return self.model.fabm.variable_get_missing_value(self._pvariable)
value: float = self.model.fabm.variable_get_missing_value(self._pvariable)
return value


class StandardVariable:
Expand Down Expand Up @@ -953,7 +964,7 @@ def __contains__(self, key: object) -> bool:
return False
return key in self._data

def index(self, key: Union[T, str], *args) -> int:
def index(self, key: Union[T, str], *args: SupportsIndex) -> int:
if isinstance(key, str):
try:
key = self.find(key)
Expand All @@ -977,7 +988,7 @@ def find(self, name: str, case_insensitive: bool = False) -> T:
self._lookup = {obj.name: obj for obj in self._data}
return self._lookup[name]

def clear(self):
def clear(self) -> None:
self._data.clear()
self._lookup = None
self._lookup_ci = None
Expand Down Expand Up @@ -1036,7 +1047,7 @@ def __init__(self, model: "Model", name: str):
model.fabm.get_model_metadata(
model.pmodel, name.encode("ascii"), ATTRIBUTE_LENGTH, strlong_name, iuser
)
self.long_name = strlong_name.value.decode("ascii")
self.long_name: str = strlong_name.value.decode("ascii")
self.user_created = iuser.value != 0


Expand Down Expand Up @@ -1262,15 +1273,17 @@ def save_settings(self, path: str, display: int = DISPLAY_NORMAL):
"""Write model configuration to yaml file"""
self.fabm.save_settings(self.pmodel, path.encode("ascii"), display)

def _save_state(self) -> Tuple:
def _save_state(self) -> Tuple[Mapping[str, np.ndarray], Mapping[str, np.ndarray]]:
environment = {}
for dependency in self.dependencies:
if dependency.value is not None:
environment[dependency.name] = dependency.value
state = {variable.name: variable.value for variable in self.state_variables}
return environment, state

def _restore_state(self, data: Tuple):
def _restore_state(
self, data: Tuple[Mapping[str, np.ndarray], Mapping[str, np.ndarray]]
):
environment, state = data
for dependency in self.dependencies:
if dependency.name in environment:
Expand All @@ -1279,7 +1292,12 @@ def _restore_state(self, data: Tuple):
if variable.name in state:
variable.value = state[variable.name]

def _update_configuration(self, settings: Optional[Tuple] = None):
def _update_configuration(
self,
settings: Optional[
Tuple[Mapping[str, np.ndarray], Mapping[str, np.ndarray]]
] = None,
):
# Get number of model variables per category
nstate_interior = ctypes.c_int()
nstate_surface = ctypes.c_int()
Expand Down Expand Up @@ -1578,10 +1596,10 @@ def get_conserved_quantities(self, out: Optional[np.ndarray] = None) -> np.ndarr
return out

def check_state(self, repair: bool = False) -> bool:
valid = self.fabm.check_state(self.pmodel, repair) != 0
valid: int = self.fabm.check_state(self.pmodel, repair)
if hasError():
raise FABMException(getError())
return valid
return valid != 0

checkState = check_state

Expand Down Expand Up @@ -1756,7 +1774,7 @@ def integrate(
return y


def unload():
def unload() -> None:
global ctypes

for lib in name2lib.values():
Expand Down
7 changes: 6 additions & 1 deletion src/pyfabm/complete_yaml.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pyfabm


def processFile(infile, outfile, subtract_background=False, add_missing=False):
def processFile(
infile: str,
outfile: str,
subtract_background: bool = False,
add_missing: bool = False,
):
# Create model object from YAML file.
model = pyfabm.Model(infile)
model.save_settings(
Expand Down
Loading

0 comments on commit 9a91b44

Please sign in to comment.