diff --git a/neural_lam/config.py b/neural_lam/config.py index aa20030c..7df993d0 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -60,52 +60,109 @@ def coords_projection(self): proj_params = proj_config.get("kwargs", {}) return proj_class(**proj_params) + @functools.cached_property def param_names(self): """Return parameter names.""" - surface_names = self.values["state"]["surface"] - atmosphere_names = [ + surface_vars_names = self.values["state"]["surface_vars"] + atmosphere_vars_names = [ f"{var}_{level}" - for var in self.values["state"]["atmosphere"] + for var in self.values["state"]["atmosphere_vars"] for level in self.values["state"]["levels"] ] - return surface_names + atmosphere_names + return surface_vars_names + atmosphere_vars_names + @functools.cached_property def param_units(self): """Return parameter units.""" - surface_units = self.values["state"]["surface_units"] - atmosphere_units = [ + surface_vars_units = self.values["state"]["surface_vars_units"] + atmosphere_vars_units = [ unit - for unit in self.values["state"]["atmosphere_units"] + for unit in self.values["state"]["atmosphere_vars_units"] for _ in self.values["state"]["levels"] ] - return surface_units + atmosphere_units + return surface_vars_units + atmosphere_vars_units - def num_data_vars(self, key): - """Return the number of data variables for a given key.""" - surface_vars = len(self.values[key]["surface"]) - atmosphere_vars = len(self.values[key]["atmosphere"]) - levels = len(self.values[key]["levels"]) - return surface_vars + atmosphere_vars * levels + @functools.lru_cache() + def num_data_vars(self, category): + """Return the number of data variables for a given category.""" + surface_vars = self.values[category].get("surface_vars", []) + atmosphere_vars = self.values[category].get("atmosphere_vars", []) + levels = self.values[category].get("levels", []) - def projection(self): - """Return the projection.""" - proj_config = self.values["projections"]["class"] - proj_class = getattr(ccrs, proj_config["proj_class"]) - proj_params = proj_config["proj_params"] - return proj_class(**proj_params) + surface_vars_count = ( + len(surface_vars) if surface_vars is not None else 0 + ) + atmosphere_vars_count = ( + len(atmosphere_vars) if atmosphere_vars is not None else 0 + ) + levels_count = len(levels) if levels is not None else 0 - def open_zarr(self, dataset_name): - """Open a dataset specified by the dataset name.""" - dataset_path = self.zarrs[dataset_name].path - if dataset_path is None or not os.path.exists(dataset_path): - print( - f"Dataset '{dataset_name}' " - f"not found at path: {dataset_path}" - ) - return None - dataset = xr.open_zarr(dataset_path, consolidated=True) + return surface_vars_count + atmosphere_vars_count * levels_count + + @functools.lru_cache(maxsize=None) + def open_zarr(self, category): + """Open a dataset specified by the category.""" + zarr_config = self.zarrs[category] + + if isinstance(zarr_config, list): + try: + datasets = [] + for config in zarr_config: + dataset_path = config["path"] + dataset = xr.open_zarr(dataset_path, consolidated=True) + datasets.append(dataset) + return xr.merge(datasets) + except Exception: + print(f"Invalid zarr configuration for category: {category}") + return None + + else: + try: + dataset_path = zarr_config["path"] + return xr.open_zarr(dataset_path, consolidated=True) + except Exception: + print(f"Invalid zarr configuration for category: {category}") + return None + + def stack_grid(self, dataset): + """Stack grid dimensions.""" + dims = dataset.to_array().dims + + if "grid" not in dims and "x" in dims and "y" in dims: + dataset = dataset.squeeze().stack(grid=("x", "y")).to_array() + else: + try: + dataset = dataset.squeeze().to_array() + except ValueError: + print("Failed to stack grid dimensions.") + return None + + if "time" in dataset.dims: + dataset = dataset.transpose("time", "grid", "variable") + else: + dataset = dataset.transpose("grid", "variable") return dataset + @functools.lru_cache() + def get_nwp_xy(self, category): + """Get the x and y coordinates for the NWP grid.""" + dataset = self.open_zarr(category) + lon_name = self.zarrs[category].lat_lon_names.lon + lat_name = self.zarrs[category].lat_lon_names.lat + if lon_name in dataset and lat_name in dataset: + lon = dataset[lon_name].values + lat = dataset[lat_name].values + else: + raise ValueError( + f"Dataset does not contain " f"{lon_name} or {lat_name}" + ) + if lon.ndim == 1: + lon, lat = np.meshgrid(lat, lon) + lonlat = np.stack((lon, lat), axis=0) + + return lonlat + + @functools.cached_property def load_normalization_stats(self): """Load normalization statistics from Zarr archive.""" normalization_path = self.normalization.zarr @@ -120,10 +177,11 @@ def load_normalization_stats(self): ) return normalization_stats - def process_dataset(self, dataset_name, split="train", stack=True): + @functools.lru_cache(maxsize=None) + def process_dataset(self, category, split="train"): """Process a single dataset specified by the dataset name.""" - dataset = self.open_zarr(dataset_name) + dataset = self.open_zarr(category) if dataset is None: return None @@ -132,48 +190,64 @@ def process_dataset(self, dataset_name, split="train", stack=True): self.splits[split].end, ) dataset = dataset.sel(time=slice(start, end)) + + dims_mapping = {} + zarr_configs = self.zarrs[category] + if isinstance(zarr_configs, list): + for zarr_config in zarr_configs: + dims_mapping.update(zarr_config["dims"]) + else: + dims_mapping.update(zarr_configs["dims"].values) + dataset = dataset.rename_dims( { v: k - for k, v in self.zarrs[dataset_name].dims.values.items() - if k not in dataset.dims + for k, v in dims_mapping.items() + if k not in dataset.dims and v in dataset.dims } ) + dataset = dataset.rename_vars( + {v: k for k, v in dims_mapping.items() if v in dataset.coords} + ) - vars_surface = [] - if self[dataset_name].surface: - vars_surface = dataset[self[dataset_name].surface] + surface_vars = [] + if self[category].surface_vars: + surface_vars = dataset[self[category].surface_vars] - vars_atmosphere = [] - if self[dataset_name].atmosphere: - vars_atmosphere = xr.merge( + atmosphere_vars = [] + if self[category].atmosphere_vars: + atmosphere_vars = xr.merge( [ dataset[var] .sel(level=level, drop=True) .rename(f"{var}_{level}") - for var in self[dataset_name].atmosphere - for level in self[dataset_name].levels + for var in self[category].atmosphere_vars + for level in self[category].levels ] ) - if vars_surface and vars_atmosphere: - dataset = xr.merge([vars_surface, vars_atmosphere]) - elif vars_surface: - dataset = vars_surface - elif vars_atmosphere: - dataset = vars_atmosphere + if surface_vars and atmosphere_vars: + dataset = xr.merge([surface_vars, atmosphere_vars]) + elif surface_vars: + dataset = surface_vars + elif atmosphere_vars: + dataset = atmosphere_vars else: - print(f"No variables found in dataset {dataset_name}") + print(f"No variables found in dataset {category}") return None + zarr_configs = self.zarrs[category] + lat_lon_names = {} + if isinstance(self.zarrs[category], list): + for zarr_configs in self.zarrs[category]: + lat_lon_names.update(zarr_configs["lat_lon_names"]) + else: + lat_lon_names.update(self.zarrs[category]["lat_lon_names"].values) + if not all( - lat_lon in self.zarrs[dataset_name].dims.values.values() - for lat_lon in self.zarrs[ - dataset_name - ].lat_lon_names.values.values() + lat_lon in lat_lon_names.values() for lat_lon in lat_lon_names ): - lat_name = self.zarrs[dataset_name].lat_lon_names.lat - lon_name = self.zarrs[dataset_name].lat_lon_names.lon + lat_name, lon_name = lat_lon_names[:2] if dataset[lat_name].ndim == 2: dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True) if dataset[lon_name].ndim == 2: @@ -182,26 +256,15 @@ def process_dataset(self, dataset_name, split="train", stack=True): x=dataset[lon_name], y=dataset[lat_name] ) - if stack: - dataset = self.stack_grid(dataset) - + dataset = dataset.rename( + {v: k for k, v in dims_mapping.items() if v in dataset.coords} + ) + dataset = self.stack_grid(dataset) return dataset - def stack_grid(self, dataset): - """Stack grid dimensions.""" - dataset = dataset.squeeze().stack(grid=("x", "y")).to_array() + dataset = self.stack_grid(dataset) - if "time" in dataset.dims: - dataset = dataset.transpose("time", "grid", "variable") - else: - dataset = dataset.transpose("grid", "variable") return dataset - def get_nwp_xy(self): - """Get the x and y coordinates for the NWP grid.""" - x = self.process_dataset("static", stack=False).x.values - y = self.process_dataset("static", stack=False).y.values - xx, yy = np.meshgrid(y, x) - xy = np.stack((xx, yy), axis=0) - return xy +config = Config.from_file("neural_lam/data_config.yaml") diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index a4417a65..ff14a231 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -1,191 +1,96 @@ -zarrs: # List of zarrs containing fields related to state +name: danra +zarrs: state: - path: /scratch/sadamov/template.zarr # Path to zarr - dims: # Name of dimensions in zarr, to be used for indexing - time: time - level: z - x: x # Either give "grid" (flattened) dimension or "x" and "y" - y: y - lat_lon_names: - lon: lon - lat: lat + - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr" + dims: + time: time + level: null + x: x + y: y + grid: null + lat_lon_names: + lon: lon + lat: lat + - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr" + dims: + time: time + level: altitude + x: x + y: y + grid: null + lat_lon_names: + lon: lon + lat: lat static: - path: /scratch/sadamov/template.zarr + path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr" dims: - level: z + level: null x: x y: y + grid: null lat_lon_names: lon: lon lat: lat forcing: - path: /scratch/sadamov/template.zarr + path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr" dims: time: time - level: z + level: null x: x y: y + grid: null lat_lon_names: lon: lon lat: lat - boundary: - path: /scratch/sadamov/era5_template.zarr - dims: - time: time - level: level - x: longitude - y: latitude - lat_lon_names: - lon: longitude - lat: latitude - mask: boundary_mask # Name of variable containing boolean mask, true for grid nodes to be used in boundary. -state: # Variables forecasted by the model - surface: # Single-field variables - - CLCT - - PMSL - - PS - - T_2M - - TOT_PREC - - U_10M - - V_10M +state: + surface_vars: + - u10m + - v10m + - t2m surface_units: - - "%" - - $\mathrm{Pa}$ - - $\mathrm{Pa}$ - - $\mathrm{K}$ - - $\mathrm{kg}/\mathrm{m}^2$ - - $\mathrm{m}/\mathrm{s}$ - - $\mathrm{m}/\mathrm{s}$ - atmosphere: # Variables with vertical levels - - PP - - QV - - RELHUM - - T - - U - - V - - W + - m/s + - m/s + - K + atmosphere_vars: + - u + - v + - t atmosphere_units: - - $\mathrm{Pa}$ - - $\mathrm{kg}/\mathrm{kg}$ - - "%" - - $\mathrm{K}$ - - $\mathrm{m}/\mathrm{s}$ - - $\mathrm{m}/\mathrm{s}$ - - $\mathrm{Pa}/\mathrm{s}$ - levels: # Levels to use for atmosphere variables - - 0 - - 5 - - 8 - - 11 - - 13 - - 15 - - 19 - - 22 - - 26 - - 30 - - 38 - - 44 - - 59 -static: # Static inputs - surface: - - HSURF - atmosphere: - - FI - levels: - - 0 - - 5 - - 8 - - 11 - - 13 - - 15 - - 19 - - 22 - - 26 - - 30 - - 38 - - 44 - - 59 -forcing: # Forcing variables, dynamic inputs to the model - surface: - - ASOB_S - atmosphere: - - T + - m/s + - m/s + - K levels: - - 0 - - 5 - - 8 - - 11 - - 13 - - 15 - - 19 - - 22 - - 26 - - 30 - - 38 - - 44 - - 59 - window: 3 # Number of time steps to use for forcing (odd) -boundary: # Boundary conditions - surface: - - 10m_u_component_of_wind - # - 10m_v_component_of_wind - # - 2m_dewpoint_temperature - # - 2m_temperature - # - mean_sea_level_pressure - # - mean_surface_latent_heat_flux - # - mean_surface_net_long_wave_radiation_flux - # - mean_surface_net_short_wave_radiation_flux - # - mean_surface_sensible_heat_flux - # - surface_pressure - # - total_cloud_cover - # - total_column_water_vapour - # - total_precipitation_12hr - # - total_precipitation_24hr - # - total_precipitation_6hr - # - geopotential_at_surface - atmosphere: - - divergence - # - geopotential - # - relative_humidity - # - specific_humidity - # - temperature - # - u_component_of_wind - # - v_component_of_wind - # - vertical_velocity - # - vorticity - levels: - - 50 - 100 - - 150 - - 200 - - 250 - - 300 - - 400 - - 500 - - 600 - - 700 - - 850 - - 925 - - 1000 - window: 3 # Number of time steps to use for boundary (odd) +static: + surface_vars: + - pres0m # just as a technical test + atmosphere_vars: null + levels: null +forcing: + surface_vars: + - cape_column # just as a technical test + atmosphere_vars: null + levels: null + window: 3 # Number of time steps to use for forcing (odd) grid_shape_state: x: 582 y: 390 splits: train: - start: 2015-01-01T00 - end: 2024-12-31T23 + start: 1990-09-01T00 + end: 1990-09-11T00 val: - start: 2015-01-01T00 - end: 2024-12-31T23 + start: 1990-09-11T03 + end: 1990-09-13T09 test: - start: 2015-01-01T00 - end: 2024-12-31T23 + start: 1990-09-11T03 + end: 1990-09-13T09 projection: - class: RotatedPole # Name of class in cartopy.crs - kwargs: # Parsed and used directly as kwargs to projection-class above - pole_longitude: 10.0 - pole_latitude: -43.0 + class: LambertConformal # Name of class in cartopy.crs + kwargs: + central_longitude: 6.22 + central_latitude: 56.0 + standard_parallels: [47.6, 64.4] normalization: zarr: normalization.zarr vars: