Skip to content

Commit

Permalink
Add some type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
VeckoTheGecko committed Aug 21, 2024
1 parent 217e1c2 commit a28bd6b
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 29 deletions.
4 changes: 3 additions & 1 deletion parcels/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import ast
import datetime
import os
from typing import Any, Literal, get_args
from typing import Any, Callable, Literal, get_args


class ParcelsAST(ast.AST):
Expand All @@ -36,6 +36,8 @@ class ParcelsAST(ast.AST):
UpdateStatus = Literal["not_updated", "first_updated", "updated"] # mostly corresponds with `update_status`
TimePeriodic = float | datetime.timedelta | Literal[False] # mostly corresponds with `update_status`

KernelFunction = Callable[..., None]


def ensure_is_literal_value(value: Any, literal: Any) -> None:
"""Ensures that a value is a valid option for the provided Literal type annotation."""
Expand Down
2 changes: 1 addition & 1 deletion parcels/interaction/interactionkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
py_ast=None,
funcvars=None,
c_include="",
delete_cfiles=True,
delete_cfiles: bool = True,
):
if MPI is not None and MPI.COMM_WORLD.Get_size() > 1:
raise NotImplementedError(
Expand Down
4 changes: 2 additions & 2 deletions parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __init__(
funccode=None,
py_ast=None,
funcvars: list[str] | None = None,
c_include="",
delete_cfiles=True,
c_include: str = "",
delete_cfiles: bool = True,
):
self._fieldset = fieldset
self.field_args = None
Expand Down
4 changes: 2 additions & 2 deletions parcels/particledata.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, n
assert (
depth is not None
), "particle's initial depth is None - incompatible with the ParticleData class. Invalid state."
assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don" "t all have the same lenghts."
assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don't all have the same lenghts."

assert lon.size == time.size, "time and positions (lon, lat, depth) don" "t have the same lengths."
assert lon.size == time.size, "time and positions (lon, lat, depth) don't have the same lengths."

# If a partitioning function for MPI runs has been passed into the
# particle creation with the "partition_function" kwarg, retrieve it here.
Expand Down
65 changes: 43 additions & 22 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
import sys
from copy import copy
from datetime import date, datetime, timedelta
from typing import Callable, Literal

import cftime
import numpy as np
import xarray as xr
from tqdm import tqdm

import parcels
from parcels._typing import KernelFunction, PathLike

try:
from mpi4py import MPI
except ModuleNotFoundError:
MPI = None

try:
from pykdtree.kdtree import KDTree
except ModuleNotFoundError:
Expand All @@ -21,6 +24,7 @@
from parcels.application_kernels.advection import AdvectionRK4
from parcels.compilation.codecompiler import GNUCompiler
from parcels.field import NestedField
from parcels.fieldset import FieldSet
from parcels.grid import CurvilinearGrid, GridType
from parcels.interaction.interactionkernel import InteractionKernel
from parcels.interaction.neighborsearch import (
Expand All @@ -30,7 +34,7 @@
KDTreeFlatNeighborSearch,
)
from parcels.kernel import Kernel
from parcels.particle import JITParticle, Variable
from parcels.particle import JITParticle, ScipyParticle, Variable
from parcels.particledata import ParticleData, ParticleDataIterator
from parcels.particlefile import ParticleFile
from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array
Expand Down Expand Up @@ -90,13 +94,13 @@ class ParticleSet:

def __init__(
self,
fieldset,
fieldset: FieldSet,
pclass=JITParticle,
lon=None,
lat=None,
lon: list[float] | None = None,
lat: list[float] | None = None,
depth=None,
time=None,
repeatdt=None,
repeatdt: timedelta | float | None = None,
lonlatdepth_dtype=None,
pid_orig=None,
interaction_distance=None,
Expand Down Expand Up @@ -166,7 +170,7 @@ def ArrayClass_init(self, *args, **kwargs):
depth = np.ones(lon.size) * mindepth
else:
depth = convert_to_flat_array(depth)
assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don" "t all have the same lenghts"
assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don't all have the same lenghts"

time = convert_to_flat_array(time)
time = np.repeat(time, lon.size) if time.size == 1 else time
Expand All @@ -176,7 +180,7 @@ def ArrayClass_init(self, *args, **kwargs):
if time.size > 0 and isinstance(time[0], np.timedelta64) and not self.time_origin:
raise NotImplementedError("If fieldset.time_origin is not a date, time of a particle must be a double")
time = np.array([self.time_origin.reltime(t) if _convert_to_reltime(t) else t for t in time])
assert lon.size == time.size, "time and positions (lon, lat, depth) don" "t have the same lengths."
assert lon.size == time.size, "time and positions (lon, lat, depth) don't have the same lengths."

if lonlatdepth_dtype is None:
lonlatdepth_dtype = self.lonlatdepth_dtype_from_field_interp_method(fieldset.U)
Expand Down Expand Up @@ -476,7 +480,16 @@ def populate_indices(self):

@classmethod
def from_list(
cls, fieldset, pclass, lon, lat, depth=None, time=None, repeatdt=None, lonlatdepth_dtype=None, **kwargs
cls,
fieldset: FieldSet,
pclass: JITParticle | ScipyParticle,
lon: list[float],
lat: list[float],
depth: list[float] | None = None,
time: list[float] | None = None,
repeatdt: datetime | float | None = None,
lonlatdepth_dtype=None,
**kwargs,
):
"""Initialise the ParticleSet from lists of lon and lat.
Expand Down Expand Up @@ -576,7 +589,7 @@ def from_line(
)

@classmethod
def monte_carlo_sample(cls, start_field, size, mode="monte_carlo"):
def monte_carlo_sample(cls, start_field, size, mode: Literal["monte_carlo"] = "monte_carlo"):
"""Converts a starting field into a monte-carlo sample of lons and lats.
Parameters
Expand Down Expand Up @@ -635,16 +648,16 @@ def monte_carlo_sample(cls, start_field, size, mode="monte_carlo"):
)
return list(lon), list(lat)
else:
raise NotImplementedError(f'Mode {mode} not implemented. Please use "monte carlo" algorithm instead.')
raise NotImplementedError(f'Mode {mode} not implemented. Please use "monte_carlo" algorithm instead.')

@classmethod
def from_field(
cls,
fieldset,
pclass,
start_field,
size,
mode="monte_carlo",
size: int,
mode: Literal["monte_carlo"] = "monte_carlo",
depth=None,
time=None,
repeatdt=None,
Expand Down Expand Up @@ -690,7 +703,15 @@ def from_field(

@classmethod
def from_particlefile(
cls, fieldset, pclass, filename, restart=True, restarttime=None, repeatdt=None, lonlatdepth_dtype=None, **kwargs
cls,
fieldset: FieldSet,
pclass,
filename: PathLike,
restart=True,
restarttime=None,
repeatdt=None,
lonlatdepth_dtype=None,
**kwargs,
):
"""Initialise the ParticleSet from a zarr ParticleFile.
This creates a new ParticleSet based on locations of all particles written
Expand All @@ -702,7 +723,7 @@ def from_particlefile(
mod:`parcels.fieldset.FieldSet` object from which to sample velocity
pclass : parcels.particle.JITParticle or parcels.particle.ScipyParticle
Particle class. May be a particle class as defined in parcels, or a subclass defining a custom particle.
filename : str
filename :
Name of the particlefile from which to read initial conditions
restart : bool
BSignal if pset is used for a restart (default is True).
Expand Down Expand Up @@ -799,7 +820,7 @@ def from_particlefile(
**kwargs,
)

def Kernel(self, pyfunc, c_include="", delete_cfiles=True):
def Kernel(self, pyfunc: KernelFunction | list[KernelFunction], c_include="", delete_cfiles=True):
"""Wrapper method to convert a `pyfunc` into a :class:`parcels.kernel.Kernel` object.
Conversion is based on `fieldset` and `ptype` of the ParticleSet.
Expand Down Expand Up @@ -881,7 +902,7 @@ def error_particles(self):
return ParticleDataIterator(self.particledata, subset=error_indices)

@property
def num_error_particles(self):
def num_error_particles(self) -> int:
"""Get the number of particles that are in an error state.
Returns
Expand Down Expand Up @@ -909,12 +930,12 @@ def execute(
pyfunc_inter=None,
endtime=None,
runtime=None,
dt=1.0,
output_file=None,
verbose_progress=True,
postIterationCallbacks=None,
dt: timedelta | float = 1.0,
output_file: parcels.ParticleFile | None = None,
verbose_progress: bool = True,
postIterationCallbacks: list[Callable[[], None]] | None = None,
callbackdt=None,
delete_cfiles=True,
delete_cfiles: bool = True,
):
"""Execute a given kernel function over the particle set for multiple timesteps.
Expand Down
4 changes: 3 additions & 1 deletion parcels/tools/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import inspect
from datetime import timedelta
from math import cos, pi
from typing import Any

import cftime
import numpy as np
import xarray as xr
from numpy.typing import ArrayLike, NDArray

__all__ = [
"UnitConverter",
Expand All @@ -20,7 +22,7 @@
]


def convert_to_flat_array(var):
def convert_to_flat_array(var: list[float] | float | int | NDArray[Any] | ArrayLike) -> NDArray[Any]:
"""Convert lists and single integers/floats to one-dimensional numpy arrays
Parameters
Expand Down

0 comments on commit a28bd6b

Please sign in to comment.