Skip to content

Commit

Permalink
config.py is ready for danra
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed May 31, 2024
1 parent 59c4947 commit 4e457ed
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 232 deletions.
207 changes: 135 additions & 72 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,52 +60,109 @@ def coords_projection(self):
proj_params = proj_config.get("kwargs", {})
return proj_class(**proj_params)

@functools.cached_property
def param_names(self):
"""Return parameter names."""
surface_names = self.values["state"]["surface"]
atmosphere_names = [
surface_vars_names = self.values["state"]["surface_vars"]
atmosphere_vars_names = [
f"{var}_{level}"
for var in self.values["state"]["atmosphere"]
for var in self.values["state"]["atmosphere_vars"]
for level in self.values["state"]["levels"]
]
return surface_names + atmosphere_names
return surface_vars_names + atmosphere_vars_names

@functools.cached_property
def param_units(self):
"""Return parameter units."""
surface_units = self.values["state"]["surface_units"]
atmosphere_units = [
surface_vars_units = self.values["state"]["surface_vars_units"]
atmosphere_vars_units = [
unit
for unit in self.values["state"]["atmosphere_units"]
for unit in self.values["state"]["atmosphere_vars_units"]
for _ in self.values["state"]["levels"]
]
return surface_units + atmosphere_units
return surface_vars_units + atmosphere_vars_units

def num_data_vars(self, key):
"""Return the number of data variables for a given key."""
surface_vars = len(self.values[key]["surface"])
atmosphere_vars = len(self.values[key]["atmosphere"])
levels = len(self.values[key]["levels"])
return surface_vars + atmosphere_vars * levels
@functools.lru_cache()
def num_data_vars(self, category):
"""Return the number of data variables for a given category."""
surface_vars = self.values[category].get("surface_vars", [])
atmosphere_vars = self.values[category].get("atmosphere_vars", [])
levels = self.values[category].get("levels", [])

def projection(self):
"""Return the projection."""
proj_config = self.values["projections"]["class"]
proj_class = getattr(ccrs, proj_config["proj_class"])
proj_params = proj_config["proj_params"]
return proj_class(**proj_params)
surface_vars_count = (
len(surface_vars) if surface_vars is not None else 0
)
atmosphere_vars_count = (
len(atmosphere_vars) if atmosphere_vars is not None else 0
)
levels_count = len(levels) if levels is not None else 0

def open_zarr(self, dataset_name):
"""Open a dataset specified by the dataset name."""
dataset_path = self.zarrs[dataset_name].path
if dataset_path is None or not os.path.exists(dataset_path):
print(
f"Dataset '{dataset_name}' "
f"not found at path: {dataset_path}"
)
return None
dataset = xr.open_zarr(dataset_path, consolidated=True)
return surface_vars_count + atmosphere_vars_count * levels_count

@functools.lru_cache(maxsize=None)
def open_zarr(self, category):
"""Open a dataset specified by the category."""
zarr_config = self.zarrs[category]

if isinstance(zarr_config, list):
try:
datasets = []
for config in zarr_config:
dataset_path = config["path"]
dataset = xr.open_zarr(dataset_path, consolidated=True)
datasets.append(dataset)
return xr.merge(datasets)
except Exception:
print(f"Invalid zarr configuration for category: {category}")
return None

else:
try:
dataset_path = zarr_config["path"]
return xr.open_zarr(dataset_path, consolidated=True)
except Exception:
print(f"Invalid zarr configuration for category: {category}")
return None

def stack_grid(self, dataset):
"""Stack grid dimensions."""
dims = dataset.to_array().dims

if "grid" not in dims and "x" in dims and "y" in dims:
dataset = dataset.squeeze().stack(grid=("x", "y")).to_array()
else:
try:
dataset = dataset.squeeze().to_array()
except ValueError:
print("Failed to stack grid dimensions.")
return None

if "time" in dataset.dims:
dataset = dataset.transpose("time", "grid", "variable")
else:
dataset = dataset.transpose("grid", "variable")
return dataset

@functools.lru_cache()
def get_nwp_xy(self, category):
"""Get the x and y coordinates for the NWP grid."""
dataset = self.open_zarr(category)
lon_name = self.zarrs[category].lat_lon_names.lon
lat_name = self.zarrs[category].lat_lon_names.lat
if lon_name in dataset and lat_name in dataset:
lon = dataset[lon_name].values
lat = dataset[lat_name].values
else:
raise ValueError(
f"Dataset does not contain " f"{lon_name} or {lat_name}"
)
if lon.ndim == 1:
lon, lat = np.meshgrid(lat, lon)
lonlat = np.stack((lon, lat), axis=0)

return lonlat

@functools.cached_property
def load_normalization_stats(self):
"""Load normalization statistics from Zarr archive."""
normalization_path = self.normalization.zarr
Expand All @@ -120,10 +177,11 @@ def load_normalization_stats(self):
)
return normalization_stats

def process_dataset(self, dataset_name, split="train", stack=True):
@functools.lru_cache(maxsize=None)
def process_dataset(self, category, split="train"):
"""Process a single dataset specified by the dataset name."""

dataset = self.open_zarr(dataset_name)
dataset = self.open_zarr(category)
if dataset is None:
return None

Expand All @@ -132,48 +190,64 @@ def process_dataset(self, dataset_name, split="train", stack=True):
self.splits[split].end,
)
dataset = dataset.sel(time=slice(start, end))

dims_mapping = {}
zarr_configs = self.zarrs[category]
if isinstance(zarr_configs, list):
for zarr_config in zarr_configs:
dims_mapping.update(zarr_config["dims"])
else:
dims_mapping.update(zarr_configs["dims"].values)

dataset = dataset.rename_dims(
{
v: k
for k, v in self.zarrs[dataset_name].dims.values.items()
if k not in dataset.dims
for k, v in dims_mapping.items()
if k not in dataset.dims and v in dataset.dims
}
)
dataset = dataset.rename_vars(
{v: k for k, v in dims_mapping.items() if v in dataset.coords}
)

vars_surface = []
if self[dataset_name].surface:
vars_surface = dataset[self[dataset_name].surface]
surface_vars = []
if self[category].surface_vars:
surface_vars = dataset[self[category].surface_vars]

vars_atmosphere = []
if self[dataset_name].atmosphere:
vars_atmosphere = xr.merge(
atmosphere_vars = []
if self[category].atmosphere_vars:
atmosphere_vars = xr.merge(
[
dataset[var]
.sel(level=level, drop=True)
.rename(f"{var}_{level}")
for var in self[dataset_name].atmosphere
for level in self[dataset_name].levels
for var in self[category].atmosphere_vars
for level in self[category].levels
]
)

if vars_surface and vars_atmosphere:
dataset = xr.merge([vars_surface, vars_atmosphere])
elif vars_surface:
dataset = vars_surface
elif vars_atmosphere:
dataset = vars_atmosphere
if surface_vars and atmosphere_vars:
dataset = xr.merge([surface_vars, atmosphere_vars])
elif surface_vars:
dataset = surface_vars
elif atmosphere_vars:
dataset = atmosphere_vars
else:
print(f"No variables found in dataset {dataset_name}")
print(f"No variables found in dataset {category}")
return None

zarr_configs = self.zarrs[category]
lat_lon_names = {}
if isinstance(self.zarrs[category], list):
for zarr_configs in self.zarrs[category]:
lat_lon_names.update(zarr_configs["lat_lon_names"])
else:
lat_lon_names.update(self.zarrs[category]["lat_lon_names"].values)

if not all(
lat_lon in self.zarrs[dataset_name].dims.values.values()
for lat_lon in self.zarrs[
dataset_name
].lat_lon_names.values.values()
lat_lon in lat_lon_names.values() for lat_lon in lat_lon_names
):
lat_name = self.zarrs[dataset_name].lat_lon_names.lat
lon_name = self.zarrs[dataset_name].lat_lon_names.lon
lat_name, lon_name = lat_lon_names[:2]
if dataset[lat_name].ndim == 2:
dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True)
if dataset[lon_name].ndim == 2:
Expand All @@ -182,26 +256,15 @@ def process_dataset(self, dataset_name, split="train", stack=True):
x=dataset[lon_name], y=dataset[lat_name]
)

if stack:
dataset = self.stack_grid(dataset)

dataset = dataset.rename(
{v: k for k, v in dims_mapping.items() if v in dataset.coords}
)
dataset = self.stack_grid(dataset)
return dataset

def stack_grid(self, dataset):
"""Stack grid dimensions."""
dataset = dataset.squeeze().stack(grid=("x", "y")).to_array()
dataset = self.stack_grid(dataset)

if "time" in dataset.dims:
dataset = dataset.transpose("time", "grid", "variable")
else:
dataset = dataset.transpose("grid", "variable")
return dataset

def get_nwp_xy(self):
"""Get the x and y coordinates for the NWP grid."""
x = self.process_dataset("static", stack=False).x.values
y = self.process_dataset("static", stack=False).y.values
xx, yy = np.meshgrid(y, x)
xy = np.stack((xx, yy), axis=0)

return xy
config = Config.from_file("neural_lam/data_config.yaml")
Loading

0 comments on commit 4e457ed

Please sign in to comment.