Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support passing though dm, nicer typing #400

Merged
merged 2 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ repos:
- id: end-of-file-fixer

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.1.5"
rev: "v0.1.6"
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.1
rev: v1.7.1
hooks:
- id: mypy
files: '^src/decaylanguage/(decay|dec|utils)/'
Expand Down
29 changes: 16 additions & 13 deletions src/decaylanguage/dec/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

from .. import data
from .._compat.typing import Self
from ..decay.decay import _expand_decay_modes
from ..decay.decay import DecayModeDict, _expand_decay_modes
from ..utils import charge_conjugate_name
from .enums import PhotosEnum

Expand Down Expand Up @@ -747,7 +747,7 @@ def list_decay_modes(self, mother: str, pdg_name: bool = False) -> list[list[str

def _decay_mode_details(
self, decay_mode: Tree, display_photos_keyword: bool = True
) -> tuple[float, list[str], str, str | list[str | Any]]:
) -> DecayModeDict:
"""
Parse a decay mode (Tree instance)
and return the relevant bits of information in it.
Expand All @@ -768,7 +768,9 @@ def _decay_mode_details(
if display_photos_keyword and list(decay_mode.find_data("photos")):
model = "PHOTOS " + model

return (bf, fsp_names, model, model_params)
return DecayModeDict(
bf=bf, fs=fsp_names, model=model, model_params=model_params
)

def print_decay_modes(
self,
Expand Down Expand Up @@ -864,11 +866,13 @@ def print_decay_modes(

ls_dict = {}
for dm in dms:
bf, fsp_names, model, model_params = self._decay_mode_details(
dm, display_photos_keyword
dmdict = self._decay_mode_details(dm, display_photos_keyword)
model_params = [str(i) for i in dmdict["model_params"]]
ls_dict[dmdict["bf"]] = (
dmdict["fs"],
dmdict["model"],
model_params,
)
model_params = [str(i) for i in model_params]
ls_dict[bf] = (fsp_names, model, model_params)

dec_details = list(ls_dict.values())
ls_attrs_aligned = list(
Expand Down Expand Up @@ -937,7 +941,7 @@ def build_decay_chains(
self,
mother: str,
stable_particles: list[str] | set[str] | tuple[str] | tuple[()] = (),
) -> dict[str, list[dict[str, float | str | list[Any]]]]:
) -> dict[str, list[DecayModeDict]]:
"""
Iteratively build the entire decay chains of a given mother particle,
optionally considering, on the fly, certain particles as stable.
Expand Down Expand Up @@ -992,14 +996,12 @@ def build_decay_chains(
>>> p.build_decay_chains('D+', stable_particles=['pi0']) # doctest: +SKIP
{'D+': [{'bf': 1.0, 'fs': ['K-', 'pi+', 'pi+', 'pi0'], 'model': 'PHSP', 'model_params': ''}]}
"""
keys = ("bf", "fs", "model", "model_params")

info = []
for dm in self._find_decay_modes(mother):
list_dm_details = self._decay_mode_details(dm, display_photos_keyword=False)
d = dict(zip(keys, list_dm_details))
d = self._decay_mode_details(dm, display_photos_keyword=False)

for i, fs in enumerate(d["fs"]): # type: ignore[arg-type, var-annotated]
for i, fs in enumerate(d["fs"]):
if fs in stable_particles:
continue

Expand All @@ -1008,14 +1010,15 @@ def build_decay_chains(
# if fs does not have decays defined in the parsed file
# _n_dms = len(self._find_decay_modes(fs))

assert isinstance(fs, str)
_info = self.build_decay_chains(fs, stable_particles)
d["fs"][i] = _info # type: ignore[index]
except DecayNotFound:
pass

info.append(d)

return {mother: info} # type: ignore[dict-item]
return {mother: info}

def __repr__(self) -> str:
if self._parsed_dec_file is not None:
Expand Down
44 changes: 27 additions & 17 deletions src/decaylanguage/decay/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
from __future__ import annotations

from collections import Counter
from collections.abc import Sequence
from copy import deepcopy
from itertools import product
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List

from particle import PDGID, ParticleNotFound
from particle.converters import EvtGenName2PDGIDBiMap
from particle.exceptions import MatchingIDNotFound

from .._compat.typing import Self
from .._compat.typing import Self, TypedDict
from ..utils import DescriptorFormat, charge_conjugate_name

if TYPE_CHECKING:
Expand All @@ -24,6 +25,16 @@
CounterStr = Counter


class DecayModeDict(TypedDict):
bf: float
fs: Sequence[str | DecayChainDict]
model: str
model_params: str | Sequence[str | Any]


DecayChainDict = Dict[str, List[DecayModeDict]]


class DaughtersDict(CounterStr):
"""
Class holding a decay final state as a dictionary.
Expand Down Expand Up @@ -187,8 +198,12 @@ class DecayMode:
def __init__(
self,
bf: float = 0,
daughters: None
| (DaughtersDict | dict[str, int] | list[str] | tuple[str] | str) = None,
daughters: DaughtersDict
| dict[str, int]
| list[str]
| tuple[str]
| str
| None = None,
**info: Any,
) -> None:
"""
Expand Down Expand Up @@ -241,6 +256,8 @@ def __init__(
True
"""
self.bf = bf
if daughters is None and "fs" in info:
daughters = info.pop("fs")
self.daughters = DaughtersDict(daughters)

self.metadata: dict[str, str | None] = {"model": "", "model_params": ""}
Expand All @@ -249,7 +266,7 @@ def __init__(
@classmethod
def from_dict(
cls,
decay_mode_dict: dict[str, int | float | str | list[str]],
decay_mode_dict: DecayModeDict,
) -> Self:
"""
Constructor from a dictionary of the form
Expand Down Expand Up @@ -285,13 +302,10 @@ def from_dict(
dm = deepcopy(decay_mode_dict)

# Ensure the input dict has the 2 required keys 'bf' and 'fs'
try:
bf = dm.pop("bf")
daughters = dm.pop("fs")
except KeyError as e:
raise RuntimeError("Input not in the expected format!") from e
if not dm.keys() >= {"bf", "fs"}:
raise RuntimeError("Input not in the expected format! Needs 'bf' and 'fs'")

return cls(bf=bf, daughters=daughters, **dm) # type: ignore[arg-type]
return cls(**dm)

@classmethod
def from_pdgids(
Expand Down Expand Up @@ -436,10 +450,6 @@ def __str__(self) -> str:
return repr(self)


DecayModeDict = Dict[str, Union[float, str, List[Any]]]
DecayChainDict = Dict[str, List[DecayModeDict]]


def _has_no_subdecay(ds: list[Any]) -> bool:
"""
Internal function to check whether the input list
Expand Down Expand Up @@ -894,9 +904,9 @@ def _print(
for i_decay in decay_dict[mother]:
print(prefix, arrow if depth > 0 else "", mother, sep="") # noqa: T201
fsps = i_decay["fs"]
n = len(list(fsps)) # type: ignore[arg-type]
n = len(list(fsps))
depth += 1
for j, fsp in enumerate(fsps): # type: ignore[arg-type]
for j, fsp in enumerate(fsps):
prefix = bar if (link and depth > 1) else ""
if last:
prefix = prefix + " " * indent * (depth - 1) + " "
Expand Down
7 changes: 6 additions & 1 deletion tests/dec/test_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,12 @@ def test_decay_mode_details():
p.parse()

tree_Dp = p._find_decay_modes("D+")[0]
output = (1.0, ["K-", "pi+", "pi+", "pi0"], "PHSP", "")
output = {
"bf": 1.0,
"fs": ["K-", "pi+", "pi+", "pi0"],
"model": "PHSP",
"model_params": "",
}
assert p._decay_mode_details(tree_Dp, display_photos_keyword=False) == output


Expand Down