Skip to content

Commit

Permalink
cleared deprecated transforms, added momentum split #175
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed May 31, 2024
1 parent 6428304 commit c0969c8
Showing 1 changed file with 193 additions and 122 deletions.
315 changes: 193 additions & 122 deletions graphicle/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,158 +4,229 @@
Utilities for manipulating the graph structure of particle data.
.. deprecated:: 0.3.1
Module is out of date, and will be removed in 0.4.0.
"""
import networkx as _nx
import cmath
import operator as op
import typing as ty

import numpy as np
from deprecation import deprecated
from typicle import Types
from typicle.convert import cast_array
import numpy.typing as npt

import graphicle as gcl

from . import base
_complex_unpack = op.attrgetter("real", "imag")


class SphericalAngle(ty.NamedTuple):
"""Pair of inclination and azimuthal angles, respectively."""

theta: float
phi: float


class SphericalAxis(ty.NamedTuple):
x: float
y: float
z: float


__all__ = ["particle_as_node", "centre_angle", "centre_pseudorapidity"]
def _angles_to_axis(angles: SphericalAngle) -> SphericalAxis:
axis = gcl.calculate._angles_to_axis(np.array([angles.phi, angles.theta]))
return SphericalAxis(*axis.tolist())

_types = Types()

def _axis_to_angles(axis: SphericalAxis) -> SphericalAngle:
phi_polar = complex(axis.x, axis.y)
pt, phi = cmath.polar(phi_polar)
theta_polar = complex(pt, axis.z)
theta = cmath.phase(theta_polar)
return SphericalAngle(theta=theta, phi=phi)

@deprecated(
deprecated_in="0.3.1",
removed_in="0.4.0",
details="See ``networkx.line_graph()`` for potential replacement.",
)
def particle_as_node(adj_list: gcl.AdjacencyList) -> gcl.AdjacencyList:
"""Converts an ``AdjacencyList`` in which the particles are
represented as edges, to one in which the particles are the nodes.
The order of the nodes in the resulting ``AdjacencyList`` retains
the same particle ordering of the initial edge list.

:group: transform
def _momentum_to_numpy(momenta: gcl.MomentumArray) -> gcl.base.DoubleVector:
return momenta.data.view(np.float64).reshape(-1, 4)

.. versionadded:: 0.1.0

def _cos_sin(angle: float) -> ty.Tuple[float, float]:
"""Returns a tuple containing the sine and cosine of an angle."""
return _complex_unpack(cmath.rect(1.0, angle))


def soft_hard_axis(momenta: gcl.MomentumArray) -> SphericalAngle:
"""Calculates the axis defined by the plane swept out between the
hardest and softest particles in ``momenta``.
Parameters
----------
adj_list : AdjacencyList
The edge-as-particle representation.
momenta : MomentumArray
Point cloud of particle four-momenta.
Returns
-------
node_adj : AdjacencyList
The node-as-particle representation.
SphericalAngle
Normal axis to the soft-hard plane, defined in terms of
inclination and azimuthal angles.
"""
data = momenta.data.view(np.float64).reshape(-1, 4)
softest = data[momenta.pt.argmin()]
hardest = data[momenta.pt.argmax()]
axis = np.cross(softest[:3], hardest[:3])
return _axis_to_angles(SphericalAxis(*axis.tolist()))

Examples
--------
>>> from graphicle import transform
>>> # restructuring existing graph:
>>> graph.adj = transform.particle_as_node(graph.adj)

def rotation_matrix(
angle: float, axis: SphericalAngle
) -> npt.NDArray[np.float64]:
"""Computes the matrix operator to rotate a 3D vector with respect
to an arbitrary ``axis`` by a given ``angle``.
Parameters
----------
angle : float
Desired angular displacement after matrix multiplication.
axis : SphericalAngle
Inclination and azimuthal angles, defining the axis about which
the rotation is to be performed.
Returns
-------
ndarray[float64]
A 3x3 matrix which, when acting to the left of a 3D vector, will
rotate it about the provided ``axis`` by the given ``angle``.
Notes
-----
This is a matrix implementation of Rodrigues' rotation formula [1]_.
References
----------
.. [1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
"""
# create the networkx edge graph
nx_edge_graph = _nx.MultiDiGraph()
graph_dicts = adj_list.to_dicts()
nx_edge_graph.add_edges_from(graph_dicts["edges"])
# transform into node graph (with edge tuples rep'ing nodes)
nx_node_graph = _nx.line_graph(G=nx_edge_graph, create_using=_nx.DiGraph)
# create a node index for each particle
edges = adj_list.edges
num_pcls = len(edges)
node_idxs = np.empty(num_pcls, dtype=_types.int)
# translate nodes represented with edge tuples into node indices
edge_node_type = _types.edge.copy()
edge_node_type.append(("key", _types.int)) # if > 1 pcl between vtxs
edges_as_nodes = cast_array(
np.array(tuple(nx_node_graph.nodes)), edge_node_type
)

def check_sign(x):
"""Returns -1 if x <= 0, and +1 if x > 0."""
sign = np.sign(x)
sign = sign + int(not sign) # if sign = 0 => sign = +1
return sign

# node labels set as particle indices in original edge array
for i, node_triplet in enumerate(edges_as_nodes):
key = node_triplet["key"]
node = node_triplet[["src", "dst"]]
sign = -1 * check_sign(node["src"] * node["dst"])
node_idxs[i] = sign * (np.where(edges == node)[0][key] + 1)
nx_node_graph = _nx.relabel_nodes(
nx_node_graph,
{n: idx for n, idx in zip(nx_node_graph, node_idxs)},
)
return gcl.AdjacencyList(np.array(nx_node_graph.edges))


@deprecated(
deprecated_in="0.3.1",
removed_in="0.4.0",
details="Use ``calculate.resultant_coords()`` and "
"``data.MomentumArray.shift_phi()`` instead.",
)
def centre_angle(
angle: base.DoubleVector, pt: base.DoubleVector
) -> base.DoubleVector:
"""Shifts angles so transverse momentum weighted centroid is at
``0``.
:group: transform
.. versionadded:: 0.1.0
axis_3d = _angles_to_axis(axis)
skew_sym = np.zeros((3, 3), dtype=np.float64)
upper_idxs = np.triu_indices(3, 1)
skew_sym[upper_idxs] = -axis_3d.z, axis_3d.y, -axis_3d.x
skew_sym[np.tril_indices(3, -1)] = -skew_sym[upper_idxs]
cos_alpha, sin_alpha = _cos_sin(angle)
rot = sin_alpha * skew_sym + (1.0 - cos_alpha) * (skew_sym @ skew_sym)
np.fill_diagonal(rot, rot.diagonal() + 1.0)
return rot


def split_momentum(
momentum: gcl.MomentumArray,
z: float,
angle: float,
axis: ty.Union[ty.Tuple[float, float], SphericalAngle],
) -> gcl.MomentumArray:
"""Splits the momentum of the given particle into two momenta.
Energy and 3-momentum is conserved. Hardness and collinearity of the
split are determined by ``z`` and ``angle``.
Parameters
----------
angle : array
Angular displacements.
pt : array
Transverse momenta.
momentum : MomentumArray
Four-momentum prior to splitting. Must contain only one element.
z : float
Energy fraction retained by the first child after the split.
Must be in range ``0.0 < z <= 0.5``.
angle : float
Angular displacement of the first child after the split.
axis : SphericalAngle or tuple[float, float], optional
The theta and phi values of the axis vector about which to
rotate the momentum. If ``None``, will choose the axis normal to
the plane swept out by the hardest and softest momentum
constituents.
Returns
-------
centred_angle : array
Shifted angular displacements, with centroid at 0.
MomentumArray
Four-momenta of two particles produced by splitting.
See Also
--------
soft_hard_axis : Axis of plane swept by softest and hardest momenta.
rotation_matrix : Matrix to rotate 3-vectors about a given axis.
"""
# convert angles into complex polar positions
pos = np.exp(1.0j * angle)
# obtain weighted sum positions ie. un-normalised midpoint
pos_wt_mid = (pos * pt).sum()
# convert to U(1) rotation operator e^(-i delta x)
rot_op = (pos_wt_mid / np.abs(pos_wt_mid)).conjugate()
# rotate positions so midpoint is at 0
pos_centred = rot_op * pos
return np.angle(pos_centred) # type: ignore


@deprecated(
deprecated_in="0.3.1",
removed_in="0.4.0",
details="Use ``calculate.resultant_coords()`` and "
"``data.MomentumArray.shift_eta()`` instead.",
)
def centre_pseudorapidity(
eta: base.DoubleVector, pt: base.DoubleVector
) -> base.DoubleVector:
"""Shifts pseudorapidities so pt weighted midpoint is at ``0``.
:group: transform
.. versionadded:: 0.1.0
if len(momentum) > 1:
raise ValueError("momentum must have only one element.")
if not isinstance(axis, SphericalAngle):
axis = SphericalAngle(*axis)
parent = _momentum_to_numpy(momentum)
children = np.tile(parent, (2, 1))
children[0, :] *= z
children[0, :3] @= rotation_matrix(angle, axis).T
children[1, :] -= children[0, :]
return gcl.MomentumArray(children)


def split_hardest(
momenta: gcl.MomentumArray,
z: float,
angle: float,
axis: ty.Optional[ty.Union[ty.Tuple[float, float], SphericalAngle]] = None,
) -> gcl.MomentumArray:
"""Splits the momentum of the hardest particle into two momenta.
Energy and 3-momentum is conserved over the whole MomentumArray.
Hardness and collinearity of the split are determined by function
parameters.
Parameters
----------
eta : ndarray[float64]
Values of pseudorapidity for the particle set.
pt : ndarray[float64]
Values of transverse momenta for the particle set.
momenta : MomentumArray
Set of four-momenta, representing the point cloud of particles,
prior to splitting.
z : float
Energy fraction retained by the first child after the split.
Must be in range ``0.0 < z <= 0.5``.
angle : float
Angular displacement of the first child after the split.
axis : SphericalAngle or tuple[float, float], optional
The theta and phi values of the axis vector about which to
rotate the momenta. If ``None``, will choose the axis normal to
the plane swept out by the hardest and softest momentum
constituents.
Returns
-------
eta_centred : ndarray[float64]
Pseudorapidity values relative to the centre of transverse
momentum.
MomentumArray
Set of four-momenta after splitting, such that the length is
increased by one. The highest transverse momentum element has
been removed from the set, and replaced with two momenta
elements. The first and second children of the split are the
penultimate and final elements of the MomentumArray,
respectively.
See Also
--------
soft_hard_axis : Axis of plane swept by softest and hardest momenta.
rotation_matrix : Matrix to rotate 3-vectors about a given axis.
Notes
-----
This function is intended to check the IRC safety of our GNN jet
clustering algorithms. It is implemented from the description given
in a jet tagging paper [1]_, which defined the IRC safe message
passing procedure used in this work.
References
----------
.. [1] https://doi.org/10.48550/arXiv.2109.14636
"""
pt_norm = pt / pt.sum()
eta_wt_mid = (eta * pt_norm).sum()
return eta - eta_wt_mid # type: ignore
if not (0.0 < z <= 0.5):
raise ValueError("z must be in range (0, 0.5]")
if axis is None:
if len(momenta) < 2:
raise ValueError(
"If axis is not provided, there must be at least two elements "
"in momenta."
)
axis = soft_hard_axis(momenta)
data = _momentum_to_numpy(momenta)
hard_idx = momenta.pt.argmax()
out = np.empty((len(momenta) + 1, 4), dtype=np.float64)
out[:hard_idx] = data[:hard_idx]
out[hard_idx:-2] = data[(hard_idx + 1) :]
children = split_momentum(momenta[hard_idx], z, angle, axis)
out[-2:, ...] = _momentum_to_numpy(children)[...]
return gcl.MomentumArray(out)

0 comments on commit c0969c8

Please sign in to comment.