From f7fd7992657347f670525023732f2d3a7aa7f7fd Mon Sep 17 00:00:00 2001 From: Maximilian Linhoff Date: Fri, 6 Sep 2024 13:47:37 +0200 Subject: [PATCH] Add special casing for from_name --- src/particle/particle/particle.py | 19 +++++++++++++++++-- tests/particle/test_particle.py | 25 +++++++++++++++++++++---- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/particle/particle/particle.py b/src/particle/particle/particle.py index b5397d3e..f41c7f9d 100644 --- a/src/particle/particle/particle.py +++ b/src/particle/particle/particle.py @@ -63,6 +63,14 @@ class InvalidParticle(RuntimeError): _NON_UNIQUE_PDGIDS[-pdgid1] = -pdgid2 _NON_UNIQUE_PDGIDS[-pdgid2] = -pdgid1 +# lookup for hash and from_name which representation is the preferred one +# this results in the "bag-of-quarks" representation +_PREFERRED_PDGID = {} +for pdgid1, pdgid2 in _NON_UNIQUE_PDGIDS.items(): + sign = -1 if pdgid1 < 0 else 1 + _PREFERRED_PDGID[pdgid1] = sign * min(abs(pdgid1), abs(pdgid2)) + _PREFERRED_PDGID[pdgid2] = _PREFERRED_PDGID[pdgid1] + def _isospin_converter(isospin: str) -> float | None: vals: dict[str | None, float | None] = { @@ -627,8 +635,8 @@ def __eq__(self, other: object) -> bool: return self.pdgid == other def __hash__(self) -> int: - if self.pdgid in _NON_UNIQUE_PDGIDS: - return hash(min(_NON_UNIQUE_PDGIDS[self.pdgid], self.pdgid)) + if self.pdgid in _PREFERRED_PDGID: + return hash(_PREFERRED_PDGID[self.pdgid]) return hash(self.pdgid) # Shared with PDGID @@ -1011,6 +1019,13 @@ def from_name(cls: type[Self], name: str) -> Self: ParticleNotFound If no particle matches the input name uniquely and exactly. """ + # special handling for the particles with two possible pdgids + if name in {"p", "n", "p~", "n~"}: + + def find_preferred_id(p: Self) -> bool: + return int(p.pdgid) in _PREFERRED_PDGID.values() + + return next(filter(find_preferred_id, cls.finditer(name=name))) try: (particle,) = cls.finditer( name=name diff --git a/tests/particle/test_particle.py b/tests/particle/test_particle.py index 87ba2bf9..316a7c59 100644 --- a/tests/particle/test_particle.py +++ b/tests/particle/test_particle.py @@ -711,10 +711,10 @@ def test_evtgen_name(name, pid): # noqa: ARG001 @pytest.mark.parametrize( ("pdgid1", "pdgid2"), [ - pytest.param(2212, 1000010010, id="proton"), - pytest.param(2112, 1000000010, id="neutron"), - pytest.param(-2212, -1000010010, id="anti-proton"), - pytest.param(-2112, -1000000010, id="anti-neutron"), + pytest.param(2212, 1000010010, id="p"), + pytest.param(2112, 1000000010, id="n"), + pytest.param(-2212, -1000010010, id="p~"), + pytest.param(-2112, -1000000010, id="n~"), ], ) def test_eq_non_unique_pdgids(pdgid1, pdgid2): @@ -725,3 +725,20 @@ def test_eq_non_unique_pdgids(pdgid1, pdgid2): assert p1.pdgid != p2.pdgid assert p1 == p2 assert hash(p1) == hash(p2) + + +@pytest.mark.parametrize( + ("name", "pdgid"), + [ + pytest.param("p", 2212, id="p"), + pytest.param("n", 2112, id="n"), + pytest.param("p~", -2212, id="p~"), + pytest.param("n~", -2112, id="n~"), + ], +) +def test_from_name_non_unique_pdgids(name, pdgid): + """The proton and the neutron have two pdgid representations, make sure they still compare equal""" + + p = Particle.from_name(name) + assert p.name == name + assert p.pdgid == pdgid