diff --git a/jdaviz/app.py b/jdaviz/app.py index e06a26924d..5bcbb01c76 100644 --- a/jdaviz/app.py +++ b/jdaviz/app.py @@ -1,3 +1,6 @@ +from datetime import datetime +import glob +import json import os import pathlib import re @@ -281,6 +284,9 @@ def __init__(self, configuration=None, *args, **kwargs): # can reference their state easily since glue does not store viewers self._viewer_store = {} + # Record the origins of the data loaded into data_collection + self._data_origins = {} + # Parse the yaml configuration file used to compose the front-end UI self.load_configuration(configuration) @@ -647,6 +653,19 @@ def load_data(self, file_obj, parser_reference=None, **kwargs): finally: self.loading = False + # Record where this data came from + if "data_label" not in kwargs or kwargs["data_label"] is None: + data_label = self.data_collection.labels[-1] + else: + data_label = kwargs["data_label"] + + if not isinstance(file_obj, str): + file_obj = f"Object: {str(type(file_obj))}" + + self._data_origins[data_label] = {"file_obj": file_obj, + "parser_reference": parser_reference, + "kwargs": kwargs} + def get_viewer(self, viewer_reference): """ Return a `~glue_jupyter.bqplot.common.BqplotBaseView` viewer instance. @@ -2570,6 +2589,123 @@ def _reset_state(self): self.state = ApplicationState() self._application_handler._tools = {} + def save_state(self, output_file=None, save_dir=None): + """ + Saves the application's glue-related state (viewers, layers, subsets, etc) + to a json file + + Parameters + ---------- + output_file: str, optional + File in which to save the output json. Defaults to a timestamped-filename. + save_dir: str, optional + Directory in which to save the output file. Defaults to making a ``saved_states`` + directory in the location where Jdaviz is being run. + """ + if save_dir is None: + save_dir = "saved_states" + pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True) + + if output_file is None: + output_file = f"glue-state-{datetime.now().strftime("%Y%m%d-%H:%M:%S")}.json" + output_file = pathlib.Path(save_dir, output_file) + + output_json = {"viewers": {}, "subsets": {}, "data_origins": self._data_origins} + + # Get link type + if self.config == "imviz": + link_plugin = self._jdaviz_helper.plugins['Links Control'] + output_json["link_type"] = link_plugin.link_type.selected + + # Get states for viewers+layers, not including subset layers. + for viewer_id, viewer in self._viewer_store.items(): + viewer_dict = viewer.state.as_dict() + _ = viewer_dict.pop('layers') + for key in viewer_dict: + if hasattr(viewer_dict[key], "label"): + viewer_dict[key] = viewer_dict[key].label + + output_json['viewers'][viewer_id] = {"state": viewer_dict, + "layers": {}} + + for layer in viewer.layers: + layer_dict = layer.state.as_dict() + # Need to remove the Data from the state dict, since we can't json-serialize it. + for key in layer_dict: + if hasattr(layer_dict[key], "label"): + layer_dict[key] = layer_dict[key].label + elif hasattr(layer_dict[key], "name"): + layer_dict[key] = layer_dict[key].name + output_json['viewers'][viewer_id]["layers"][layer.label] = layer_dict + + print(json.dumps(output_json, indent=2)) + + with open(output_file, "w") as f: + json.dump(output_json, f, indent=2) + + def restore_state(self, input_json=None, save_dir=None): + glue_to_plot_option = { + "v_min": "stretch_vmin", + "v_max": "stretch_vmax", + "stretch": "stretch_function", + "cmap": "image_colormap", + "visible": "image_visible" + } + if save_dir is None: + save_dir = "saved_states" + if input_json is None: + save_files = glob.glob(f"{save_dir}/glue-state*.json") + save_files.sort() + input_json=save_files[-1] + + with open(input_json, "r") as f: + state_json = json.load(f) + + # Determine if same data was loaded, error if not and we can't find the data on disk. + for data_label, origin_dict in state_json["data_origins"].items(): + if data_label in self.data_collection.labels: + continue + if origin_dict["file_obj"][:7] == "Object:": + print(f"Data loaded from {origin_dict["file_obj"]} is not in data_collection and" + " could not be automatically loaded.") + else: + # Need to replace null data label with whatever the parser came up with + origin_dict["kwargs"]["data_label"] = data_label.split("[")[0] + # Load the data! + self.load_data(origin_dict["file_obj"], + parser_reference=origin_dict["parser_reference"], + **origin_dict["kwargs"]) + + # Create any viewers that do not exist for Imviz + if self.config == "imviz": + for viewer_id in state_json['viewers']: + if viewer_id not in self._viewer_store: + self._jdaviz_helper.create_image_viewer(viewer_id) + + # Load/unload data into appropriate viewers and set plot options + plot_options = self._jdaviz_helper.plugins['Plot Options'] + for viewer_id, viewer in self._viewer_store.items(): + plot_options.viewer = viewer_id + layer_labels = [layer.label for layer in viewer.layers] + saved_viewer_state = state_json['viewers'][viewer_id] + for layer_id, layer_dict in saved_viewer_state['layers'].items(): + data_label = layer_id.split(" ")[0] + if data_label not in self.data_collection.labels: + print(f"Couldn't find or load {data_label}.") + continue + data_label = layer_id.split(" ")[0] + if data_label not in layer_labels: + self.add_data_to_viewer(viewer_id, data_label) + plot_options.layer = data_label + for glue_att, plot_att in glue_to_plot_option.items(): + setattr(plot_options, plot_att, layer_dict[glue_att]) + + # Update link type if we're in Imviz + if self.config == "imviz": + link_plugin = self._jdaviz_helper.plugins['Links Control'] + link_plugin.link_type = state_json["link_type"] + + def get_configuration(self, path=None, section=None): """Returns a copy of the application configuration.