Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add space-charge to Cheetah #142

Merged
merged 127 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
127 commits
Select commit Hold shift + click to select a range
2a02165
save test
Mar 21, 2024
8c0cbd3
first commit
Apr 5, 2024
610f706
maj1
greglenerd Apr 5, 2024
99f3c53
commit2
greglenerd Apr 5, 2024
7900034
com
greglenerd Apr 5, 2024
96e7fa3
c
greglenerd Apr 5, 2024
f25065e
c
greglenerd Apr 5, 2024
f914e2b
c
greglenerd Apr 5, 2024
b01d4fe
c
greglenerd Apr 5, 2024
357840e
Update of charge deposition. First draft of poisson solver
greglenerd Apr 5, 2024
e1f8785
added space_cherge_deposition_vec for faster computation. First draft…
greglenerd Apr 10, 2024
4565920
First version of the whole IGF solver. Works with one particle, but s…
greglenerd Apr 11, 2024
27632d7
Update cheetah/accelerator.py
greglenerd Apr 11, 2024
1b5535e
vectorized version of the code
greglenerd Apr 11, 2024
d7846c4
first "complete" version of the code, with the track method implement…
greglenerd Apr 25, 2024
61f6723
version that runs correctly, need to be quantitatively tested
greglenerd Apr 26, 2024
0e57d2a
Draft version of the code, tested with the test_space_charge_kick.py …
greglenerd Apr 30, 2024
8c9c3fc
minor written chqnges to accelerator.py
greglenerd May 3, 2024
1c18884
Update cheetah/accelerator.py
greglenerd May 6, 2024
342e162
Update cheetah/accelerator.py
greglenerd May 6, 2024
6821781
Update cheetah/accelerator.py
greglenerd May 6, 2024
3263513
Update cheetah/accelerator.py
greglenerd May 6, 2024
1b8b60d
Update cheetah/accelerator.py
greglenerd May 6, 2024
6220fb7
Update cheetah/accelerator.py
greglenerd May 6, 2024
f90cef1
Update cheetah/accelerator.py
greglenerd May 6, 2024
be0b706
before pulling the suggested changes
greglenerd May 6, 2024
cfecc6c
Merge branch 'space_charge' of https://github.com/greglenerd/cheetah …
greglenerd May 6, 2024
6ed5a74
Update cheetah/accelerator.py
greglenerd May 6, 2024
6bc41e7
.
greglenerd May 6, 2024
ff4ae36
Merge branch 'space_charge' of https://github.com/greglenerd/cheetah …
greglenerd May 6, 2024
db5876d
Update cheetah/accelerator.py
greglenerd May 6, 2024
96a6220
Update cheetah/accelerator.py
greglenerd May 6, 2024
5da3b8f
Update cheetah/accelerator.py
greglenerd May 6, 2024
7d12f0f
Update cheetah/accelerator.py
greglenerd May 6, 2024
4dcf3cb
Update tests/test_space_charge_kick.py
greglenerd May 6, 2024
3534a5b
Update cheetah/accelerator.py
greglenerd May 6, 2024
304c445
Update cheetah/accelerator.py
greglenerd May 6, 2024
2b787d3
Update cheetah/accelerator.py
greglenerd May 6, 2024
ea8bac2
Update cheetah/accelerator.py
greglenerd May 6, 2024
d9bb152
cleaning
greglenerd May 6, 2024
10d1bbb
Merge branch 'space_charge' of https://github.com/greglenerd/cheetah …
greglenerd May 6, 2024
85a6edd
new test, little change in accelerator.py
greglenerd May 7, 2024
86247e1
Merge branch 'desy_master' into space_charge_merged
greglenerd May 7, 2024
ed861b8
start adapting to PR 116
greglenerd May 7, 2024
1ea0d51
Fix bugs associated with array shapes
RemiLehe May 8, 2024
d456418
Fix initialization of the Green function
RemiLehe May 9, 2024
d8bbc5d
Fix errors with shapes
RemiLehe May 9, 2024
9b2d4f1
Fix remaining bugs
RemiLehe May 9, 2024
86c7d89
Fix test
RemiLehe May 9, 2024
4c9c94a
Merge pull request #1 from RemiLehe/space_charge_merged
greglenerd May 10, 2024
8f03a20
Merge branch 'master' into space_charge
jank324 May 14, 2024
28febad
Reformat code with `black`
RemiLehe May 23, 2024
5e86219
Merge remote-tracking branch 'public/master' into space_charge
RemiLehe May 28, 2024
975d4db
Merge branch 'master' into space_charge
jank324 May 29, 2024
7b555b8
Fix CI
RemiLehe May 31, 2024
d2de1fc
Fix CI
RemiLehe May 31, 2024
0da8484
Apply isort corrections
RemiLehe May 31, 2024
cf5119f
Apply flake8 corrections
RemiLehe May 31, 2024
a607d8d
Apply suggestions from code review
cr-xu Jun 4, 2024
527e416
black formatting
cr-xu Jun 4, 2024
0dc127f
Apply suggestions from code review
cr-xu Jun 4, 2024
68cb8ee
Merge branch 'master' into space_charge
jank324 Jun 6, 2024
71f421c
Set the random seed in space charge test
RemiLehe Jun 9, 2024
0b64405
Apply suggestions from code review
RemiLehe Jun 9, 2024
42c390c
Update test file
RemiLehe Jun 9, 2024
4a756de
Add docstrings
RemiLehe Jun 9, 2024
e38111b
Change a few names
RemiLehe Jun 9, 2024
ed7c924
Reformat files
RemiLehe Jun 9, 2024
665aabc
Apply suggestions from code review
RemiLehe Jun 11, 2024
bcd9703
Update name: n_batch -> batch_size
RemiLehe Jun 11, 2024
9947d97
Update formatting
RemiLehe Jun 11, 2024
5e87d63
Merge branch 'master' into space_charge
jank324 Jun 13, 2024
3db651b
Move random seed to individual test functions
jank324 Jun 15, 2024
bbb3057
Add test for space charge gradient computation
jank324 Jun 15, 2024
f7dd7a2
Minor formating changes
jank324 Jun 15, 2024
fd0be7a
Replace obvious in-place operations with out-of-place alternatives
jank324 Jun 15, 2024
98cbaf5
Fix in-place gradient error
jank324 Jun 15, 2024
27e58ab
Revert "Replace obvious in-place operations with out-of-place alterna…
jank324 Jun 15, 2024
e9a1f47
Add test to check if space charge works vectorised accroding to #116
jank324 Jun 15, 2024
cf629ca
Implement mostly functional vectorisation for space charge
jank324 Jun 16, 2024
d505d57
First `index_put_` location where vectorisation didn't work
jank324 Jun 18, 2024
81921ba
Add test to check that vectorisation doesn't just not crash but also …
jank324 Jun 18, 2024
78da8c5
Fix test name for better test selection
jank324 Jun 18, 2024
b9725ff
Fix shape issue in vectorised beam expansion test and add segment len…
jank324 Jun 18, 2024
ae1de55
Fix gradient issue
jank324 Jun 18, 2024
161e759
Refactor `gammaref`
jank324 Jun 18, 2024
20d4db0
Refactor `betaref`
jank324 Jun 18, 2024
83e2b6b
Refactor moments computations
jank324 Jun 18, 2024
e66a08e
Remove unused constants
jank324 Jun 18, 2024
c447667
Fix length computation
jank324 Jun 18, 2024
7b57805
Remove out-of-date todo
jank324 Jun 18, 2024
cc225bb
Add `SpaceChargeKick` to documentation
jank324 Jun 18, 2024
bab431d
Merge branch 'master' into space_charge
jank324 Jun 18, 2024
35d80c4
Fix issues in `_compute_forces_`
cr-xu Jun 19, 2024
396d840
Fix index out of range issue
cr-xu Jun 19, 2024
043c7b0
Merge branch 'master' into space_charge
jank324 Jun 19, 2024
2b15344
Add entry to changelog
jank324 Jun 19, 2024
2fe41b1
Merge branch 'space_charge' of github.com:greglenerd/cheetah into spa…
jank324 Jun 19, 2024
16d3b77
Remove not-needed `atol` as per comment by @ax3l
jank324 Jun 19, 2024
3d8b97d
Remove `try_batched` notebook
jank324 Jun 19, 2024
fea6c83
Implement plotting for `SpaceChargeKick`
jank324 Jun 19, 2024
fd9ee34
Use int instead of long
RemiLehe Jun 19, 2024
922417b
Updates to and from moments conversion method names
jank324 Jun 20, 2024
2a47899
Merge branch 'space_charge' of github.com:greglenerd/cheetah into spa…
jank324 Jun 20, 2024
b00784c
Merge branch 'master' into space_charge
jank324 Jun 20, 2024
86ffa53
Change `SpaceChargeKick` plotting to a single line
jank324 Jun 20, 2024
61cd0b4
Implement `SpaceChargeKick.defining_features`
jank324 Jun 20, 2024
a53a450
Implement `SpaceChargeKick.__repr__`
jank324 Jun 20, 2024
2ca6cf8
Move grid extent sigma comment up by one line
jank324 Jun 20, 2024
4092d82
Improve docstring and comments
cr-xu Jun 21, 2024
5e3df78
Update cheetah/particles/particle_beam.py
cr-xu Jun 21, 2024
f3aea76
Slight change to comment formatting
jank324 Jun 21, 2024
1e8235a
Fix flake8 warning
jank324 Jun 21, 2024
3e2503a
Improve docstring for `from_xyz_pxpypz` as well
jank324 Jun 21, 2024
c813416
Tiny formating improvement
jank324 Jun 21, 2024
d0ccaa7
Fix main docstring
jank324 Jun 21, 2024
daa08b9
Another fix to the docstring
jank324 Jun 21, 2024
595ab2c
Attempt to fix docstring bullet point rendering
jank324 Jun 21, 2024
4d9b311
Another bullet point docstring fix
jank324 Jun 22, 2024
ee07a52
Some refactoring
RemiLehe Jun 25, 2024
1a0131c
Simplify division by cell volume
RemiLehe Jun 25, 2024
913e1a8
Update a few docstrings
RemiLehe Jun 25, 2024
42e4faf
Reduce length of lines
RemiLehe Jun 25, 2024
ee2096e
Update name from moments to xp_coordinates
RemiLehe Jun 25, 2024
1e1888f
Use rfft and irfft
RemiLehe Jun 25, 2024
b2e45c0
Change from moments to xp_coords
RemiLehe Jun 25, 2024
f31034f
Merge branch 'master' into space_charge
jank324 Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 222 additions & 2 deletions cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
electron_mass_eV = torch.tensor(
physical_constants["electron mass energy equivalent in MeV"][0] * 1e6
)

epsilon_0 = torch.tensor(constants.epsilon_0)

class Element(ABC, nn.Module):
"""
Expand All @@ -45,7 +45,7 @@ def __init__(self, name: Optional[str] = None) -> None:
self.name = name if name is not None else generate_unique_name()

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
r"""
"""
greglenerd marked this conversation as resolved.
Show resolved Hide resolved
Generates the element's transfer map that describes how the beam and its
particles are transformed when traveling through the element.
The state vector consists of 6 values with a physical meaning:
Expand Down Expand Up @@ -299,6 +299,226 @@ def defining_features(self) -> list[str]:
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={repr(self.length)})"

class SpaceChargeKick(Element):
"""
Simulates space charge effects on a beam.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these params might be changed: I was thinking of a way to integrate spacechargekicks as a setting (eg spacecharge = True) which would automatically build spacechargekick objects and incorporate them appropriately, as they don't represent any physical element of the accelerator. @cr-xu @jank324

Copy link
Member

@jank324 jank324 May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at some point in the future, we should think about creating a Effects class and make SpaceChargeKick a subclass of that. But right now this doesn't make sense yet.

Nevertheless, here is how I think we should integrate adding space charge to elements right now in a way that can easily be extended in the future:

Basically what I would do is to add a method to the Element class that looks something like this:

class Element:

    ...

    def with_space_charge(self, resolution: float = 0.01, *args, **kwargs) -> Segment:
        splits = self.split(resolution)
        splits_with_space_charge = # List of [split, sc, split, sc, ..., split, sc] ... probably itertools has a nice way to create this ... *args and **kwargs go into SpaceChargeKick
        return Segment(elements=splits_with_space_charge, name=f"{self.name}_with_space_charge")

This should end up looking similar to what @RemiLehe proposed in the very beginning. So if you do

Drift(length=0.5, name="my_drift").with_space_charge(resolution=0.25, nx=64)

you get

Segment(
    name="my_drift_with_space_charge",
    elements=[
        Drift(length=0.25), SpaceChargeKick(nx=64), Drift(length=0.25), SpaceChargeKick(nx=64)
    ],
)

This example is probably not quite correct in many ways, but I think it illustrates my idea.

The nice thing about doing it this way would be that it automatically works for all elements that implement split correctly, and that it would be relatively straightforward to extend this for other collective effects in the future.

Does this make sense?

Copy link

@ax3l ax3l May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative thought for the low level implementation: One could also create sub- or mixing classes for thin (L=0) and thick (L>0) elements. For elements in ImpactX, we currently use these mixin classes, e.g. Drift vs. thin Multipole.

Space charge kicks could be thin elements :)

For the high level interface, I agree on the above syntax. You want to automatically slice this up, a property or method on how many slices on the sliced element is a nice user interface.

greglenerd marked this conversation as resolved.
Show resolved Hide resolved
:param grid_points: Number of grid points in each dimension.
:param grid_dimensions: Dimensions of the grid in meters.
:param name: Unique identifier of the element.
"""

greglenerd marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
nx: Union[torch.Tensor, nn.Parameter,int],
ny: Union[torch.Tensor, nn.Parameter,int],
ns: Union[torch.Tensor, nn.Parameter,int],
dx: Union[torch.Tensor, nn.Parameter],
dy: Union[torch.Tensor, nn.Parameter],
ds: Union[torch.Tensor, nn.Parameter],
name: Optional[str] = None,
device=None,
dtype=torch.float32,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
greglenerd marked this conversation as resolved.
Show resolved Hide resolved
greglenerd marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(name=name)

self.nx = int(torch.as_tensor(nx, **factory_kwargs))
self.ny = int(torch.as_tensor(ny, **factory_kwargs))
self.ns = int(torch.as_tensor(ns, **factory_kwargs))
self.dx = torch.as_tensor(dx, **factory_kwargs) #in meters
self.dy = torch.as_tensor(dy, **factory_kwargs)
self.ds = torch.as_tensor(ds, **factory_kwargs)


def grid_shape(self) -> tuple[int,int,int]:
return (int(self.nx), int(self.ny), int(self.ns))


def grid_dimensions(self) -> torch.Tensor:
return torch.tensor([self.dx, self.dy, self.ds], device=self.dx.device)


def cell_size(self) -> torch.Tensor:
grid_shape = self.grid_shape()
grid_dimensions = self.grid_dimensions()
return 2*grid_dimensions / torch.tensor(grid_shape)


def space_charge_deposition(self, beam: ParticleBeam) -> torch.Tensor:
"""
Deposition of the beam on the grid using fully vectorized computation.
"""
grid_shape = self.grid_shape()
grid_dimensions = self.grid_dimensions()
cell_size = self.cell_size()

# Initialize the charge density grid
charge = torch.zeros(grid_shape, dtype=torch.float32)
greglenerd marked this conversation as resolved.
Show resolved Hide resolved

# Get particle positions and charges
particle_pos = beam.particles[:, [0, 2, 4]]
greglenerd marked this conversation as resolved.
Show resolved Hide resolved
particle_charge = beam.particle_charges

# Compute the normalized positions of the particles within the grid
normalized_pos = (particle_pos + grid_dimensions) / cell_size
greglenerd marked this conversation as resolved.
Show resolved Hide resolved

# Find the indices of the lower corners of the cells containing the particles
cell_indices = torch.floor(normalized_pos).type(torch.long)

# Calculate the weights for all surrounding cells
offsets = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]])
surrounding_indices = cell_indices.unsqueeze(1) + offsets # Shape: (n_particles, 8, 3)
weights = 1 - torch.abs(normalized_pos.unsqueeze(1) - surrounding_indices) # Shape: (n_particles, 8, 3)
greglenerd marked this conversation as resolved.
Show resolved Hide resolved
cell_weights = weights.prod(dim=2) # Shape: (n_particles, 8)
greglenerd marked this conversation as resolved.
Show resolved Hide resolved

# Add the charge contributions to the cells
idx_x, idx_y, idx_s = surrounding_indices.view(-1, 3).T
greglenerd marked this conversation as resolved.
Show resolved Hide resolved
valid_mask = (idx_x >= 0) & (idx_x < grid_shape[0]) & \
greglenerd marked this conversation as resolved.
Show resolved Hide resolved
(idx_y >= 0) & (idx_y < grid_shape[1]) & \
(idx_s >= 0) & (idx_s < grid_shape[2])

# Accumulate the charge contributions
indices = torch.stack([idx_x[valid_mask], idx_y[valid_mask], idx_s[valid_mask]], dim=0)
greglenerd marked this conversation as resolved.
Show resolved Hide resolved
repeated_charges = particle_charge.repeat_interleave(8)
greglenerd marked this conversation as resolved.
Show resolved Hide resolved
values = (cell_weights.view(-1) * repeated_charges)[valid_mask]
charge.index_put_(tuple(indices), values, accumulate=True)
cell_volume = cell_size[0] * cell_size[1] * cell_size[2]
jank324 marked this conversation as resolved.
Show resolved Hide resolved
greglenerd marked this conversation as resolved.
Show resolved Hide resolved

return charge/cell_volume # Normalize by the cell volume, so that the charge density is in C/m^3
greglenerd marked this conversation as resolved.
Show resolved Hide resolved


def integrated_potential(self, x, y, s) -> torch.Tensor:
r = torch.sqrt(x**2 + y**2 + s**2)
G = (-0.5 * s**2 * torch.atan(x * y / (s * r))
-0.5 * y**2 * torch.atan(x * s / (y * r))
-0.5 * x**2 * torch.atan(y * s / (x * r))
+ y * s * torch.asinh(x / torch.sqrt(y**2 + s**2))
+ x * s * torch.asinh(y / torch.sqrt(x**2 + s**2))
+ x * y * torch.asinh(s / torch.sqrt(x**2 + y**2)))
return G


def cyclic_rho(self,beam: ParticleBeam) -> torch.Tensor:
"""
Compute the charge density on the grid using the cyclic deposition method.
"""
grid_shape = self.grid_shape()
charge_density = self.space_charge_deposition(beam)

# Double the dimensions
new_dims = tuple(dim * 2 for dim in grid_shape)
jank324 marked this conversation as resolved.
Show resolved Hide resolved

# Create a new tensor with the doubled dimensions, filled with zeros
cyclic_charge_density = torch.zeros(new_dims)

# Copy the original charge_density values to the beginning of the new tensor
cyclic_charge_density[:charge_density.shape[0], :charge_density.shape[1], :charge_density.shape[2]] = charge_density
greglenerd marked this conversation as resolved.
Show resolved Hide resolved
return cyclic_charge_density

def IGF(self, beam: ParticleBeam) -> torch.Tensor:
gamma = beam.energy / rest_energy
dx, dy, ds = self.cell_size()[0], self.cell_size()[1], self.cell_size()[2] * gamma # ds is scaled by gamma
nx, ny, ns = self.grid_shape()

# Create coordinate grids
x = torch.arange(nx) * dx
y = torch.arange(ny) * dy
s = torch.arange(ns) * ds
x_grid, y_grid, s_grid = torch.meshgrid(x, y, s, indexing='ij')

# Compute the Green's function values
G_values = (
self.integrated_potential(x_grid + 0.5 * dx, y_grid + 0.5 * dy, s_grid + 0.5 * ds)
- self.integrated_potential(x_grid - 0.5 * dx, y_grid + 0.5 * dy, s_grid + 0.5 * ds)
- self.integrated_potential(x_grid + 0.5 * dx, y_grid - 0.5 * dy, s_grid + 0.5 * ds)
- self.integrated_potential(x_grid + 0.5 * dx, y_grid + 0.5 * dy, s_grid - 0.5 * ds)
+ self.integrated_potential(x_grid + 0.5 * dx, y_grid - 0.5 * dy, s_grid - 0.5 * ds)
+ self.integrated_potential(x_grid - 0.5 * dx, y_grid + 0.5 * dy, s_grid - 0.5 * ds)
+ self.integrated_potential(x_grid - 0.5 * dx, y_grid - 0.5 * dy, s_grid + 0.5 * ds)
- self.integrated_potential(x_grid - 0.5 * dx, y_grid - 0.5 * dy, s_grid - 0.5 * ds)
)

# Initialize the grid with double dimensions
grid = torch.zeros(2 * nx, 2 * ny, 2 * ns)
greglenerd marked this conversation as resolved.
Show resolved Hide resolved

# Fill the grid with G_values and its periodic copies
grid[:nx, :ny, :ns] = G_values
grid[nx+1:, :ny, :ns] = G_values[1:,:,:].flip(dims=[0]) # Reverse the x dimension, excluding the first element
grid[:nx, ny+1:, :ns] = G_values[:, 1:,:].flip(dims=[1]) # Reverse the y dimension, excluding the first element
grid[:nx, :ny, ns+1:] = G_values[:, :, 1:].flip(dims=[2]) # Reverse the s dimension, excluding the first element
grid[nx+1:, ny+1:, :ns] = G_values[1:, 1:,:].flip(dims=[0, 1]) # Reverse the x and y dimensions
grid[:nx, ny+1:, ns+1:] = G_values[:, 1:, 1:].flip(dims=[1, 2]) # Reverse the y and s dimensions
grid[nx+1:, :ny, ns+1:] = G_values[1:, :, 1:].flip(dims=[0, 2]) # Reverse the x and s dimensions
grid[nx+1:, ny+1:, ns+1:] = G_values[1:, 1:, 1:].flip(dims=[0, 1, 2]) # Reverse all dimensions
jank324 marked this conversation as resolved.
Show resolved Hide resolved

return grid


def solve_poisson_equation(self, beam: ParticleBeam) -> torch.Tensor: #works only for ParticleBeam at this stage
"""
Solves the Poisson equation for the given charge density.
"""
# Compute the charge density
charge_density = self.cyclic_rho(beam)

# Compute the Fourier transform of the charge density
charge_density_ft = torch.fft.fftn(charge_density)
jank324 marked this conversation as resolved.
Show resolved Hide resolved

# Compute the integrated Green's function
integrated_green_function = self.IGF(beam)

# Compute the integrated Green's function's Fourier transform
integrated_green_function_ft = torch.fft.fftn(integrated_green_function)
jank324 marked this conversation as resolved.
Show resolved Hide resolved

# Compute the Fourier transform of the potential
potential_ft = charge_density_ft * integrated_green_function_ft

# Compute the potential
potential = (1/4*torch.pi*epsilon_0)*torch.fft.ifftn(potential_ft).real
greglenerd marked this conversation as resolved.
Show resolved Hide resolved

# Return the physical potential
return potential[:charge_density.shape[0]//2, :charge_density.shape[1]//2, :charge_density.shape[2]//2]


def split(self, resolution: torch.Tensor) -> list[Element]:
# TODO: Implement splitting for cavity properly, for now just returns the
# element itself
return [self]


def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
device = self.length.device
dtype = self.length.dtype

gamma = energy / rest_energy.to(device=device, dtype=dtype)
igamma2 = (
1 / gamma**2
if gamma != 0
else torch.tensor(0.0, device=device, dtype=dtype)
)
beta = torch.sqrt(1 - igamma2)

tm = torch.eye(7, device=device, dtype=dtype)
tm[0, 1] = self.length
tm[2, 3] = self.length
tm[4, 5] = -self.length / beta**2 * igamma2

return tm

@property
def is_skippable(self) -> bool:
return True

def plot(self, ax: matplotlib.axes.Axes, s: float) -> None:
jank324 marked this conversation as resolved.
Show resolved Hide resolved
pass

@property
def defining_features(self) -> list[str]:
return super().defining_features + ["length"]

def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={repr(self.length)})"


class Quadrupole(Element):
"""
Expand Down
Empty file removed test.py
Empty file.
43 changes: 43 additions & 0 deletions tests/test_space_charge_kick.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
import torch

import cheetah

def test_charge_deposition():
"""
Test that the charge deposition is correct for a particle beam. The first test checks that the total charge is preserved, and the second test checks that the charge is deposited in the correct grid cells.
"""
space_charge_kick = cheetah.SpaceChargeKick(nx=32,ny=32,ns=32,dx=3e-9,dy=3e-9,ds=2e-6)
incoming_beam = cheetah.ParticleBeam.from_parameters(
num_particles=torch.tensor(1000),
sigma_xp=torch.tensor(2e-7),
sigma_yp=torch.tensor(2e-7),
)
total_charge = incoming_beam.total_charge
space_charge_grid = space_charge_kick.space_charge_deposition(incoming_beam)

assert torch.isclose(space_charge_grid.sum() * space_charge_kick.grid_resolution ** 3, torch.tensor(total_charge), atol=1e-12) # grid_resolution is a parameter of the space charge kick #Total charge is preserved

# something similar to the read function in the CIC code should be implemented
assert outgoing_beam.sigma_y > incoming_beam.sigma_y


@pytest.mark.skip(
reason="Requires rewriting Element and Beam member variables to be buffers."
)
def test_device_like_torch_module():
"""
Test that when changing the device, Drift reacts like a `torch.nn.Module`.
"""
# There is no point in running this test, if there aren't two different devices to
# move between
if not torch.cuda.is_available():
return

element = cheetah.Drift(length=torch.tensor(0.2), device="cuda")

assert element.length.device.type == "cuda"

element = element.cpu()

assert element.length.device.type == "cpu"
Loading