diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8823c7ca2..10049f345 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -178,9 +178,7 @@ jobs: name: "TypeChecking: pixi run typing" runs-on: ubuntu-latest needs: [should-skip-ci, cache-pixi-lock] - # TODO v4: Enable typechecking again - # needs.should-skip-ci.outputs.value == 'false' - if: false + if: needs.should-skip-ci.outputs.value == 'false' steps: - name: Checkout uses: actions/checkout@v5 diff --git a/pyproject.toml b/pyproject.toml index ec5508606..fb133ac55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -160,12 +160,13 @@ known-first-party = ["parcels"] [tool.mypy] files = [ - "parcels/_typing.py", - "parcels/tools/*.py", - "parcels/grid.py", - "parcels/field.py", - "parcels/fieldset.py", + "src/parcels/_typing.py", + "src/parcels/_core/xgrid.py", + "src/parcels/_core/uxgrid.py", + "src/parcels/_core/field.py", + "src/parcels/_core/fieldset.py", ] +disable_error_code = "attr-defined,assignment,operator,call-overload,index,valid-type,override,misc,union-attr" [[tool.mypy.overrides]] module = [ @@ -174,9 +175,16 @@ module = [ "scipy.spatial", "sklearn.cluster", "zarr", + "zarr.storage", + "uxarray", + "xgcm", "cftime", "pykdtree.kdtree", "netCDF4", "pooch", ] ignore_missing_imports = true + +[[tool.mypy.overrides]] # TODO: This module should stabilize before release of v4 +module = "parcels.interpolators" +ignore_errors = true diff --git a/src/parcels/_core/field.py b/src/parcels/_core/field.py index 14aeb6ea8..64d583ac1 100644 --- a/src/parcels/_core/field.py +++ b/src/parcels/_core/field.py @@ -259,9 +259,9 @@ def __init__( self.igrid = U.igrid if W is None: - _assert_same_time_interval((U, V)) + _assert_same_time_interval([U, V]) else: - _assert_same_time_interval((U, V, W)) + _assert_same_time_interval([U, V, W]) self.time_interval = U.time_interval diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index ba7536de0..282b30fe6 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -20,7 +20,6 @@ from parcels._logger import logger from parcels._reprs import fieldset_repr from parcels._typing import Mesh -from parcels.convert import _ds_rename_using_standard_names from parcels.interpolators import ( CGrid_Velocity, Ux_Velocity, @@ -182,7 +181,7 @@ def add_constant(self, name, value): @property def gridset(self) -> list[BaseGrid]: - grids = [] + grids: list[BaseGrid] = [] for field in self.fields.values(): if field.grid not in grids: grids.append(field.grid) @@ -416,7 +415,8 @@ def _datetime_to_msg(example_datetime: TimeLike) -> str: return msg -def _format_calendar_error_message(field: Field, reference_datetime: TimeLike) -> str: +def _format_calendar_error_message(field: Field | VectorField, reference_datetime: TimeLike) -> str: + assert field.time_interval is not None return f"Expected field {field.name!r} to have calendar compatible with datetime object {_datetime_to_msg(reference_datetime)}. Got field with calendar {_datetime_to_msg(field.time_interval.left)}. Have you considered using xarray to update the time dimension of the dataset to have a compatible calendar?" @@ -456,6 +456,16 @@ def _format_calendar_error_message(field: Field, reference_datetime: TimeLike) - } +def _ds_rename_using_standard_names(ds: xr.Dataset | ux.UxDataset, name_dict: dict[str, str]) -> xr.Dataset: + for standard_name, rename_to in name_dict.items(): + name = ds.cf[standard_name].name + ds = ds.rename({name: rename_to}) + logger.info( + f"cf_xarray found variable {name!r} with CF standard name {standard_name!r} in dataset, renamed it to {rename_to!r} for Parcels simulation." + ) + return ds + + def _discover_ux_U_and_V(ds: ux.UxDataset) -> ux.UxDataset: # Common variable names for U and V found in UxDatasets common_ux_UV = [("unod", "vnod"), ("u", "v")] diff --git a/src/parcels/_core/index_search.py b/src/parcels/_core/index_search.py index b9b0ae497..49522afc2 100644 --- a/src/parcels/_core/index_search.py +++ b/src/parcels/_core/index_search.py @@ -1,6 +1,5 @@ from __future__ import annotations -from datetime import datetime from typing import TYPE_CHECKING import numpy as np @@ -9,8 +8,8 @@ from parcels._core.utils.time import timedelta_to_float if TYPE_CHECKING: + from parcels import XGrid from parcels._core.field import Field - from parcels.xgrid import XGrid GRID_SEARCH_ERROR = -3 @@ -21,7 +20,7 @@ def _search_1d_array( arr: np.array, x: float, -) -> tuple[int, int]: +) -> tuple[np.array[int], np.array[float]]: """ Searches for particle locations in a 1D array and returns barycentric coordinate along dimension. @@ -63,14 +62,14 @@ def _search_1d_array( return np.atleast_1d(index), np.atleast_1d(bcoord) -def _search_time_index(field: Field, time: datetime): +def _search_time_index(field: Field, time: float): """Find and return the index and relative coordinate in the time array associated with a given time. Parameters ---------- field: Field - time: datetime + time: float This is the amount of time, in seconds (time_delta), in unix epoch Note that we normalize to either the first or the last index if the sampled value is outside the time value range. @@ -172,6 +171,8 @@ def _search_indices_curvilinear_2d( """ if np.any(xi): # If an initial guess is provided, we first perform a point in cell check for all guessed indices + assert xi is not None + assert yi is not None is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi, xi) y_check = y[is_in_cell == 0] x_check = x[is_in_cell == 0] diff --git a/src/parcels/_core/kernel.py b/src/parcels/_core/kernel.py index 3fe8b717d..30811e2d3 100644 --- a/src/parcels/_core/kernel.py +++ b/src/parcels/_core/kernel.py @@ -2,7 +2,6 @@ import types import warnings -from typing import TYPE_CHECKING import numpy as np @@ -24,9 +23,6 @@ AdvectionRK45, ) -if TYPE_CHECKING: - from collections.abc import Callable - __all__ = ["Kernel"] @@ -84,7 +80,7 @@ def __init__( # if (pyfunc is AdvectionRK4_3D) and fieldset.U.gridindexingtype == "croco": # pyfunc = AdvectionRK4_3D_CROCO - self._pyfuncs: list[Callable] = pyfuncs + self._pyfuncs: list[types.FunctionType] = pyfuncs @property #! Ported from v3. To be removed in v4? (/find another way to name kernels in output file) def funcname(self): diff --git a/src/parcels/_core/particle.py b/src/parcels/_core/particle.py index 86ddc5138..2ca0b49f6 100644 --- a/src/parcels/_core/particle.py +++ b/src/parcels/_core/particle.py @@ -37,7 +37,7 @@ class Variable: def __init__( self, name, - dtype: np.dtype = np.float32, + dtype: type[np.float32 | np.float64 | np.int32 | np.int64] = np.float32, initial=0, to_write: bool | Literal["once"] = True, attrs: dict | None = None, @@ -122,7 +122,7 @@ def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_va raise ValueError(f"Variable name already exists: {var.name}") -def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClass: +def get_default_particle(spatial_dtype: type[np.float32 | np.float64]) -> ParticleClass: if spatial_dtype not in [np.float32, np.float64]: raise ValueError(f"spatial_dtype must be np.float32 or np.float64. Got {spatial_dtype=!r}") diff --git a/src/parcels/_core/particleset.py b/src/parcels/_core/particleset.py index c4cd5ffd8..75040d43f 100644 --- a/src/parcels/_core/particleset.py +++ b/src/parcels/_core/particleset.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import datetime import sys +import types import warnings from collections.abc import Iterable -from typing import Literal +from typing import TYPE_CHECKING, Literal import numpy as np import xarray as xr @@ -21,6 +24,9 @@ from parcels._logger import logger from parcels._reprs import _format_zarr_output_location, particleset_repr +if TYPE_CHECKING: + from parcels import FieldSet, ParticleClass, ParticleFile + __all__ = ["ParticleSet"] @@ -58,10 +64,10 @@ class ParticleSet: def __init__( self, - fieldset, - pclass=Particle, - lon=None, - lat=None, + fieldset: FieldSet, + pclass: ParticleClass = Particle, + lon: np.array[float] = None, + lat: np.array[float] = None, z=None, time=None, trajectory_ids=None, @@ -376,12 +382,12 @@ def set_variable_write_status(self, var, write_status): def execute( self, - pyfunc, + pyfunc: types.FunctionType | Kernel, dt: datetime.timedelta | np.timedelta64 | float, endtime: np.timedelta64 | np.datetime64 | None = None, runtime: datetime.timedelta | np.timedelta64 | float | None = None, - output_file=None, - verbose_progress=True, + output_file: ParticleFile = None, + verbose_progress: bool = True, ): """Execute a given kernel function over the particle set for multiple timesteps. diff --git a/src/parcels/_core/utils/interpolation.py b/src/parcels/_core/utils/interpolation.py index e4893a813..3957d5885 100644 --- a/src/parcels/_core/utils/interpolation.py +++ b/src/parcels/_core/utils/interpolation.py @@ -22,7 +22,7 @@ def phi1D_quad(xsi: float) -> list[float]: return phi -def phi2D_lin(eta: float, xsi: float) -> list[float]: +def phi2D_lin(eta: float, xsi: float) -> np.ndarray: phi = np.column_stack([(1-xsi) * (1-eta), xsi * (1-eta), xsi * eta , diff --git a/src/parcels/_core/utils/sgrid.py b/src/parcels/_core/utils/sgrid.py index 9fd87e143..99780f4a6 100644 --- a/src/parcels/_core/utils/sgrid.py +++ b/src/parcels/_core/utils/sgrid.py @@ -175,7 +175,7 @@ def to_attrs(self) -> dict[str, str | int]: d["vertical_dimensions"] = dump_mappings(self.vertical_dimensions) return d - def rename(self, names_dict: dict[str, str]) -> Self: + def rename(self, names_dict: dict[str, str]) -> Grid2DMetadata: return _metadata_rename(self, names_dict) def get_value_by_id(self, id: str) -> str: @@ -285,7 +285,7 @@ def to_attrs(self) -> dict[str, str | int]: d["node_coordinates"] = dump_mappings(self.node_coordinates) return d - def rename(self, dims_dict: dict[str, str]) -> Self: + def rename(self, dims_dict: dict[str, str]) -> Grid3DMetadata: return _metadata_rename(self, dims_dict) def get_value_by_id(self, id: str) -> str: diff --git a/src/parcels/_core/utils/time.py b/src/parcels/_core/utils/time.py index 76a7a54cc..07a9f4681 100644 --- a/src/parcels/_core/utils/time.py +++ b/src/parcels/_core/utils/time.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Callable from datetime import datetime, timedelta -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Generic, Literal, TypeVar import cftime import numpy as np @@ -14,7 +15,7 @@ T = TypeVar("T", bound="TimeLike") -class TimeInterval: +class TimeInterval(Generic[T]): """A class representing a time interval between two datetime or np.timedelta64 objects. Parameters @@ -29,7 +30,7 @@ class TimeInterval: For the purposes of this codebase, the interval can be thought of as closed on the left and right. """ - def __init__(self, left: T, right: T) -> None: + def __init__(self, left: T, right: T): if not isinstance(left, (np.timedelta64, datetime, cftime.datetime, np.datetime64)): raise ValueError( f"Expected right to be a np.timedelta64, datetime, cftime.datetime, or np.datetime64. Got {type(left)}." @@ -130,7 +131,7 @@ def get_datetime_type_calendar( return type(example_datetime), calendar -_TD_PRECISION_GETTER_FOR_UNIT = ( +_TD_PRECISION_GETTER_FOR_UNIT: tuple[tuple[Callable[[timedelta], int], Literal["D", "s", "us"]], ...] = ( (lambda dt: dt.days, "D"), (lambda dt: dt.seconds, "s"), (lambda dt: dt.microseconds, "us"), @@ -142,14 +143,14 @@ def maybe_convert_python_timedelta_to_numpy(dt: timedelta | np.timedelta64) -> n return dt try: - dts = [] + dts: list[np.timedelta64] = [] for get_value_for_unit, np_unit in _TD_PRECISION_GETTER_FOR_UNIT: value = get_value_for_unit(dt) if value != 0: dts.append(np.timedelta64(value, np_unit)) if dts: - return sum(dts) + return np.sum(dts) else: return np.timedelta64(0, "s") except Exception as e: diff --git a/src/parcels/_core/uxgrid.py b/src/parcels/_core/uxgrid.py index a1d45e796..78b1c15a3 100644 --- a/src/parcels/_core/uxgrid.py +++ b/src/parcels/_core/uxgrid.py @@ -18,7 +18,7 @@ class UxGrid(BaseGrid): for interpolation on unstructured grids. """ - def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh) -> UxGrid: + def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh): """ Initializes the UxGrid with a uxarray grid and vertical coordinate array. diff --git a/src/parcels/_core/xgrid.py b/src/parcels/_core/xgrid.py index 2de3de998..280778ba9 100644 --- a/src/parcels/_core/xgrid.py +++ b/src/parcels/_core/xgrid.py @@ -10,16 +10,15 @@ from parcels._core.basegrid import BaseGrid from parcels._core.index_search import _search_1d_array, _search_indices_curvilinear_2d from parcels._reprs import xgrid_repr -from parcels._typing import assert_valid_mesh +from parcels._typing import CfAxis, assert_valid_mesh _XGRID_AXES = Literal["X", "Y", "Z"] _XGRID_AXES_ORDERING: Sequence[_XGRID_AXES] = "ZYX" -_XGCM_AXIS_DIRECTION = Literal["X", "Y", "Z", "T"] _XGCM_AXIS_POSITION = Literal["center", "left", "right", "inner", "outer"] -_XGCM_AXES = Mapping[_XGCM_AXIS_DIRECTION, xgcm.Axis] +_XGCM_AXES = Mapping[CfAxis, xgcm.Axis] -_FIELD_DATA_ORDERING: Sequence[_XGCM_AXIS_DIRECTION] = "TZYX" +_FIELD_DATA_ORDERING: Sequence[CfAxis] = "TZYX" _DEFAULT_XGCM_KWARGS = {"periodic": False} @@ -282,7 +281,7 @@ def _gtype(self): TODO: Remove """ - from parcels.grid import GridType + from parcels._core.basegrid import GridType if len(self.lon.shape) <= 1: if self.depth is None or len(self.depth.shape) <= 1: @@ -384,7 +383,7 @@ def get_axis_dim_mapping(self, dims: list[str]) -> dict[_XGRID_AXES, str]: return result -def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None: +def get_axis_from_dim_name(axes: _XGCM_AXES, dim: Hashable) -> CfAxis | None: """For a given dimension name in a grid, returns the direction axis it is on.""" for axis_name, axis in axes.items(): if dim in axis.coords.values(): @@ -421,7 +420,7 @@ def assert_valid_field_array(da: xr.DataArray, axes: _XGCM_AXES): assert_all_dimensions_correspond_with_axis(da, axes) dim_to_axis = {dim: get_axis_from_dim_name(axes, dim) for dim in da.dims} - dim_to_axis = cast(dict[Hashable, _XGCM_AXIS_DIRECTION], dim_to_axis) + dim_to_axis = cast(dict[Hashable, CfAxis], dim_to_axis) # Assert all dimensions are present if set(dim_to_axis.values()) != {"T", "Z", "Y", "X"}: diff --git a/src/parcels/_typing.py b/src/parcels/_typing.py index 7c2e12b8f..440a9369b 100644 --- a/src/parcels/_typing.py +++ b/src/parcels/_typing.py @@ -14,6 +14,8 @@ import numpy as np from cftime import datetime as cftime_datetime +CfAxis = Literal["X", "Y", "Z", "T"] + InterpMethodOption = Literal[ "linear", "nearest", diff --git a/src/parcels/convert.py b/src/parcels/convert.py index b981e318b..c4a579c6d 100644 --- a/src/parcels/convert.py +++ b/src/parcels/convert.py @@ -12,20 +12,16 @@ from __future__ import annotations -import typing - import numpy as np import xarray as xr from parcels._core.utils import sgrid from parcels._logger import logger +from parcels._typing import CfAxis -if typing.TYPE_CHECKING: - import uxarray as ux - -_NEMO_DIMENSION_COORD_NAMES = ["x", "y", "time", "x", "x_center", "y", "y_center", "depth", "glamf", "gphif"] +_NEMO_DIMENSION_COORD_NAMES: list[str] = ["x", "y", "time", "x", "x_center", "y", "y_center", "depth", "glamf", "gphif"] -_NEMO_AXIS_VARNAMES = { +_NEMO_AXIS_VARNAMES: dict[str, CfAxis] = { "x": "X", "x_center": "X", "y": "Y", @@ -34,7 +30,7 @@ "time": "T", } -_NEMO_VARNAMES_MAPPING = { +_NEMO_VARNAMES_MAPPING: dict[str, str] = { "time_counter": "time", "depthw": "depth", "uo": "U", @@ -42,7 +38,7 @@ "wo": "W", } -_MITGCM_AXIS_VARNAMES = { +_MITGCM_AXIS_VARNAMES: dict[str, CfAxis] = { "XC": "X", "XG": "X", "Xp1": "X", @@ -57,13 +53,13 @@ "time": "T", } -_MITGCM_VARNAMES_MAPPING = { +_MITGCM_VARNAMES_MAPPING: dict[str, str] = { "XG": "lon", "YG": "lat", "Zl": "depth", } -_COPERNICUS_MARINE_AXIS_VARNAMES = { +_COPERNICUS_MARINE_AXIS_VARNAMES: dict[CfAxis, str] = { "X": "lon", "Y": "lat", "Z": "depth", @@ -71,7 +67,7 @@ } -def _maybe_bring_UV_depths_to_depth(ds): +def _maybe_bring_UV_depths_to_depth(ds: xr.Dataset): if "U" in ds.variables and "depthu" in ds.U.coords and "depth" in ds.coords: ds["U"] = ds["U"].assign_coords(depthu=ds["depth"].values).rename({"depthu": "depth"}) if "V" in ds.variables and "depthv" in ds.V.coords and "depth" in ds.coords: @@ -79,14 +75,14 @@ def _maybe_bring_UV_depths_to_depth(ds): return ds -def _maybe_create_depth_dim(ds): +def _maybe_create_depth_dim(ds: xr.Dataset): if "depth" not in ds.dims: ds = ds.expand_dims({"depth": [0]}) ds["depth"] = xr.DataArray([0], dims=["depth"]) return ds -def _maybe_rename_coords(ds, axis_varnames): +def _maybe_rename_coords(ds: xr.Dataset, axis_varnames: dict[CfAxis, str]): try: for axis, [coord] in ds.cf.axes.items(): ds = ds.rename({coord: axis_varnames[axis]}) @@ -95,61 +91,51 @@ def _maybe_rename_coords(ds, axis_varnames): return ds -def _maybe_rename_variables(ds, varnames_mapping): +def _maybe_rename_variables(ds: xr.Dataset, varnames_mapping: dict[str, str]): rename_dict = {old: new for old, new in varnames_mapping.items() if (old in ds.data_vars) or (old in ds.coords)} if rename_dict: ds = ds.rename(rename_dict) return ds -def _assign_dims_as_coords(ds, dimension_names): +def _assign_dims_as_coords(ds: xr.Dataset, dimension_names: list[str]): for axis in dimension_names: if axis in ds.dims and axis not in ds.coords: ds = ds.assign_coords({axis: np.arange(ds.sizes[axis])}) return ds -def _drop_unused_dimensions_and_coords(ds, dimension_and_coord_names): +def _drop_unused_dimensions_and_coords(ds: xr.Dataset, dimension_and_coord_names: list[str]): for dim in ds.dims: if dim not in dimension_and_coord_names: - ds = ds.drop_dims(dim, errors="ignore") + ds = ds.drop_dims([dim], errors="ignore") for coord in ds.coords: if coord not in dimension_and_coord_names: - ds = ds.drop_vars(coord, errors="ignore") + ds = ds.drop_vars([coord], errors="ignore") return ds -def _set_coords(ds, dimension_names): +def _set_coords(ds: xr.Dataset, dimension_names): for varname in dimension_names: if varname in ds and varname not in ds.coords: ds = ds.set_coords([varname]) return ds -def _maybe_remove_depth_from_lonlat(ds): +def _maybe_remove_depth_from_lonlat(ds: xr.Dataset): for coord in ["glamf", "gphif"]: if coord in ds.coords and "depth" in ds[coord].dims: ds[coord] = ds[coord].squeeze("depth", drop=True) return ds -def _set_axis_attrs(ds, dim_axis): +def _set_axis_attrs(ds: xr.Dataset, dim_axis: dict[str, CfAxis]): for dim, axis in dim_axis.items(): if dim in ds: ds[dim].attrs["axis"] = axis return ds -def _ds_rename_using_standard_names(ds: xr.Dataset | ux.UxDataset, name_dict: dict[str, str]) -> xr.Dataset: - for standard_name, rename_to in name_dict.items(): - name = ds.cf[standard_name].name - ds = ds.rename({name: rename_to}) - logger.info( - f"cf_xarray found variable {name!r} with CF standard name {standard_name!r} in dataset, renamed it to {rename_to!r} for Parcels simulation." - ) - return ds - - def _maybe_swap_depth_direction(ds: xr.Dataset) -> xr.Dataset: if ds["depth"].size > 1: if ds["depth"][0] > ds["depth"][-1]: @@ -160,39 +146,6 @@ def _maybe_swap_depth_direction(ds: xr.Dataset) -> xr.Dataset: return ds -# TODO is this function still needed, now that we require users to provide field names explicitly? -def _discover_U_and_V(ds: xr.Dataset, cf_standard_names_fallbacks) -> xr.Dataset: - # Assumes that the dataset has U and V data - - if "W" not in ds: - for cf_standard_name_W in cf_standard_names_fallbacks["W"]: - if cf_standard_name_W in ds.cf.standard_names: - ds = _ds_rename_using_standard_names(ds, {cf_standard_name_W: "W"}) - break - - if "U" in ds and "V" in ds: - return ds # U and V already present - elif "U" in ds or "V" in ds: - raise ValueError( - "Dataset has only one of the two variables 'U' and 'V'. Please rename the appropriate variable in your dataset to have both 'U' and 'V' for Parcels simulation." - ) - - for cf_standard_name_U, cf_standard_name_V in cf_standard_names_fallbacks["UV"]: - if cf_standard_name_U in ds.cf.standard_names: - if cf_standard_name_V not in ds.cf.standard_names: - raise ValueError( - f"Dataset has variable with CF standard name {cf_standard_name_U!r}, " - f"but not the matching variable with CF standard name {cf_standard_name_V!r}. " - "Please rename the appropriate variables in your dataset to have both 'U' and 'V' for Parcels simulation." - ) - else: - continue - - ds = _ds_rename_using_standard_names(ds, {cf_standard_name_U: "U", cf_standard_name_V: "V"}) - break - return ds - - def nemo_to_sgrid(*, fields: dict[str, xr.Dataset | xr.DataArray], coords: xr.Dataset): # TODO: Update docstring """Create a FieldSet from a xarray.Dataset from NEMO netcdf files. diff --git a/src/parcels/interpolators/_xinterpolators.py b/src/parcels/interpolators/_xinterpolators.py index 14f92cb00..075f7d415 100644 --- a/src/parcels/interpolators/_xinterpolators.py +++ b/src/parcels/interpolators/_xinterpolators.py @@ -19,7 +19,7 @@ def ZeroInterpolator( particle_positions: dict[str, float | np.ndarray], grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]], field: Field, -) -> np.float32 | np.float64: +) -> float: """Template function used for the signature check of the lateral interpolation methods.""" return 0.0 @@ -28,7 +28,7 @@ def ZeroInterpolator_Vector( particle_positions: dict[str, float | np.ndarray], grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]], vectorfield: VectorField, -) -> np.float32 | np.float64: +) -> float: """Template function used for the signature check of the interpolation methods for velocity fields.""" return 0.0 @@ -392,8 +392,8 @@ def _Spatialslip( particle_positions: dict[str, float | np.ndarray], grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]], vectorfield: VectorField, - a: np.float32, - b: np.float32, + a: float, + b: float, ): """Helper function for spatial boundary condition interpolation for velocity fields.""" xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"]