Skip to content

Commit

Permalink
Add special casing for from_name
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnoe committed Sep 6, 2024
1 parent 0052da4 commit f7fd799
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
19 changes: 17 additions & 2 deletions src/particle/particle/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ class InvalidParticle(RuntimeError):
_NON_UNIQUE_PDGIDS[-pdgid1] = -pdgid2
_NON_UNIQUE_PDGIDS[-pdgid2] = -pdgid1

# lookup for hash and from_name which representation is the preferred one
# this results in the "bag-of-quarks" representation
_PREFERRED_PDGID = {}
for pdgid1, pdgid2 in _NON_UNIQUE_PDGIDS.items():
sign = -1 if pdgid1 < 0 else 1
_PREFERRED_PDGID[pdgid1] = sign * min(abs(pdgid1), abs(pdgid2))
_PREFERRED_PDGID[pdgid2] = _PREFERRED_PDGID[pdgid1]


def _isospin_converter(isospin: str) -> float | None:
vals: dict[str | None, float | None] = {
Expand Down Expand Up @@ -627,8 +635,8 @@ def __eq__(self, other: object) -> bool:
return self.pdgid == other

def __hash__(self) -> int:
if self.pdgid in _NON_UNIQUE_PDGIDS:
return hash(min(_NON_UNIQUE_PDGIDS[self.pdgid], self.pdgid))
if self.pdgid in _PREFERRED_PDGID:
return hash(_PREFERRED_PDGID[self.pdgid])
return hash(self.pdgid)

# Shared with PDGID
Expand Down Expand Up @@ -1011,6 +1019,13 @@ def from_name(cls: type[Self], name: str) -> Self:
ParticleNotFound
If no particle matches the input name uniquely and exactly.
"""
# special handling for the particles with two possible pdgids
if name in {"p", "n", "p~", "n~"}:

def find_preferred_id(p: Self) -> bool:
return int(p.pdgid) in _PREFERRED_PDGID.values()

return next(filter(find_preferred_id, cls.finditer(name=name)))
try:
(particle,) = cls.finditer(
name=name
Expand Down
25 changes: 21 additions & 4 deletions tests/particle/test_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,10 +711,10 @@ def test_evtgen_name(name, pid): # noqa: ARG001
@pytest.mark.parametrize(
("pdgid1", "pdgid2"),
[
pytest.param(2212, 1000010010, id="proton"),
pytest.param(2112, 1000000010, id="neutron"),
pytest.param(-2212, -1000010010, id="anti-proton"),
pytest.param(-2112, -1000000010, id="anti-neutron"),
pytest.param(2212, 1000010010, id="p"),
pytest.param(2112, 1000000010, id="n"),
pytest.param(-2212, -1000010010, id="p~"),
pytest.param(-2112, -1000000010, id="n~"),
],
)
def test_eq_non_unique_pdgids(pdgid1, pdgid2):
Expand All @@ -725,3 +725,20 @@ def test_eq_non_unique_pdgids(pdgid1, pdgid2):
assert p1.pdgid != p2.pdgid
assert p1 == p2
assert hash(p1) == hash(p2)


@pytest.mark.parametrize(
("name", "pdgid"),
[
pytest.param("p", 2212, id="p"),
pytest.param("n", 2112, id="n"),
pytest.param("p~", -2212, id="p~"),
pytest.param("n~", -2112, id="n~"),
],
)
def test_from_name_non_unique_pdgids(name, pdgid):
"""The proton and the neutron have two pdgid representations, make sure they still compare equal"""

p = Particle.from_name(name)
assert p.name == name
assert p.pdgid == pdgid

0 comments on commit f7fd799

Please sign in to comment.