Skip to content

Commit

Permalink
pretend to load measurements, parse info, color and tooltips
Browse files Browse the repository at this point in the history
  • Loading branch information
JoschD committed Nov 14, 2023
1 parent c1ab766 commit 25ffd27
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 73 deletions.
15 changes: 10 additions & 5 deletions omc3_gui/segment_by_segment/controller.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from pathlib import Path
from typing import List
from omc3_gui.utils.base_classes import Controller
from omc3_gui.utils.dialogs import OpenDirectoriesDialog, OpenDirectoryDialog
from omc3_gui.utils.file_dialogs import OpenDirectoriesDialog, OpenDirectoryDialog
from omc3_gui.segment_by_segment.view import SbSWindow
from omc3_gui.segment_by_segment.model import Measurement, Settings
from omc3_gui.segment_by_segment.model import Settings
from qtpy.QtCore import Qt, Signal, Slot
from qtpy.QtWidgets import QFileDialog
from omc3_gui.segment_by_segment.measurement_model import OpticsMeasurement
import logging

LOG = logging.getLogger(__name__)
Expand All @@ -21,7 +23,7 @@ def __init__(self):
self._last_selected_optics_path = None


def add_measurement(self, measurement: Measurement):
def add_measurement(self, measurement: OpticsMeasurement):
self._view.get_measurement_list().add_item(measurement)


Expand All @@ -44,7 +46,10 @@ def open_measurements(self):
for filename in filenames:
self._last_selected_optics_path = filename.parent
LOG.debug(f"User selected: {filename}")
loaded_measurements.add_item(Measurement(filename))
optics_measurement = OpticsMeasurement.from_path(filename)
try:
loaded_measurements.add_item(optics_measurement)
except ValueError as e:
LOG.error(str(e))



165 changes: 165 additions & 0 deletions omc3_gui/segment_by_segment/measurement_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@

from dataclasses import dataclass, fields
from pathlib import Path
import types
from typing import Any, Dict, List, Sequence, Union, Optional, ClassVar
from accwidgets.graph import StaticPlotWidget
import pyqtgraph as pg
from omc3.segment_by_segment.segments import Segment
from omc3.definitions.optics import OpticsMeasurement as OpticsMeasurementCollection
from omc3.optics_measurements.constants import MODEL_DIRECTORY, KICK_NAME, PHASE_NAME, BETA_NAME, EXT
from omc3.model.constants import TWISS_DAT
from tfs.reader import read_headers
import logging


SEQUENCE = "SEQUENCE"
DATE = "DATE"
LHC_MODEL_YEARS = (2012, 2015, 2016, 2017, 2018, 2022, 2023) # TODO: get from omc3

FILES_TO_LOOK_FOR = (f"{name}{plane}" for name in (KICK_NAME, PHASE_NAME, BETA_NAME) for plane in ("x", "y"))

LOGGER = logging.getLogger(__name__)

@dataclass
class OpticsMeasurement:
measurement_dir: Path # Path to the optics-measurement folder
model_dir: Path = None # Path to the model folder
accel: str = None # Name of the accelerator
output_dir: Optional[Path] = None # Path to the sbs-output folder
elements: Optional[Dict[str, Segment]] = None # List of elements
segments: Optional[Dict[str, Segment]] = None # List of segments
year: Optional[str] = None # Year of the measurement (accelerator)
ring: Optional[int] = None # Ring of the accelerator
beam: Optional[int] = None # LHC-Beam (not part of SbS input!)

DEFAULT_OUTPUT_DIR: ClassVar[str] = "sbs"

def __post_init__(self):
if self.output_dir is None:
self.output_dir = self.measurement_dir / self.DEFAULT_OUTPUT_DIR

def display(self) -> str:
return str(self.measurement_dir.name)

def tooltip(self) -> str:
parts = (
("Optics Measurement", self.measurement_dir),
("Model", self.model_dir),
("Accelerator", self.accel),
("Beam", self.beam),
("Year", self.year),
("Ring", self.ring),
)
l = max(len(name) for name, _ in parts)
return "\n".join(f"{name:{l}s}: {value}" for name, value in parts if value is not None)

@classmethod
def from_path(cls, path: Path) -> "OpticsMeasurement":
""" Creates an OpticsMeasurement from a folder, by trying
to parse information from the data in the folder.
Args:
path (Path): Path to the folder.
Returns:
OpticsMeasurement: OpticsMeasurement instance.
"""
model_dir = None
info = {}
try:
model_dir = _parse_model_dir_from_optics_measurement(path)
except FileNotFoundError as e:
LOGGER.error(str(e))
else:
info = _parse_info_from_model_dir(model_dir)

meas = cls(measurement_dir=path, model_dir=model_dir, **info)
if (
any(getattr(meas, name) is None for name in ("model_dir", "accel", "output_dir"))
or (meas.accel == 'lhc' and (meas.year is None or meas.beam is None))
or (meas.accel == 'psb' and meas.ring is None)
):
LOGGER.error(f"Info parsed from measurement folder '{path!s}' is incomplete. Adjust manually!!")
return meas


def _parse_model_dir_from_optics_measurement(measurement_path: Path) -> Path:
"""Tries to find the model directory in the headers of one of the optics measurement files.
Args:
measurement_path (Path): Path to the folder.
Returns:
Path: Path to the (associated) model directory.
"""
LOGGER.debug(f"Searching for model dir in {measurement_path!s}")
for file_name in FILES_TO_LOOK_FOR:
LOGGER.debug(f"Checking {file_name!s} for model dir.")
try:
headers = read_headers((measurement_path / file_name).with_suffix(EXT))
except FileNotFoundError:
LOGGER.debug(f"{file_name!s} not found in {measurement_path!s}.")
else:
if MODEL_DIRECTORY in headers:
LOGGER.debug(f"{MODEL_DIRECTORY!s} found in {file_name!s}!")
break

LOGGER.debug(f"{MODEL_DIRECTORY!s} not found in {file_name!s}.")
else:
raise FileNotFoundError(f"Could not find '{MODEL_DIRECTORY}' in any of {FILES_TO_LOOK_FOR!r} in {measurement_path!r}")
path = Path(headers[MODEL_DIRECTORY])
LOGGER.debug(f"Associated model dir found: {path!s}")
return path


def _parse_info_from_model_dir(model_dir: Path) -> Dict[str, Any]:
""" Checking twiss.dat for more info about the accelerator.
Args:
model_dir (Path): Path to the model-directory.
Returns:
Dict[str, Any]: Containing the additional info found (accel, beam, year, ring).
"""
result = {}

try:
headers = read_headers(model_dir / TWISS_DAT)
except FileNotFoundError as e:
LOGGER.debug(str(e))
return result

sequence = headers.get(SEQUENCE)
if sequence is not None:
sequence = sequence.lower()
if "lhc" in sequence:
result['accel'] = "lhc"
result['beam'] = int(sequence[-1])
result['year'] = _get_lhc_model_year(headers.get(DATE))
elif "psb" in sequence:
result['accel'] = "psb"
result['ring'] = int(sequence[-1])
else:
result['accel'] = sequence
LOGGER.debug(f"Associated info found in model dir '{model_dir!s}':\n {result!s}")
return result


def _get_lhc_model_year(date: Union[str, None]) -> Union[str, None]:
""" Parses the year from the date in the LHC twiss.dat file
and tries to find the closest model-year."""
if date is None:
return None
try:
found_year = int(f"20{date.split('/')[-1]}")
except ValueError:
LOGGER.debug(f"Could not parse year from '{date}'!")
return None

for year in sorted(LHC_MODEL_YEARS, reverse=True):
if year <= found_year:
return str(year)

LOGGER.debug(f"Could not parse year from '{date}'!")
return None
Empty file.
78 changes: 45 additions & 33 deletions omc3_gui/segment_by_segment/model.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,25 @@
from dataclasses import dataclass
from pathlib import Path
import enum
import types
from typing import Any, Dict, List, Sequence, Union
from accwidgets.graph import StaticPlotWidget
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Union

import pyqtgraph as pg
from accwidgets.graph import StaticPlotWidget
from omc3.segment_by_segment.segments import Segment
from qtpy import QtCore, QtWidgets
from qtpy.QtCore import QModelIndex, Qt

from typing import List
from qtpy import QtCore
from qtpy.QtCore import Qt


@dataclass
class Measurement:
measurement_dir: Path
output_dir: Path = None
elements: Dict[str, Segment] = None
segments: Dict[str, Segment] = None
model_dir: Path = None
accel: str = None
year: str = None
ring: int = None

def __str__(self):
return str(self.measurement_dir)

from omc3_gui.segment_by_segment.measurement_model import OpticsMeasurement


@dataclass
class Settings:
pass



class ItemDictModel:
""" Mixin-Class for a class that has a dictionary of items. """

def __init__(self):
self.items = {}
Expand All @@ -42,8 +28,9 @@ def try_emit(self, emit: bool = True):
if not emit:
return

if hasattr(self, "layoutChanged"):
self.layoutChanged.emit()
if hasattr(self, "dataChanged"):
# TODO: return which data has actually changed?
self.dataChanged.emit(self.index(0), self.index(len(self.items)-1), [Qt.EditRole])

def update_item(self, item):
self.items[str(item)] = item
Expand Down Expand Up @@ -84,24 +71,49 @@ def get_item_at(self, index: int) -> Any:
return list(self.items.values())[index]



class MeasurementListModel(QtCore.QAbstractListModel, ItemDictModel):

items: Dict[str, Measurement] # for the IDE
items: Dict[str, OpticsMeasurement] # for the IDE

class ColorRoles(enum.IntEnum):
NONE = 0
BEAM1 = enum.auto()
BEAM2 = enum.auto()
RING1 = enum.auto()
RING2 = enum.auto()
RING3 = enum.auto()
RING4 = enum.auto()

@classmethod
def get_color(cls, meas: OpticsMeasurement) -> int:
if meas.accel == "lhc":
return getattr(cls, f"BEAM{meas.beam}")

if meas.accel == "psb":
return getattr(cls, f"RING{meas.ring}")

return cls.NONE

def __init__(self, *args, **kwargs):
super(QtCore.QAbstractListModel, self).__init__(*args, **kwargs)
super(ItemDictModel, self).__init__()

def data(self, index: QtCore.QModelIndex, role: int = Qt.DisplayRole):
meas: Measurement = self.get_item_at(index.row())

meas: OpticsMeasurement = self.get_item_at(index.row())
if role == Qt.DisplayRole: # https://doc.qt.io/qt-5/qt.html#ItemDataRole-enum
return str(meas)

return meas.display()

if role == Qt.ToolTipRole:
return meas.tooltip()

if role == Qt.TextColorRole:
return self.ColorRoles.get_color(meas)

if role == Qt.EditRole:
return meas

def rowCount(self, index):
def rowCount(self, index: QtCore.QModelIndex = None):
return len(self.items)


Expand All @@ -128,7 +140,7 @@ def rowCount(self, parent=QtCore.QModelIndex()):
def columnCount(self, parent=QtCore.QModelIndex()):
return len(self._COLUMNS)

def data(self, index, role=QtCore.Qt.DisplayRole):
def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole):
i = index.row()
j = index.column()
segment: Segment = self.get_item_at(i)
Expand Down
Loading

0 comments on commit 25ffd27

Please sign in to comment.