Skip to content

Commit

Permalink
Feature: interpolator, WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Magnus-SI committed Nov 22, 2024
1 parent 923b266 commit 780630b
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 56 deletions.
35 changes: 21 additions & 14 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,6 @@ def __init__(self, config: DictConfig) -> None:
self.model_comm_group_rank,
)

# Set the maximum rollout to be expected
self.rollout = (
self.config.training.rollout.max
if self.config.training.rollout.epoch_increment > 0
else self.config.training.rollout.start
)

# Set the training end date if not specified
if self.config.dataloader.training.end is None:
LOGGER.info(
Expand Down Expand Up @@ -102,6 +95,23 @@ def metadata(self) -> dict:
@cached_property
def data_indices(self) -> IndexCollection:
return IndexCollection(self.config, self.ds_train.name_to_index)

@cached_property
def relative_date_indices(self) -> list:
"""Determine a list of relative time indices to load for each batch"""
if hasattr(self.config.training, "explicit_times"):
return sorted(self.config.training.explicit_times.input + self.config.training.explicit_times.target)

else: #uses the old default of multistep, timeincrement and rollout.
# Use the maximum rollout to be expected
rollout = (
self.config.training.rollout.max
if self.config.training.rollout.epoch_increment > 0
else self.config.training.rollout.start
)#NOTE: --> for gradual rollout, max rollout dates is always fetched. But this was always the case in datamodule.py

multi_step = self.config.training.multistep_input
return [self.timeincrement * mstep for mstep in range(multi_step + rollout)]

@cached_property
def timeincrement(self) -> int:
Expand Down Expand Up @@ -140,7 +150,8 @@ def ds_train(self) -> NativeGridDataset:

@cached_property
def ds_valid(self) -> NativeGridDataset:
r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))
#r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))
#NOTE: temporary left unimplemented until I figure out how to best do this with the new time_indices object

assert self.config.dataloader.training.end < self.config.dataloader.validation.start, (
f"Training end date {self.config.dataloader.training.end} is not before"
Expand All @@ -149,7 +160,7 @@ def ds_valid(self) -> NativeGridDataset:
return self._get_dataset(
open_dataset(OmegaConf.to_container(self.config.dataloader.validation, resolve=True)),
shuffle=False,
rollout=r,
#rollout=r, #NOTE: see the above
label="validation",
)

Expand All @@ -173,15 +184,11 @@ def _get_dataset(
self,
data_reader: Callable,
shuffle: bool = True,
rollout: int = 1,
label: str = "generic",
) -> NativeGridDataset:
r = max(rollout, self.rollout)
data = NativeGridDataset(
data_reader=data_reader,
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
relative_date_indices = self.relative_date_indices,
model_comm_group_rank=self.model_comm_group_rank,
model_comm_group_id=self.model_comm_group_id,
model_comm_num_groups=self.model_comm_num_groups,
Expand Down
33 changes: 11 additions & 22 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ class NativeGridDataset(IterableDataset):
def __init__(
self,
data_reader: Callable,
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
relative_date_indices: list = [0,1,2],
model_comm_group_rank: int = 0,
model_comm_group_id: int = 0,
model_comm_num_groups: int = 1,
Expand All @@ -48,12 +46,8 @@ def __init__(
----------
data_reader : Callable
user function that opens and returns the zarr array data
rollout : int, optional
length of rollout window, by default 12
timeincrement : int, optional
time increment between samples, by default 1
multistep : int, optional
collate (t-1, ... t - multistep) into the input state vector, by default 1
relative_date_indices : list
list of time indices to load from the data relative to the current sample i in __iter__
model_comm_group_rank : int, optional
process rank in the torch.distributed group (important when running on multiple GPUs), by default 0
model_comm_group_id: int, optional
Expand All @@ -70,9 +64,6 @@ def __init__(

self.data = data_reader

self.rollout = rollout
self.timeincrement = timeincrement

# lazy init
self.n_samples_per_epoch_total: int = 0
self.n_samples_per_epoch_per_worker: int = 0
Expand All @@ -89,11 +80,12 @@ def __init__(
self.shuffle = shuffle

# Data dimensions
self.multi_step = multistep
assert self.multi_step > 0, "Multistep value must be greater than zero."
self.ensemble_dim: int = 2
self.ensemble_size = self.data.shape[self.ensemble_dim]

# relative index of dates to extract
self.relative_date_indices = relative_date_indices #np.array(date_indices, dtype = np.int32)

@cached_property
def statistics(self) -> dict:
"""Return dataset statistics."""
Expand Down Expand Up @@ -126,7 +118,7 @@ def valid_date_indices(self) -> np.ndarray:
dataset length minus rollout minus additional multistep inputs
(if time_increment is 1).
"""
return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement)
return get_usable_indices(self.data.missing, len(self.data), np.array(self.relative_date_indices, dtype=np.int32))

def per_worker_init(self, n_workers: int, worker_id: int) -> None:
"""Called by worker_init_func on each copy of dataset.
Expand Down Expand Up @@ -230,10 +222,9 @@ def __iter__(self) -> torch.Tensor:
)

for i in shuffled_chunk_indices:
start = i - (self.multi_step - 1) * self.timeincrement
end = i + (self.rollout + 1) * self.timeincrement

x = self.data[start : end : self.timeincrement]
#TODO: self.data[relative_date_indices + i] is intended here, but it seems like array indices are not supported in
#anemoi-datasets, and I couldn't get a tuple of indices that may not have a regular structure to work either
x = self.data[slice(self.relative_date_indices[0]+i, i+ self.relative_date_indices[-1]+1, 1)]
x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
self.ensemble_dim = 1

Expand All @@ -243,9 +234,7 @@ def __repr__(self) -> str:
return f"""
{super().__repr__()}
Dataset: {self.data}
Rollout: {self.rollout}
Multistep: {self.multi_step}
Timeincrement: {self.timeincrement}
Relative dates: {self.relative_date_indices}
"""


Expand Down
118 changes: 118 additions & 0 deletions src/anemoi/training/train/interpolator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging
import math
import os
from collections import defaultdict
from collections.abc import Generator
from collections.abc import Mapping
from typing import Optional
from typing import Union
from operator import itemgetter

import numpy as np
import pytorch_lightning as pl
import torch
from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.interface import AnemoiModelInterface
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from omegaconf import DictConfig
from omegaconf import OmegaConf
from timm.scheduler import CosineLRScheduler
from torch.distributed.distributed_c10d import ProcessGroup
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData

from anemoi.training.losses.utils import grad_scaler
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.masks import Boolean1DMask
from anemoi.training.utils.masks import NoOutputMask

from anemoi.training.train.forecaster import GraphForecaster

LOGGER = logging.getLogger(__name__)

class GraphInterpolator(GraphForecaster):
"""Graph neural network interpolator for PyTorch Lightning."""

def __init__(
self,
*,
config: DictConfig,
graph_data: HeteroData,
statistics: dict,
data_indices: IndexCollection,
metadata: dict,
) -> None:
"""Initialize graph neural network interpolator.
Parameters
----------
config : DictConfig
Job configuration
graph_data : HeteroData
Graph object
statistics : dict
Statistics of the training data
data_indices : IndexCollection
Indices of the training data,
metadata : dict
Provenance information
"""
super().__init__(config = config, graph_data = graph_data, statistics = statistics, data_indices = data_indices, metadata = metadata)
self.target_forcing_indices = itemgetter(*config.training.target_forcing.data)(data_indices.data.input.name_to_index)
self.boundary_times = config.training.explicit_times.input
self.interp_times = config.training.explicit_times.target


def _step(
self,
batch: torch.Tensor,
batch_idx: int,
validation_mode: bool = False,
) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]:

del batch_idx
loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False)
metrics = {}
y_preds = []

batch = self.model.pre_processors(batch)
x_bound = batch[:, self.boundary_times][..., self.data_indices.data.input.full] # (bs, time, ens, latlon, nvar)

tfi = self.target_forcing_indices
target_forcing = torch.empty(batch.shape[0], batch.shape[2], batch.shape[3], len(tfi)+1, device = self.device, dtype = batch.dtype)
for interp_step in self.interp_times:
#get the forcing information for the target interpolation time:
target_forcing[..., :len(tfi)] = batch[:, interp_step, :, :, tfi]
target_forcing[..., -1] = (interp_step - self.boundary_times[0])/(self.boundary_times[1] - self.boundary_times[0])
#TODO: make fraction time one of a config given set of arbitrary custom forcing functions.

y_pred = self(x_bound, target_forcing)
y = batch[:, interp_step, :, :, self.data_indices.data.output.full]

loss += checkpoint(self.loss, y_pred, y, use_reentrant=False)

metrics_next = {}
if validation_mode:
metrics_next = self.calculate_val_metrics(y_pred, y, interp_step-1) #expects rollout but can be repurposed here.
metrics.update(metrics_next)
y_preds.extend(y_pred)

loss *= 1.0 / len(self.interp_times)
return loss, metrics, y_preds

def forward(self, x: torch.Tensor, target_forcing: torch.Tensor) -> torch.Tensor:
return self.model(x, target_forcing, self.model_comm_group)
11 changes: 7 additions & 4 deletions src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING
import importlib

import hydra
import numpy as np
Expand All @@ -33,7 +34,6 @@
from anemoi.training.diagnostics.logger import get_tensorboard_logger
from anemoi.training.diagnostics.logger import get_wandb_logger
from anemoi.training.distributed.strategy import DDPGroupStrategy
from anemoi.training.train.forecaster import GraphForecaster
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.seeding import get_base_seed

Expand Down Expand Up @@ -135,7 +135,7 @@ def graph_data(self) -> HeteroData:
)

@cached_property
def model(self) -> GraphForecaster:
def model(self) -> pl.LightningModule:
"""Provide the model instance."""
kwargs = {
"config": self.config,
Expand All @@ -144,10 +144,13 @@ def model(self) -> GraphForecaster:
"metadata": self.metadata,
"statistics": self.datamodule.statistics,
}
train_module = importlib.import_module(getattr(self.config.training, "train_module", "anemoi.training.train.forecaster"))
train_func = getattr(train_module, getattr(self.config.training, "train_function", "GraphForecaster"))
#NOTE: instantiate would be preferable, but I run into issues with "config" being the first kwarg of instantiate itself.
if self.load_weights_only:
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs)
return GraphForecaster(**kwargs)
return train_func.load_from_checkpoint(self.last_checkpoint, **kwargs)
return train_func(**kwargs)

@rank_zero_only
def _get_mlflow_run_id(self) -> str:
Expand Down
22 changes: 6 additions & 16 deletions src/anemoi/training/utils/usable_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
def get_usable_indices(
missing_indices: set[int] | None,
series_length: int,
rollout: int,
multistep: int,
timeincrement: int = 1,
relative_indices: np.ndarray,
) -> np.ndarray:
"""Get the usable indices of a series whit missing indices.
Expand All @@ -28,32 +26,24 @@ def get_usable_indices(
Dataset to be used.
series_length : int
Length of the series.
rollout : int
Number of steps to roll out.
multistep : int
Number of previous indices to include as predictors.
timeincrement : int
Time increment, by default 1.
relative_indices:
Array of relative indices requested at each index i.
Returns
-------
usable_indices : np.array
Array of usable indices.
"""
prev_invalid_dates = (multistep - 1) * timeincrement
next_invalid_dates = rollout * timeincrement

usable_indices = np.arange(series_length) # set of all indices

if missing_indices is None:
missing_indices = set()

missing_indices |= {-1, series_length} # to filter initial and final indices
missing_indices |= {series_length} #filter final index

# Missing indices
for i in missing_indices:
usable_indices = usable_indices[
(usable_indices < i - next_invalid_dates) + (usable_indices > i + prev_invalid_dates)
]
rel_missing = i - relative_indices #indices which have their relative indices match the missing.
usable_indices = usable_indices[np.all(usable_indices != rel_missing[:,np.newaxis], axis = 0)]

return usable_indices

0 comments on commit 780630b

Please sign in to comment.