Skip to content

Commit

Permalink
Merge pull request #147 from jacanchaplais/feature/unstable-145
Browse files Browse the repository at this point in the history
Explicit numerical instability handling #145
  • Loading branch information
jacanchaplais authored Sep 3, 2023
2 parents 49d9aaf + b2123b4 commit 7c4eac6
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 28 deletions.
7 changes: 7 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,10 @@ Algorithms to query event record, providing masks which select specific
regions of collision events.

.. python-apigen-group:: select

Exceptions and warnings
-------------------
Custom classes indicating to users specific issues relating to the unique
cross-over between high energy physics and graph algorithms.

.. python-apigen-group:: errors_warnings
11 changes: 11 additions & 0 deletions graphicle/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"AdjacencyBase",
"MaskBase",
"MaskLike",
"NumericalStabilityWarning",
]

DoubleVector = npt.NDArray[np.float64]
Expand Down Expand Up @@ -248,3 +249,13 @@ def __invert__(self) -> "MaskBase":
@abstractmethod
def __bool__(self) -> bool:
pass


class NumericalStabilityWarning(UserWarning):
"""Raised when the result of a calculation may not be numerically
stable.
:group: errors_warnings
.. versionadded:: 0.3.1
"""
36 changes: 34 additions & 2 deletions graphicle/calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,37 @@ def flow_trace(
return traces


@nb.njit("float64[:](float64[:], float64[:], float64)")
def _rapidity(
energy: base.DoubleVector, z: base.DoubleVector, zero_tol: float
) -> base.DoubleVector:
"""Numpy ufunc to calculate the rapidity of a set of particles.
Parameters
----------
energy, z : array_like
Components of the particles' four-momenta.
zero_tol : float
Absolute tolerance for energy values to be considered close to
zero.
Returns
-------
ndarray or float
Rapidity of the particles.
"""
rap = np.empty_like(energy)
for i in range(len(rap)):
z_ = abs(z[i])
diff = energy[i] - z_
if abs(diff) < zero_tol:
rap_ = math.inf
else:
rap_ = 0.5 * math.log((energy[i] + z_) / diff)
rap[i] = math.copysign(rap_, z[i])
return rap


@nb.vectorize([nb.float64(nb.float64, nb.float64)])
def _root_diff_two_squares(
x1: base.DoubleUfunc, x2: base.DoubleUfunc
Expand Down Expand Up @@ -441,9 +472,10 @@ def _root_diff_two_squares(
Root difference of two squares. This is a scalar if both `x1`
and `x2` are scalars.
"""
diff = x1 - x2
x1_, x2_ = abs(x1), abs(x2)
diff = x1_ - x2_
sqrt_diff = math.copysign(math.sqrt(abs(diff)), diff)
sqrt_sum = math.sqrt(x1 + x2)
sqrt_sum = math.sqrt(x1_ + x2_)
return sqrt_diff * sqrt_sum # type: ignore


Expand Down
55 changes: 40 additions & 15 deletions graphicle/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
base.IntVector, base.VoidVector, ty.Sequence[ty.Tuple[int, int]]
]
DHUGE = np.finfo(np.dtype("<f8")).max * 0.1
ZERO_TOL = 1.0e-10


def _map_invert(mapping: ty.Dict[str, ty.Set[str]]) -> ty.Dict[str, str]:
Expand Down Expand Up @@ -1312,16 +1313,37 @@ def eta(self) -> base.DoubleVector:

@fn.cached_property
def rapidity(self) -> base.DoubleVector:
"""Rapidity component of the particle momenta, :math:`y`."""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
rap = 0.5 * np.log((self.energy + self.z) / (self.energy - self.z))
return rap.reshape(-1)
"""Rapidity component of the particle momenta, :math:`y`.
.. versionchanged:: 0.3.1
Explicitly handled zero division, replacing with ``np.inf``
with the appropriate sign.
"""
return calculate._rapidity(self.energy, self.z, ZERO_TOL).reshape(-1)

@fn.cached_property
def phi(self) -> base.DoubleVector:
"""Azimuth component of particle momenta, :math:`\\phi`."""
return np.angle(self._xy_pol).reshape(-1)
"""Azimuth component of particle momenta, :math:`\\phi`.
.. versionchanged:: 0.3.1
Where :math:`p_T` is very small, rendering azimuthal angles
numerically unstable, ``np.nan`` is given to enable user
handling.
"""
invalid = np.isclose(self.pt, 0.0, atol=ZERO_TOL)
phi_ = np.angle(self._xy_pol).reshape(-1)
if np.any(invalid):
num_nan = np.sum(invalid, dtype=np.int32).item()
e_tol = ZERO_TOL * 1.0e9
warnings.warn(
f"The transverse momenta of {num_nan} particles fall below "
f"{e_tol} eV. This may result in these particles giving "
"unstable or invalid values for the azimuthal angle. These "
"angles have been replaced with NaN.",
base.NumericalStabilityWarning,
)
phi_[invalid] = np.nan
return phi_

@fn.cached_property
def theta(self) -> base.DoubleVector:
Expand All @@ -1337,7 +1359,11 @@ def mass(self) -> base.DoubleVector:

@fn.cached_property
def mass_t(self) -> base.DoubleVector:
"""Transverse component of particle mass, :math:`m_T`."""
"""Transverse component of particle mass, :math:`m_T`.
.. versionchanged:: 0.3.1
Fixed bug for momenta with negative :math:`p_z`.
"""
return calculate._root_diff_two_squares(self.energy, self.z).reshape(
-1
)
Expand Down Expand Up @@ -1373,7 +1399,7 @@ def shift_eta(
shift: ty.Union[float, base.DoubleVector],
experimental: bool = False,
max_corrections: int = 10,
abs_tol: float = 1.0e-14,
abs_tol: float = ZERO_TOL,
) -> "MomentumArray":
"""Performs a Lorentz boost to a new frame, with a
pseudorapidity increased by ``shift``.
Expand Down Expand Up @@ -1414,7 +1440,7 @@ def shift_eta(
Warns
-----
UserWarning
NumericalStabilityWarning
If the method is unable to converge within ``abs_tol`` after
``max_corrections`` corrective iterations.
Expand Down Expand Up @@ -1453,7 +1479,8 @@ def shift_eta(
eta_mid, _ = calculate.resultant_coords(pmu, pseudo=True)
if converged is not True:
warnings.warn(
f"Unable to converge within a tolerance of {abs_tol}."
f"Unable to converge within a tolerance of {abs_tol}.",
base.NumericalStabilityWarning,
)
return pmu

Expand Down Expand Up @@ -1534,11 +1561,9 @@ def delta_R(
passing the same instance to the ``other`` parameter.
"""
get_rapidity = op.attrgetter("eta")
if pseudo is False:
if not pseudo:
get_rapidity = op.attrgetter("rapidity")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
rap1, rap2 = get_rapidity(self), get_rapidity(other)
rap1, rap2 = get_rapidity(self), get_rapidity(other)
with calculate._thread_scope(threads):
if self is other:
return calculate._delta_R_symmetric(rap1, self._xy_pol)
Expand Down
67 changes: 56 additions & 11 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
"""
``test_data``
=============
Unit tests for the data structures, probing their attributes and
methods.
"""
import cmath
import dataclasses as dc
import math
import random

import numpy as np
import pytest

import graphicle as gcl

Expand Down Expand Up @@ -45,24 +54,60 @@ def to_momentum_array(self) -> gcl.MomentumArray:
return gcl.MomentumArray([(self.px, self.py, self.pz, self.energy)])


def test_pdgs():
def test_pdg_quark_names():
"""Tests that the PDG names are correcly identified for the quarks."""
pdg_vals = np.arange(1, 7, dtype=np.int32)
pdgs = gcl.PdgArray(pdg_vals)
assert pdgs.name.tolist() == ["d", "u", "s", "c", "b", "t"]


def test_pmu_coords() -> None:
"""Tests that the components of the momentum are correctly stored
and calculated.
"""
momentum = MomentumExample()
pmu = momentum.to_momentum_array()
assert math.isclose(pmu.pt.item(), momentum.pt)
assert math.isclose(pmu.phi.item(), momentum.phi * math.pi)
assert math.isclose(pmu.mass.item(), 0.0, abs_tol=ZERO_TOL)
assert math.isclose(pmu.theta.item(), math.atan(momentum.pt / momentum.pz))


def test_pmu_transform() -> None:
correct_pt = math.isclose(pmu.pt.item(), momentum.pt)
assert correct_pt, "Incorrect pT."
correct_phi = math.isclose(pmu.phi.item(), momentum.phi * math.pi)
assert correct_phi, "Incorrect phi."
correct_mass = math.isclose(pmu.mass.item(), 0.0, abs_tol=ZERO_TOL)
assert correct_mass, "Nonzero mass."
correct_theta = math.isclose(
pmu.theta.item(), math.atan(momentum.pt / momentum.pz)
)
assert correct_theta, "Incorrect theta."


def test_pmu_transform_invertible() -> None:
"""Tests that the ``MomentumArray`` transforms are invertible."""
momentum = MomentumExample()
pmu = momentum.to_momentum_array()
shift = random.uniform(0.0, math.tau)
phi_invertible = np.allclose(pmu, pmu.shift_phi(shift).shift_phi(-shift))
assert phi_invertible, "Azimuth shift is not invertible."
rap_invertible = np.allclose(
pmu, pmu.shift_rapidity(shift).shift_rapidity(-shift)
)
assert rap_invertible, "Rapidity shift is not invertible."
eta_invertible = np.allclose(pmu, pmu.shift_eta(shift).shift_eta(-shift))
assert eta_invertible, "Pseudorapidity shift is not invertible."


def test_pmu_zero_pt() -> None:
"""Tests that when antiparallel momenta in the xy plane are added,
they have the correct properties, and the azimuth is flagged as
invalid.
"""
momentum = MomentumExample()
pmu = momentum.to_momentum_array()
zero_transverse = pmu.shift_phi(math.pi) + pmu
assert math.isclose(0.0, zero_transverse.pt.item(), abs_tol=ZERO_TOL)
assert math.isclose(zero_transverse.mass.item(), 6.0)
pmu_zero_pt = pmu.shift_phi(math.pi) + pmu
zero_pt = math.isclose(0.0, pmu_zero_pt.pt.item(), abs_tol=ZERO_TOL)
assert zero_pt, "Transverse momentum not properly cancelled"
correct_mass = math.isclose(pmu_zero_pt.mass.item(), 6.0)
assert correct_mass, "Mass generated is incorrect."
eta_inf = math.isinf(pmu_zero_pt.eta.item())
assert eta_inf, "Pseudorapidity is not infinite when longitudinal."
with pytest.warns(gcl.base.NumericalStabilityWarning):
phi_invalid = math.isnan(pmu_zero_pt.phi.item())
assert phi_invalid, "Azimuth is not NaN when pT is low"

0 comments on commit 7c4eac6

Please sign in to comment.