Skip to content

Commit

Permalink
Merge pull request #2 from ukaea/fix-defuse-sources
Browse files Browse the repository at this point in the history
Fix XSX and update XDC source for DEFUSE
  • Loading branch information
jameshod5 committed Aug 15, 2024
2 parents 2872c18 + 0d774e5 commit 5c3821d
Show file tree
Hide file tree
Showing 5 changed files with 407 additions and 346 deletions.
7 changes: 3 additions & 4 deletions jobs/freia_write_datasets.qsub
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@ num_workers=$3
export PATH="/home/rt2549/dev/:$PATH"

random_string=$(head /dev/urandom | tr -dc A-Za-z0-9 | head -c 16)

temp_dir="/common/tmp/sjackson/local_cache/$random_string"
metadata_dir="/common/tmp/sjackson/data/uda/"

# Run script
# --force --source_names abm ada adg aga ahx aim air ait alp ama amb amc amh amm ams anb ane ant anu aoe arp asb asm asx ayc aye efm esm esx rba rbb rbc rca rco rgb rgc rir rit xmo xpc xsx
# --force --source_names amc ayc efm xmo xsx
time mpirun -np $num_workers \
python3 -m src.archive.main $temp_dir $summary_file $bucket_path --force \
--source_names ${@:4}
python3 -m src.main $temp_dir $summary_file --metadata_dir $metadata_dir --bucket_path $bucket_path --file_format zarr --upload --force --source_names ${@:4}
80 changes: 60 additions & 20 deletions src/task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from pathlib import Path
import os
import xarray as xr
import pandas as pd
import traceback
import shutil
import subprocess
import logging
from pathlib import Path
import xarray as xr
import pandas as pd

from src.transforms import MASTPipelineRegistry, MASTUPipelineRegistry
from src.mast import MASTClient
Expand Down Expand Up @@ -83,30 +84,68 @@ def __init__(
self.pipelines = MASTUPipelineRegistry()

def __call__(self):
signal_infos = self.read_signal_info()
source_infos = self.read_source_info()

if len(self.signal_names) > 0:
signal_infos = signal_infos.loc[signal_infos.name.isin(self.signal_names)]
try:
self._main()
except Exception as e:
trace = traceback.format_exc()
logging.error(f"Error reading sources for shot {self.shot}: {e}\n{trace}")

def _main(self):
signal_infos, source_infos = self._read_metadata()

if len(self.source_names) > 0:
signal_infos = signal_infos.loc[signal_infos.source.isin(self.source_names)]
if signal_infos is None or signal_infos is None:
return

signal_infos = self._filter_signals(signal_infos)

self.writer.write_metadata()

for key, group_index in signal_infos.groupby("source").groups.items():
signal_infos_for_source = signal_infos.loc[group_index]
signal_datasets = self.load_source(signal_infos_for_source)
pipeline = self.pipelines.get(key)
dataset = pipeline(signal_datasets)
source_info = source_infos.loc[source_infos["name"] == key].iloc[0]
source_info = source_info.to_dict()
dataset.attrs.update(source_info)
self.writer.write_dataset(dataset)
for source_name, source_group_index in signal_infos.groupby("source").groups.items():
source_info = self._get_source_metadata(source_name, source_infos)
signal_infos_for_source = self._get_signals_for_source(source_name, source_group_index, signal_infos)
self._process_source(source_name, signal_infos_for_source, source_info)

self.writer.consolidate_dataset()

def _process_source(self, source_name: str, signal_infos: pd.DataFrame, source_info: dict):
signal_datasets = self.load_source(signal_infos)
pipeline = self.pipelines.get(source_name)
dataset = pipeline(signal_datasets)
dataset.attrs.update(source_info)
self.writer.write_dataset(dataset)

def _get_source_metadata(self, source_name, source_infos: pd.DataFrame) -> dict:
source_info = source_infos.loc[source_infos["name"] == source_name].iloc[0]
source_info = source_info.to_dict()
return source_info

def _get_signals_for_source(self, source_name: str, source_group_index: pd.Series, signal_infos: pd.DataFrame):
signal_infos_for_source = signal_infos.loc[source_group_index]
if source_name == 'xdc':
signal_infos_for_source = signal_infos_for_source.loc[signal_infos_for_source.name == 'xdc/ip_t_ipref']
return signal_infos_for_source

def _read_metadata(self) -> tuple[pd.DataFrame, pd.DataFrame]:
try:
signal_infos = self.read_signal_info()
source_infos = self.read_source_info()
except FileNotFoundError:
message = f"Could not find source/signal metadata file for shot {self.shot}"
logging.warning(message)
return None, None

return signal_infos, source_infos


def _filter_signals(self, signal_infos: pd.DataFrame) -> pd.DataFrame:
if len(self.signal_names) > 0:
signal_infos = signal_infos.loc[signal_infos.name.isin(self.signal_names)]

if len(self.source_names) > 0:
signal_infos = signal_infos.loc[signal_infos.source.isin(self.source_names)]

return signal_infos

def load_source(self, group: pd.DataFrame) -> dict[str, xr.Dataset]:
datasets = {}
for _, info in group.iterrows():
Expand All @@ -126,7 +165,8 @@ def load_source(self, group: pd.DataFrame) -> dict[str, xr.Dataset]:
shot_num=self.shot, name=info["uda_name"]
)
except Exception as e:
logging.error(f"Error reading dataset {name} for shot {self.shot}: {e}")
uda_name = info["uda_name"]
logging.warning(f"Could not read dataset {name} ({uda_name}) for shot {self.shot}: {e}")
continue

dataset.attrs.update(info)
Expand Down
Loading

0 comments on commit 5c3821d

Please sign in to comment.