Skip to content

Commit

Permalink
Merge pull request #8 from desy-ml/4-remove-ocelot-dependencies
Browse files Browse the repository at this point in the history
Removes Ocelot dependencies in Cheetah unless one actually wants to convert from Ocelot
  • Loading branch information
jank324 authored Oct 12, 2022
2 parents 557ad7a + 31ec33b commit f153a87
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 10 deletions.
22 changes: 15 additions & 7 deletions cheetah/particles.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import ocelot.adaptors.astra2ocelot as oca
import torch
from torch.distributions import MultivariateNormal

from cheetah.utils import from_astrabeam


class Beam:

Expand Down Expand Up @@ -68,8 +69,7 @@ def from_ocelot(cls, parray):
@classmethod
def from_astra(cls, path, **kwargs):
"""Load an Astra particle distribution as a Cheetah Beam."""
ocelot_parray = oca.astraBeam2particleArray(path, print_params=False)
return cls.from_ocelot(ocelot_parray, **kwargs)
raise NotImplementedError

def transformed_to(
self,
Expand Down Expand Up @@ -278,8 +278,14 @@ def from_ocelot(cls, parray):
@classmethod
def from_astra(cls, path, **kwargs):
"""Load an Astra particle distribution as a Cheetah Beam."""
ocelot_parray = oca.astraBeam2particleArray(path, print_params=False)
return cls.from_ocelot(ocelot_parray, **kwargs)
particles, energy = from_astrabeam(path)
mu = torch.ones(7)
mu[:6] = torch.tensor(particles.mean(axis=0), dtype=torch.float32)

cov = torch.zeros(7, 7)
cov[:6, :6] = torch.tensor(np.cov(particles.transpose()), dtype=torch.float32)

return cls(mu=mu, cov=cov, energy=energy)

def transformed_to(
self,
Expand Down Expand Up @@ -596,8 +602,10 @@ def from_ocelot(cls, parray, device="auto"):
@classmethod
def from_astra(cls, path, **kwargs):
"""Load an Astra particle distribution as a Cheetah Beam."""
ocelot_parray = oca.astraBeam2particleArray(path, print_params=False)
return cls.from_ocelot(ocelot_parray, **kwargs)
particles, energy = from_astrabeam(path)
particles_7d = torch.ones((particles.shape[0], 7))
particles_7d[:, :6] = torch.from_numpy(particles)
return cls(particles_7d, energy, **kwargs)

def transformed_to(
self,
Expand Down
80 changes: 79 additions & 1 deletion cheetah/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,74 @@
import ocelot as oc
import numpy as np
import torch

from cheetah import accelerator as acc

# electron mass in eV
m_e_eV = 510998.8671


def from_astrabeam(path: str):
"""
Read from a ASTRA beam distribution, and prepare for conversion to a Cheetah
ParticleBeam or ParameterBeam.
Adapted from the implementation in ocelot:
https://github.com/ocelot-collab/ocelot/blob/master/ocelot/adaptors/astra2ocelot.py
Parameters
----------
path : str
Path to the ASTRA beam distribution file.
Returns
-------
particles : np.ndarray
Particle 6D phase space information
energy : float
Mean energy of the particle beam
"""
P0 = np.loadtxt(path)

# remove lost particles
inds = np.argwhere(P0[:, 9] > 0)
inds = inds.reshape(inds.shape[0])

P0 = P0[inds, :]
n_particles = P0.shape[0]

# s_ref = P0[0, 2]
Pref = P0[0, 5]

xp = P0[:, :6]
xp[0, 2] = 0.0
xp[0, 5] = 0.0

gamref = np.sqrt((Pref / m_e_eV) ** 2 + 1)
# energy in eV: E = gamma * m_e
energy = gamref * m_e_eV

n_particles = xp.shape[0]
particles = np.zeros((n_particles, 6))

u = np.c_[xp[:, 3], xp[:, 4], xp[:, 5] + Pref]
gamma = np.sqrt(1 + np.sum(u * u, 1) / m_e_eV**2)
beta = np.sqrt(1 - gamma**-2)
betaref = np.sqrt(1 - gamref**-2)

p0 = np.linalg.norm(u, 2, 1).reshape((n_particles, 1))

u = u / p0
cdt = -xp[:, 2] / (beta * u[:, 2])
particles[:, 0] = xp[:, 0] + beta * u[:, 0] * cdt
particles[:, 2] = xp[:, 1] + beta * u[:, 1] * cdt
particles[:, 4] = cdt
particles[:, 1] = xp[:, 3] / Pref
particles[:, 3] = xp[:, 4] / Pref
particles[:, 5] = (gamma / gamref - 1) / betaref

return particles, energy


def ocelot2cheetah(element, warnings=True):
"""
Expand All @@ -26,6 +92,14 @@ def ocelot2cheetah(element, warnings=True):
need adjusting afterwards. BPM objects are only created from `ocelot.Monitor`
objects when their id has a substring "BPM".
"""
try:
import ocelot as oc
except ImportError:
raise ImportError(
"""To use the ocelot2cheetah lattice converter, Ocelot must be first
installed, see https://github.com/ocelot-collab/ocelot """
)

if isinstance(element, oc.Drift):
return acc.Drift(element.l, name=element.id)
elif isinstance(element, oc.Quadrupole):
Expand All @@ -47,6 +121,10 @@ def ocelot2cheetah(element, warnings=True):
elif isinstance(element, oc.Undulator):
return acc.Undulator(element.l, name=element.id)
else:
if warnings:
print(
f"WARNING: Unknown element {element.id}, replacing with drift section."
)
return acc.Drift(element.l, name=element.id)


Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="cheetah-accelerator",
version="0.5.14",
version="0.5.15",
author="Jan Kaiser & Oliver Stein",
author_email="jan.kaiser@desy.de",
url="https://github.com/desy-ml/cheetah",
Expand All @@ -19,7 +19,6 @@
"torch",
"matplotlib",
"numpy",
"ocelot @ git+https://github.com/ocelot-collab/ocelot.git@21.12.1",
"scipy",
],
)
Loading

0 comments on commit f153a87

Please sign in to comment.