From c0969c8e289af956ae427fcf2af4f9083bce58fd Mon Sep 17 00:00:00 2001 From: Jacan Chaplais Date: Fri, 31 May 2024 11:57:32 +0100 Subject: [PATCH] cleared deprecated transforms, added momentum split #175 --- graphicle/transform.py | 315 +++++++++++++++++++++++++---------------- 1 file changed, 193 insertions(+), 122 deletions(-) diff --git a/graphicle/transform.py b/graphicle/transform.py index c13b2e6..70e6eda 100644 --- a/graphicle/transform.py +++ b/graphicle/transform.py @@ -4,158 +4,229 @@ Utilities for manipulating the graph structure of particle data. -.. deprecated:: 0.3.1 - Module is out of date, and will be removed in 0.4.0. """ -import networkx as _nx +import cmath +import operator as op +import typing as ty + import numpy as np -from deprecation import deprecated -from typicle import Types -from typicle.convert import cast_array +import numpy.typing as npt import graphicle as gcl -from . import base +_complex_unpack = op.attrgetter("real", "imag") + + +class SphericalAngle(ty.NamedTuple): + """Pair of inclination and azimuthal angles, respectively.""" + + theta: float + phi: float + + +class SphericalAxis(ty.NamedTuple): + x: float + y: float + z: float + -__all__ = ["particle_as_node", "centre_angle", "centre_pseudorapidity"] +def _angles_to_axis(angles: SphericalAngle) -> SphericalAxis: + axis = gcl.calculate._angles_to_axis(np.array([angles.phi, angles.theta])) + return SphericalAxis(*axis.tolist()) -_types = Types() +def _axis_to_angles(axis: SphericalAxis) -> SphericalAngle: + phi_polar = complex(axis.x, axis.y) + pt, phi = cmath.polar(phi_polar) + theta_polar = complex(pt, axis.z) + theta = cmath.phase(theta_polar) + return SphericalAngle(theta=theta, phi=phi) -@deprecated( - deprecated_in="0.3.1", - removed_in="0.4.0", - details="See ``networkx.line_graph()`` for potential replacement.", -) -def particle_as_node(adj_list: gcl.AdjacencyList) -> gcl.AdjacencyList: - """Converts an ``AdjacencyList`` in which the particles are - represented as edges, to one in which the particles are the nodes. - The order of the nodes in the resulting ``AdjacencyList`` retains - the same particle ordering of the initial edge list. - :group: transform +def _momentum_to_numpy(momenta: gcl.MomentumArray) -> gcl.base.DoubleVector: + return momenta.data.view(np.float64).reshape(-1, 4) - .. versionadded:: 0.1.0 + +def _cos_sin(angle: float) -> ty.Tuple[float, float]: + """Returns a tuple containing the sine and cosine of an angle.""" + return _complex_unpack(cmath.rect(1.0, angle)) + + +def soft_hard_axis(momenta: gcl.MomentumArray) -> SphericalAngle: + """Calculates the axis defined by the plane swept out between the + hardest and softest particles in ``momenta``. Parameters ---------- - adj_list : AdjacencyList - The edge-as-particle representation. + momenta : MomentumArray + Point cloud of particle four-momenta. Returns ------- - node_adj : AdjacencyList - The node-as-particle representation. + SphericalAngle + Normal axis to the soft-hard plane, defined in terms of + inclination and azimuthal angles. + """ + data = momenta.data.view(np.float64).reshape(-1, 4) + softest = data[momenta.pt.argmin()] + hardest = data[momenta.pt.argmax()] + axis = np.cross(softest[:3], hardest[:3]) + return _axis_to_angles(SphericalAxis(*axis.tolist())) - Examples - -------- - >>> from graphicle import transform - >>> # restructuring existing graph: - >>> graph.adj = transform.particle_as_node(graph.adj) + +def rotation_matrix( + angle: float, axis: SphericalAngle +) -> npt.NDArray[np.float64]: + """Computes the matrix operator to rotate a 3D vector with respect + to an arbitrary ``axis`` by a given ``angle``. + + Parameters + ---------- + angle : float + Desired angular displacement after matrix multiplication. + axis : SphericalAngle + Inclination and azimuthal angles, defining the axis about which + the rotation is to be performed. + + Returns + ------- + ndarray[float64] + A 3x3 matrix which, when acting to the left of a 3D vector, will + rotate it about the provided ``axis`` by the given ``angle``. + + Notes + ----- + This is a matrix implementation of Rodrigues' rotation formula [1]_. + + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula """ - # create the networkx edge graph - nx_edge_graph = _nx.MultiDiGraph() - graph_dicts = adj_list.to_dicts() - nx_edge_graph.add_edges_from(graph_dicts["edges"]) - # transform into node graph (with edge tuples rep'ing nodes) - nx_node_graph = _nx.line_graph(G=nx_edge_graph, create_using=_nx.DiGraph) - # create a node index for each particle - edges = adj_list.edges - num_pcls = len(edges) - node_idxs = np.empty(num_pcls, dtype=_types.int) - # translate nodes represented with edge tuples into node indices - edge_node_type = _types.edge.copy() - edge_node_type.append(("key", _types.int)) # if > 1 pcl between vtxs - edges_as_nodes = cast_array( - np.array(tuple(nx_node_graph.nodes)), edge_node_type - ) - - def check_sign(x): - """Returns -1 if x <= 0, and +1 if x > 0.""" - sign = np.sign(x) - sign = sign + int(not sign) # if sign = 0 => sign = +1 - return sign - - # node labels set as particle indices in original edge array - for i, node_triplet in enumerate(edges_as_nodes): - key = node_triplet["key"] - node = node_triplet[["src", "dst"]] - sign = -1 * check_sign(node["src"] * node["dst"]) - node_idxs[i] = sign * (np.where(edges == node)[0][key] + 1) - nx_node_graph = _nx.relabel_nodes( - nx_node_graph, - {n: idx for n, idx in zip(nx_node_graph, node_idxs)}, - ) - return gcl.AdjacencyList(np.array(nx_node_graph.edges)) - - -@deprecated( - deprecated_in="0.3.1", - removed_in="0.4.0", - details="Use ``calculate.resultant_coords()`` and " - "``data.MomentumArray.shift_phi()`` instead.", -) -def centre_angle( - angle: base.DoubleVector, pt: base.DoubleVector -) -> base.DoubleVector: - """Shifts angles so transverse momentum weighted centroid is at - ``0``. - - :group: transform - - .. versionadded:: 0.1.0 + axis_3d = _angles_to_axis(axis) + skew_sym = np.zeros((3, 3), dtype=np.float64) + upper_idxs = np.triu_indices(3, 1) + skew_sym[upper_idxs] = -axis_3d.z, axis_3d.y, -axis_3d.x + skew_sym[np.tril_indices(3, -1)] = -skew_sym[upper_idxs] + cos_alpha, sin_alpha = _cos_sin(angle) + rot = sin_alpha * skew_sym + (1.0 - cos_alpha) * (skew_sym @ skew_sym) + np.fill_diagonal(rot, rot.diagonal() + 1.0) + return rot + + +def split_momentum( + momentum: gcl.MomentumArray, + z: float, + angle: float, + axis: ty.Union[ty.Tuple[float, float], SphericalAngle], +) -> gcl.MomentumArray: + """Splits the momentum of the given particle into two momenta. + Energy and 3-momentum is conserved. Hardness and collinearity of the + split are determined by ``z`` and ``angle``. Parameters ---------- - angle : array - Angular displacements. - pt : array - Transverse momenta. + momentum : MomentumArray + Four-momentum prior to splitting. Must contain only one element. + z : float + Energy fraction retained by the first child after the split. + Must be in range ``0.0 < z <= 0.5``. + angle : float + Angular displacement of the first child after the split. + axis : SphericalAngle or tuple[float, float], optional + The theta and phi values of the axis vector about which to + rotate the momentum. If ``None``, will choose the axis normal to + the plane swept out by the hardest and softest momentum + constituents. Returns ------- - centred_angle : array - Shifted angular displacements, with centroid at 0. + MomentumArray + Four-momenta of two particles produced by splitting. + + See Also + -------- + soft_hard_axis : Axis of plane swept by softest and hardest momenta. + rotation_matrix : Matrix to rotate 3-vectors about a given axis. """ - # convert angles into complex polar positions - pos = np.exp(1.0j * angle) - # obtain weighted sum positions ie. un-normalised midpoint - pos_wt_mid = (pos * pt).sum() - # convert to U(1) rotation operator e^(-i delta x) - rot_op = (pos_wt_mid / np.abs(pos_wt_mid)).conjugate() - # rotate positions so midpoint is at 0 - pos_centred = rot_op * pos - return np.angle(pos_centred) # type: ignore - - -@deprecated( - deprecated_in="0.3.1", - removed_in="0.4.0", - details="Use ``calculate.resultant_coords()`` and " - "``data.MomentumArray.shift_eta()`` instead.", -) -def centre_pseudorapidity( - eta: base.DoubleVector, pt: base.DoubleVector -) -> base.DoubleVector: - """Shifts pseudorapidities so pt weighted midpoint is at ``0``. - - :group: transform - - .. versionadded:: 0.1.0 + if len(momentum) > 1: + raise ValueError("momentum must have only one element.") + if not isinstance(axis, SphericalAngle): + axis = SphericalAngle(*axis) + parent = _momentum_to_numpy(momentum) + children = np.tile(parent, (2, 1)) + children[0, :] *= z + children[0, :3] @= rotation_matrix(angle, axis).T + children[1, :] -= children[0, :] + return gcl.MomentumArray(children) + + +def split_hardest( + momenta: gcl.MomentumArray, + z: float, + angle: float, + axis: ty.Optional[ty.Union[ty.Tuple[float, float], SphericalAngle]] = None, +) -> gcl.MomentumArray: + """Splits the momentum of the hardest particle into two momenta. + Energy and 3-momentum is conserved over the whole MomentumArray. + Hardness and collinearity of the split are determined by function + parameters. Parameters ---------- - eta : ndarray[float64] - Values of pseudorapidity for the particle set. - pt : ndarray[float64] - Values of transverse momenta for the particle set. + momenta : MomentumArray + Set of four-momenta, representing the point cloud of particles, + prior to splitting. + z : float + Energy fraction retained by the first child after the split. + Must be in range ``0.0 < z <= 0.5``. + angle : float + Angular displacement of the first child after the split. + axis : SphericalAngle or tuple[float, float], optional + The theta and phi values of the axis vector about which to + rotate the momenta. If ``None``, will choose the axis normal to + the plane swept out by the hardest and softest momentum + constituents. Returns ------- - eta_centred : ndarray[float64] - Pseudorapidity values relative to the centre of transverse - momentum. + MomentumArray + Set of four-momenta after splitting, such that the length is + increased by one. The highest transverse momentum element has + been removed from the set, and replaced with two momenta + elements. The first and second children of the split are the + penultimate and final elements of the MomentumArray, + respectively. + + See Also + -------- + soft_hard_axis : Axis of plane swept by softest and hardest momenta. + rotation_matrix : Matrix to rotate 3-vectors about a given axis. + + Notes + ----- + This function is intended to check the IRC safety of our GNN jet + clustering algorithms. It is implemented from the description given + in a jet tagging paper [1]_, which defined the IRC safe message + passing procedure used in this work. + + References + ---------- + .. [1] https://doi.org/10.48550/arXiv.2109.14636 """ - pt_norm = pt / pt.sum() - eta_wt_mid = (eta * pt_norm).sum() - return eta - eta_wt_mid # type: ignore + if not (0.0 < z <= 0.5): + raise ValueError("z must be in range (0, 0.5]") + if axis is None: + if len(momenta) < 2: + raise ValueError( + "If axis is not provided, there must be at least two elements " + "in momenta." + ) + axis = soft_hard_axis(momenta) + data = _momentum_to_numpy(momenta) + hard_idx = momenta.pt.argmax() + out = np.empty((len(momenta) + 1, 4), dtype=np.float64) + out[:hard_idx] = data[:hard_idx] + out[hard_idx:-2] = data[(hard_idx + 1) :] + children = split_momentum(momenta[hard_idx], z, angle, axis) + out[-2:, ...] = _momentum_to_numpy(children)[...] + return gcl.MomentumArray(out)