Skip to content

Commit

Permalink
added momentum splitting to transform module
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed Mar 4, 2024
1 parent ef73bf1 commit be08644
Showing 1 changed file with 116 additions and 123 deletions.
239 changes: 116 additions & 123 deletions graphicle/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,158 +4,151 @@
Utilities for manipulating the graph structure of particle data.
.. deprecated:: 0.3.1
Module is out of date, and will be removed in 0.4.0.
"""
import networkx as _nx
import cmath
import operator as op
import typing as ty

import numpy as np
from deprecation import deprecated
from typicle import Types
from typicle.convert import cast_array
import numpy.typing as npt

import graphicle as gcl

from . import base
_complex_unpack = op.attrgetter("real", "imag")


__all__ = ["particle_as_node", "centre_angle", "centre_pseudorapidity"]
def _cos_sin(angle: float) -> tuple[float, float]:
"""Returns a tuple containing the sine and cosine of an angle."""
return _complex_unpack(cmath.rect(1.0, angle))

_types = Types()

class SphericalAngle(ty.NamedTuple):
"""Pair of inclination and azimuthal angles, respectively."""

@deprecated(
deprecated_in="0.3.1",
removed_in="0.4.0",
details="See ``networkx.line_graph()`` for potential replacement.",
)
def particle_as_node(adj_list: gcl.AdjacencyList) -> gcl.AdjacencyList:
"""Converts an ``AdjacencyList`` in which the particles are
represented as edges, to one in which the particles are the nodes.
The order of the nodes in the resulting ``AdjacencyList`` retains
the same particle ordering of the initial edge list.
theta: float
phi: float

:group: transform

.. versionadded:: 0.1.0
def soft_hard_axis(momenta: gcl.MomentumArray) -> SphericalAngle:
"""Calculates the axis defined by the plane swept out between the
hardest and softest particles in ``momenta``.
Parameters
----------
adj_list : AdjacencyList
The edge-as-particle representation.
momenta : MomentumArray
Point cloud of particle four-momenta.
Returns
-------
node_adj : AdjacencyList
The node-as-particle representation.
Examples
--------
>>> from graphicle import transform
>>> # restructuring existing graph:
>>> graph.adj = transform.particle_as_node(graph.adj)
SphericalAngle
Normal axis to the soft-hard plane, defined in terms of
inclination and azimuthal angles.
"""
# create the networkx edge graph
nx_edge_graph = _nx.MultiDiGraph()
graph_dicts = adj_list.to_dicts()
nx_edge_graph.add_edges_from(graph_dicts["edges"])
# transform into node graph (with edge tuples rep'ing nodes)
nx_node_graph = _nx.line_graph(G=nx_edge_graph, create_using=_nx.DiGraph)
# create a node index for each particle
edges = adj_list.edges
num_pcls = len(edges)
node_idxs = np.empty(num_pcls, dtype=_types.int)
# translate nodes represented with edge tuples into node indices
edge_node_type = _types.edge.copy()
edge_node_type.append(("key", _types.int)) # if > 1 pcl between vtxs
edges_as_nodes = cast_array(
np.array(tuple(nx_node_graph.nodes)), edge_node_type
)

def check_sign(x):
"""Returns -1 if x <= 0, and +1 if x > 0."""
sign = np.sign(x)
sign = sign + int(not sign) # if sign = 0 => sign = +1
return sign

# node labels set as particle indices in original edge array
for i, node_triplet in enumerate(edges_as_nodes):
key = node_triplet["key"]
node = node_triplet[["src", "dst"]]
sign = -1 * check_sign(node["src"] * node["dst"])
node_idxs[i] = sign * (np.where(edges == node)[0][key] + 1)
nx_node_graph = _nx.relabel_nodes(
nx_node_graph,
{n: idx for n, idx in zip(nx_node_graph, node_idxs)},
)
return gcl.AdjacencyList(np.array(nx_node_graph.edges))


@deprecated(
deprecated_in="0.3.1",
removed_in="0.4.0",
details="Use ``calculate.resultant_coords()`` and "
"``data.MomentumArray.shift_phi()`` instead.",
)
def centre_angle(
angle: base.DoubleVector, pt: base.DoubleVector
) -> base.DoubleVector:
"""Shifts angles so transverse momentum weighted centroid is at
``0``.
:group: transform
.. versionadded:: 0.1.0
data = momenta.data.view(np.float64).reshape(-1, 4)
softest = data[momenta.pt.argmin()]
hardest = data[momenta.pt.argmax()]
axis = np.cross(softest[:3], hardest[:3])
phi_polar = axis[:2].view(np.complex128).item()
pt, phi = cmath.polar(phi_polar)
theta_polar = complex(pt, axis[2].item())
theta = cmath.phase(theta_polar)
return SphericalAngle(theta=theta, phi=phi)


def rotation_matrix(
angle: float, axis: SphericalAngle
) -> npt.NDArray[np.float64]:
"""Computes the matrix operator to rotate a 3D vector with respect
to an arbitrary ``axis`` by a given ``angle``.
Parameters
----------
angle : array
Angular displacements.
pt : array
Transverse momenta.
angle : float
Desired angular displacement after matrix multiplication.
axis : SphericalAngle
Inclination and azimuthal angles, defining the axis about which
the rotation is to be performed.
Returns
-------
centred_angle : array
Shifted angular displacements, with centroid at 0.
ndarray[float64]
A 3x3 matrix which, when acting to the left of a 3D vector, will
rotate it about the provided ``axis`` by the given ``angle``.
Notes
-----
This is a matrix implementation of Rodrigues' rotation formula [1]_.
References
----------
.. [1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
"""
# convert angles into complex polar positions
pos = np.exp(1.0j * angle)
# obtain weighted sum positions ie. un-normalised midpoint
pos_wt_mid = (pos * pt).sum()
# convert to U(1) rotation operator e^(-i delta x)
rot_op = (pos_wt_mid / np.abs(pos_wt_mid)).conjugate()
# rotate positions so midpoint is at 0
pos_centred = rot_op * pos
return np.angle(pos_centred) # type: ignore


@deprecated(
deprecated_in="0.3.1",
removed_in="0.4.0",
details="Use ``calculate.resultant_coords()`` and "
"``data.MomentumArray.shift_eta()`` instead.",
)
def centre_pseudorapidity(
eta: base.DoubleVector, pt: base.DoubleVector
) -> base.DoubleVector:
"""Shifts pseudorapidities so pt weighted midpoint is at ``0``.
:group: transform
.. versionadded:: 0.1.0
cos_theta, sin_theta = _cos_sin(axis.theta)
cos_phi, sin_phi = _cos_sin(axis.phi)
u_x, u_y, u_z = (sin_theta * cos_phi, sin_theta * sin_phi, cos_theta)
skew_sym = np.zeros((3, 3), dtype=np.float64)
upper_idxs = np.triu_indices(3, 1)
skew_sym[upper_idxs] = -u_z, u_y, -u_x
skew_sym[np.tril_indices(3, -1)] = -skew_sym[upper_idxs]
cos_alpha, sin_alpha = _cos_sin(angle)
rot = sin_alpha * skew_sym + (1.0 - cos_alpha) * (skew_sym @ skew_sym)
np.fill_diagonal(rot, rot.diagonal() + 1.0)
return rot


def split_hardest(
momenta: gcl.MomentumArray, z: float, angle: float
) -> gcl.MomentumArray:
"""Splits the momentum of the hardest particle into two momenta.
Energy and 3-momentum is conserved over the whole MomentumArray.
Hardness and collinearity of the split are determined by function
parameters.
Parameters
----------
eta : ndarray[float64]
Values of pseudorapidity for the particle set.
pt : ndarray[float64]
Values of transverse momenta for the particle set.
momenta : MomentumArray
Set of four-momenta, representing the point cloud of particles,
prior to splitting.
z : float
Energy fraction retained by the first child after the split.
Must be in range ``0.0 < z <= 0.5``.
angle : float
Angular displacement of the first child after the split. Will be
rotated in the plane swept out by the hardest and softest
particles in the particle set.
Returns
-------
eta_centred : ndarray[float64]
Pseudorapidity values relative to the centre of transverse
momentum.
MomentumArray
Set of four-momenta after splitting, such that the length is
increased by one. The highest transverse momentum element has
been removed from the set, and replaced with two momenta
elements. The first and second children of the split are the
penultimate and final elements of the MomentumArray,
respectively.
Notes
-----
This function is intended to check the IRC safety of our GNN jet
clustering algorithms. It is implemented from the description given
in a jet tagging paper [1]_, which defined the IRC safe message
passing procedure used in this work.
References
----------
.. [1] https://doi.org/10.48550/arXiv.2109.14636
"""
pt_norm = pt / pt.sum()
eta_wt_mid = (eta * pt_norm).sum()
return eta - eta_wt_mid # type: ignore
if not (0.0 < z <= 0.5):
raise ValueError("z must be in range (0, 0.5]")
data = momenta.data.view(np.float64).reshape(-1, 4)
hard_idx = momenta.pt.argmax()
out = np.empty((len(momenta) + 1, 4), dtype=np.float64)
out[:hard_idx] = data[:hard_idx]
out[hard_idx:-2] = data[(hard_idx + 1) :]
parent = data[hard_idx]
child_1 = z * parent
child_1[:3] @= rotation_matrix(angle, soft_hard_axis(momenta)).T
child_2 = parent - child_1
out[-2] = child_1[...]
out[-1] = child_2[...]
return gcl.MomentumArray(out)

0 comments on commit be08644

Please sign in to comment.