From a28bd6bd133c58974fd22e30eeff243332fa02aa Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 21 Aug 2024 17:30:04 +0200 Subject: [PATCH] Add some type annotations --- parcels/_typing.py | 4 +- parcels/interaction/interactionkernel.py | 2 +- parcels/kernel.py | 4 +- parcels/particledata.py | 4 +- parcels/particleset.py | 65 ++++++++++++++++-------- parcels/tools/converters.py | 4 +- 6 files changed, 54 insertions(+), 29 deletions(-) diff --git a/parcels/_typing.py b/parcels/_typing.py index 53f9b7e42..637393a07 100644 --- a/parcels/_typing.py +++ b/parcels/_typing.py @@ -9,7 +9,7 @@ import ast import datetime import os -from typing import Any, Literal, get_args +from typing import Any, Callable, Literal, get_args class ParcelsAST(ast.AST): @@ -36,6 +36,8 @@ class ParcelsAST(ast.AST): UpdateStatus = Literal["not_updated", "first_updated", "updated"] # mostly corresponds with `update_status` TimePeriodic = float | datetime.timedelta | Literal[False] # mostly corresponds with `update_status` +KernelFunction = Callable[..., None] + def ensure_is_literal_value(value: Any, literal: Any) -> None: """Ensures that a value is a valid option for the provided Literal type annotation.""" diff --git a/parcels/interaction/interactionkernel.py b/parcels/interaction/interactionkernel.py index 1b9da571e..1b5f4d36d 100644 --- a/parcels/interaction/interactionkernel.py +++ b/parcels/interaction/interactionkernel.py @@ -36,7 +36,7 @@ def __init__( py_ast=None, funcvars=None, c_include="", - delete_cfiles=True, + delete_cfiles: bool = True, ): if MPI is not None and MPI.COMM_WORLD.Get_size() > 1: raise NotImplementedError( diff --git a/parcels/kernel.py b/parcels/kernel.py index 3c961b65d..efbf33300 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -58,8 +58,8 @@ def __init__( funccode=None, py_ast=None, funcvars: list[str] | None = None, - c_include="", - delete_cfiles=True, + c_include: str = "", + delete_cfiles: bool = True, ): self._fieldset = fieldset self.field_args = None diff --git a/parcels/particledata.py b/parcels/particledata.py index 199a9c08b..1f32cfa26 100644 --- a/parcels/particledata.py +++ b/parcels/particledata.py @@ -68,9 +68,9 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, n assert ( depth is not None ), "particle's initial depth is None - incompatible with the ParticleData class. Invalid state." - assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don" "t all have the same lenghts." + assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don't all have the same lenghts." - assert lon.size == time.size, "time and positions (lon, lat, depth) don" "t have the same lengths." + assert lon.size == time.size, "time and positions (lon, lat, depth) don't have the same lengths." # If a partitioning function for MPI runs has been passed into the # particle creation with the "partition_function" kwarg, retrieve it here. diff --git a/parcels/particleset.py b/parcels/particleset.py index 4d1c2ecec..4acbe9a07 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -2,17 +2,20 @@ import sys from copy import copy from datetime import date, datetime, timedelta +from typing import Callable, Literal import cftime import numpy as np import xarray as xr from tqdm import tqdm +import parcels +from parcels._typing import KernelFunction, PathLike + try: from mpi4py import MPI except ModuleNotFoundError: MPI = None - try: from pykdtree.kdtree import KDTree except ModuleNotFoundError: @@ -21,6 +24,7 @@ from parcels.application_kernels.advection import AdvectionRK4 from parcels.compilation.codecompiler import GNUCompiler from parcels.field import NestedField +from parcels.fieldset import FieldSet from parcels.grid import CurvilinearGrid, GridType from parcels.interaction.interactionkernel import InteractionKernel from parcels.interaction.neighborsearch import ( @@ -30,7 +34,7 @@ KDTreeFlatNeighborSearch, ) from parcels.kernel import Kernel -from parcels.particle import JITParticle, Variable +from parcels.particle import JITParticle, ScipyParticle, Variable from parcels.particledata import ParticleData, ParticleDataIterator from parcels.particlefile import ParticleFile from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array @@ -90,13 +94,13 @@ class ParticleSet: def __init__( self, - fieldset, + fieldset: FieldSet, pclass=JITParticle, - lon=None, - lat=None, + lon: list[float] | None = None, + lat: list[float] | None = None, depth=None, time=None, - repeatdt=None, + repeatdt: timedelta | float | None = None, lonlatdepth_dtype=None, pid_orig=None, interaction_distance=None, @@ -166,7 +170,7 @@ def ArrayClass_init(self, *args, **kwargs): depth = np.ones(lon.size) * mindepth else: depth = convert_to_flat_array(depth) - assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don" "t all have the same lenghts" + assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don't all have the same lenghts" time = convert_to_flat_array(time) time = np.repeat(time, lon.size) if time.size == 1 else time @@ -176,7 +180,7 @@ def ArrayClass_init(self, *args, **kwargs): if time.size > 0 and isinstance(time[0], np.timedelta64) and not self.time_origin: raise NotImplementedError("If fieldset.time_origin is not a date, time of a particle must be a double") time = np.array([self.time_origin.reltime(t) if _convert_to_reltime(t) else t for t in time]) - assert lon.size == time.size, "time and positions (lon, lat, depth) don" "t have the same lengths." + assert lon.size == time.size, "time and positions (lon, lat, depth) don't have the same lengths." if lonlatdepth_dtype is None: lonlatdepth_dtype = self.lonlatdepth_dtype_from_field_interp_method(fieldset.U) @@ -476,7 +480,16 @@ def populate_indices(self): @classmethod def from_list( - cls, fieldset, pclass, lon, lat, depth=None, time=None, repeatdt=None, lonlatdepth_dtype=None, **kwargs + cls, + fieldset: FieldSet, + pclass: JITParticle | ScipyParticle, + lon: list[float], + lat: list[float], + depth: list[float] | None = None, + time: list[float] | None = None, + repeatdt: datetime | float | None = None, + lonlatdepth_dtype=None, + **kwargs, ): """Initialise the ParticleSet from lists of lon and lat. @@ -576,7 +589,7 @@ def from_line( ) @classmethod - def monte_carlo_sample(cls, start_field, size, mode="monte_carlo"): + def monte_carlo_sample(cls, start_field, size, mode: Literal["monte_carlo"] = "monte_carlo"): """Converts a starting field into a monte-carlo sample of lons and lats. Parameters @@ -635,7 +648,7 @@ def monte_carlo_sample(cls, start_field, size, mode="monte_carlo"): ) return list(lon), list(lat) else: - raise NotImplementedError(f'Mode {mode} not implemented. Please use "monte carlo" algorithm instead.') + raise NotImplementedError(f'Mode {mode} not implemented. Please use "monte_carlo" algorithm instead.') @classmethod def from_field( @@ -643,8 +656,8 @@ def from_field( fieldset, pclass, start_field, - size, - mode="monte_carlo", + size: int, + mode: Literal["monte_carlo"] = "monte_carlo", depth=None, time=None, repeatdt=None, @@ -690,7 +703,15 @@ def from_field( @classmethod def from_particlefile( - cls, fieldset, pclass, filename, restart=True, restarttime=None, repeatdt=None, lonlatdepth_dtype=None, **kwargs + cls, + fieldset: FieldSet, + pclass, + filename: PathLike, + restart=True, + restarttime=None, + repeatdt=None, + lonlatdepth_dtype=None, + **kwargs, ): """Initialise the ParticleSet from a zarr ParticleFile. This creates a new ParticleSet based on locations of all particles written @@ -702,7 +723,7 @@ def from_particlefile( mod:`parcels.fieldset.FieldSet` object from which to sample velocity pclass : parcels.particle.JITParticle or parcels.particle.ScipyParticle Particle class. May be a particle class as defined in parcels, or a subclass defining a custom particle. - filename : str + filename : Name of the particlefile from which to read initial conditions restart : bool BSignal if pset is used for a restart (default is True). @@ -799,7 +820,7 @@ def from_particlefile( **kwargs, ) - def Kernel(self, pyfunc, c_include="", delete_cfiles=True): + def Kernel(self, pyfunc: KernelFunction | list[KernelFunction], c_include="", delete_cfiles=True): """Wrapper method to convert a `pyfunc` into a :class:`parcels.kernel.Kernel` object. Conversion is based on `fieldset` and `ptype` of the ParticleSet. @@ -881,7 +902,7 @@ def error_particles(self): return ParticleDataIterator(self.particledata, subset=error_indices) @property - def num_error_particles(self): + def num_error_particles(self) -> int: """Get the number of particles that are in an error state. Returns @@ -909,12 +930,12 @@ def execute( pyfunc_inter=None, endtime=None, runtime=None, - dt=1.0, - output_file=None, - verbose_progress=True, - postIterationCallbacks=None, + dt: timedelta | float = 1.0, + output_file: parcels.ParticleFile | None = None, + verbose_progress: bool = True, + postIterationCallbacks: list[Callable[[], None]] | None = None, callbackdt=None, - delete_cfiles=True, + delete_cfiles: bool = True, ): """Execute a given kernel function over the particle set for multiple timesteps. diff --git a/parcels/tools/converters.py b/parcels/tools/converters.py index b404e93ef..b077c9a36 100644 --- a/parcels/tools/converters.py +++ b/parcels/tools/converters.py @@ -2,10 +2,12 @@ import inspect from datetime import timedelta from math import cos, pi +from typing import Any import cftime import numpy as np import xarray as xr +from numpy.typing import ArrayLike, NDArray __all__ = [ "UnitConverter", @@ -20,7 +22,7 @@ ] -def convert_to_flat_array(var): +def convert_to_flat_array(var: list[float] | float | int | NDArray[Any] | ArrayLike) -> NDArray[Any]: """Convert lists and single integers/floats to one-dimensional numpy arrays Parameters