diff --git a/graphicle/calculate.py b/graphicle/calculate.py index e43a0b3..c9d16f1 100644 --- a/graphicle/calculate.py +++ b/graphicle/calculate.py @@ -709,20 +709,36 @@ def thrust(momenta: "MomentumArray") -> float: pmu : MomentumArray Momentum of hadronised particles in the final state of the event record. + return_axis : bool + If ``True``, will return a tuple with the thrust, and the axis + unit vector which was found to maximise thrust. + rng_seed : int, optional + Initial guess for the axis unit vector is sampled from a uniform + random distribution, over the surface of a sphere. If passed, + will initialise the random number generator with the provided + seed, enabling reproducible results. Returns ------- - float + thrust : float The thrust of the event. + axis : ndarray[float64], optional + The axis which maximises the thrust for the event. """ domain = (-math.pi, math.pi) + rng = np.random.default_rng(seed=rng_seed) + guess = rng.uniform(*domain, size=2) optim = spo.minimize( fun=lambda n, p: -_thrust_with_axis(n, p), - x0=np.zeros(2), + x0=guess, bounds=(domain, domain), - args=(momenta.data,), + jac=lambda n, p: -_grad_thrust(n, p), + args=(pmu.data,), ) - return -optim.fun + thrust_val = -optim.fun + if return_axis: + return thrust_val, optim.x + return thrust_val @nb.njit(