Skip to content

Commit

Permalink
Numpy removed from computations
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcraven committed Oct 21, 2024
1 parent dd86062 commit 5891523
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 329 deletions.
22 changes: 8 additions & 14 deletions xrd_simulator/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
import dill
from scipy.spatial import ConvexHull, HalfspaceIntersection
from scipy.optimize import linprog
from xrd_simulator.cuda import fw
if fw != np:
fw.array = fw.tensor

import torch

class Beam:
"""Represents a monochromatic xray beam as a convex polyhedra with uniform intensity.
Expand Down Expand Up @@ -89,12 +86,12 @@ def contains(self, points):
"""

normal_distances = fw.matmul(self.halfspaces[:,:3],points)
normal_distances = torch.matmul(self.halfspaces[:,:3],points)
if len(points.shape) == 1:
return fw.all(normal_distances + self.halfspaces[:, 3] < 0)
return torch.all(normal_distances + self.halfspaces[:, 3] < 0)
else:
return (
fw.sum(
torch.sum(
(
normal_distances
+ self.halfspaces[:, 3].reshape(self.halfspaces.shape[0], 1)
Expand Down Expand Up @@ -178,13 +175,10 @@ def load(cls, path):
raise ValueError("The loaded file must end with .beam")
with open(path, "rb") as f:
loaded = dill.load(f)
if fw is np:
pass
else:
loaded.wave_vector = fw.array(loaded.wave_vector, dtype=fw.float32)
loaded.vertices = fw.array(loaded.vertices, dtype=fw.float32)
loaded.polarization_vector = fw.array(loaded.polarization_vector, dtype=fw.float32)
loaded.halfspaces = fw.array(loaded.halfspaces, dtype=fw.float32)
loaded.wave_vector = torch.tensor(loaded.wave_vector, dtype=torch.float32)
loaded.vertices = torch.tensor(loaded.vertices, dtype=torch.float32)
loaded.polarization_vector = torch.tensor(loaded.polarization_vector, dtype=torch.float32)
loaded.halfspaces = torch.tensor(loaded.halfspaces, dtype=torch.float32)
return loaded


Expand Down
5 changes: 1 addition & 4 deletions xrd_simulator/cuda.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import torch
import numpy as np
import pandas as pd
# Default to False
fw = np

# ===============================================
torch.set_default_device('cpu')
try:
# Check if CUDA is available
if torch.cuda.is_available():
print("CUDA is available and GPUs are found.")
gpu = input("Do you want to run in GPU? [y/n]").strip().lower() or 'y'
if gpu == 'y':
fw = torch
torch.set_default_device('cuda')
print("Running in GPU...")
else:
Expand Down
201 changes: 78 additions & 123 deletions xrd_simulator/detector.py

Large diffs are not rendered by default.

62 changes: 27 additions & 35 deletions xrd_simulator/laue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import numpy as np
import cupy as cp
import torch
from xrd_simulator import utils
from xrd_simulator.cuda import fw
from xrd_simulator import utils,cuda


def get_G(U, B, G_hkl):
"""Compute the diffraction vector
Expand All @@ -25,16 +25,11 @@ def get_G(U, B, G_hkl):
"""

if fw is np:
U = U.astype(fw.float32)
B = B.astype(fw.float32)
G_hkl = G_hkl.astype(fw.float32)
else:
U = fw.asarray(U,dtype=fw.float32)
B = fw.asarray(B,dtype=fw.float32)
G_hkl = fw.asarray(G_hkl,dtype=fw.float32)
U = torch.asarray(U,dtype=torch.float32)
B = torch.asarray(B,dtype=torch.float32)
G_hkl = torch.asarray(G_hkl,dtype=torch.float32)

return fw.matmul(fw.matmul(U, B), G_hkl.T)
return torch.matmul(torch.matmul(U, B), G_hkl.T)



Expand Down Expand Up @@ -96,11 +91,11 @@ def find_solutions_to_tangens_half_angle_equation(

# Ensure G_0 has at least 3 dimensions
if len(G_0.shape) == 2:
G_0 = G_0[fw.newaxis, :, :]
G_0 = G_0[torch.newaxis, :, :]

# Compute rho_0 and rho_2
rho_0 = fw.matmul(rho_0_factor, G_0)
rho_2 = fw.matmul(rho_2_factor, G_0) + fw.sum(G_0**2, axis=1) / 2.0
rho_0 = torch.matmul(rho_0_factor, G_0)
rho_2 = torch.matmul(rho_2_factor, G_0) + torch.sum(G_0**2, axis=1) / 2.0
denominator = rho_2 - rho_0
numerator = rho_2 + rho_0

Expand All @@ -109,14 +104,14 @@ def find_solutions_to_tangens_half_angle_equation(
denominator[denominator==0] = np.nan

# Calculate coefficients for quadratic equation
a = fw.divide(
fw.matmul(rho_1_factor, G_0),
a = torch.divide(
torch.matmul(rho_1_factor, G_0),
denominator,
out=fw.full_like(rho_0, np.nan)
out=torch.full_like(rho_0, np.nan)
)

b = fw.divide(
numerator, denominator, out=fw.full_like(rho_0, np.nan)
b = torch.divide(
numerator, denominator, out=torch.full_like(rho_0, np.nan)
)

# Clean up unnecessary variables
Expand All @@ -132,23 +127,23 @@ def find_solutions_to_tangens_half_angle_equation(
# discriminant[discriminant>10] = np.nan

# Calculate solutions for s
s1 = -a + fw.sqrt(discriminant)
s2 = -a - fw.sqrt(discriminant)
s1 = -a + torch.sqrt(discriminant)
s2 = -a - torch.sqrt(discriminant)
'''The bug is above this
# Clean up discriminant and a
# del discriminant, a
s = fw.concatenate((s1,s2),axis=0)
s = torch.concatenate((s1,s2),axis=0)
# del s1,s2
# Calculate solutions for t1 and t2
t = 2 * fw.arctan(s) / delta_omega
t = 2 * torch.arctan(s) / delta_omega
# del s,delta_omega
# Filter solutions within range [0, 1]
valid_t_indices = fw.logical_and(t >= 0, t <= 1)
valid_t_indices = torch.logical_and(t >= 0, t <= 1)
# del t
peak_index = fw.argwhere(valid_t_indices)
peak_index = torch.argwhere(valid_t_indices)
# peak_index = peak_index % G_0.shape[0]
# del valid_t_indices
grains = peak_index[:, 0]
Expand All @@ -157,23 +152,20 @@ def find_solutions_to_tangens_half_angle_equation(
times = t[grains,planes]
'''

t1 = 2 * fw.arctan(s1) / delta_omega
indices_t1 = fw.argwhere(fw.logical_and(t1 >= 0, t1 <= 1))
t1 = 2 * torch.arctan(s1) / delta_omega
indices_t1 = torch.argwhere(torch.logical_and(t1 >= 0, t1 <= 1))
values_t1 = t1[indices_t1[:,0], indices_t1[:,1]]

t2 = 2 * fw.arctan(s2) / delta_omega
indices_t2 = fw.argwhere(fw.logical_and(t2 >= 0, t2 <= 1))
t2 = 2 * torch.arctan(s2) / delta_omega
indices_t2 = torch.argwhere(torch.logical_and(t2 >= 0, t2 <= 1))
values_t2 = t2[indices_t2[:,0], indices_t2[:,1]]

peak_index = fw.concatenate((indices_t1, indices_t2), axis=0)
times = fw.concatenate((values_t1, values_t2), axis=0)
peak_index = torch.concatenate((indices_t1, indices_t2), axis=0)
times = torch.concatenate((values_t1, values_t2), axis=0)

grains = peak_index[:, 0]
planes = peak_index[:, 1]

if fw is np:
G_0 = fw.transpose(G_0,(0,2,1))
else:
G_0 = fw.transpose(G_0,2,1)
G_0 = torch.transpose(G_0,2,1)
G = G_0[grains, planes]
return grains, planes, times, G
53 changes: 23 additions & 30 deletions xrd_simulator/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
Below follows a detailed description of the RigidBodyMotion class attributes and functions.
"""
import numpy as np
import dill
from xrd_simulator.cuda import fw
if fw != np:
fw.array = fw.tensor
import torch

class RigidBodyMotion():
"""Rigid body transformation of euclidean points by an euler axis rotation and a translation.
Expand Down Expand Up @@ -46,13 +43,13 @@ class RigidBodyMotion():
"""

def __init__(self, rotation_axis, rotation_angle, translation, origin=fw.zeros((3,))):
assert rotation_angle < fw.pi and rotation_angle > 0, "The rotation angle must be in [0 pi]"
def __init__(self, rotation_axis, rotation_angle, translation, origin=torch.zeros((3,))):
assert rotation_angle < torch.pi and rotation_angle > 0, "The rotation angle must be in [0 pi]"
self.rotator = _RodriguezRotator(rotation_axis)
self.rotation_axis = fw.array(rotation_axis,dtype=fw.float32)
self.rotation_angle = fw.array(rotation_angle,dtype=fw.float32)
self.translation = fw.array(translation,dtype=fw.float32)
self.origin = fw.array(origin,dtype=fw.float32)
self.rotation_axis = torch.tensor(rotation_axis,dtype=torch.float32)
self.rotation_angle = torch.tensor(rotation_angle,dtype=torch.float32)
self.translation = torch.tensor(translation,dtype=torch.float32)
self.origin = torch.tensor(origin,dtype=torch.float32)

def __call__(self, vectors, time):
"""Find the transformation of a set of points at a prescribed time.
Expand All @@ -75,28 +72,27 @@ def __call__(self, vectors, time):
centered_vectors = vectors - origin
centered_rotated_vectors = self.rotator(centered_vectors, self.rotation_angle * time)
rotated_vectors = centered_rotated_vectors + origin
return fw.squeeze(rotated_vectors + translation * time)
return torch.squeeze(rotated_vectors + translation * time)

elif len(vectors.shape) == 2:
translation = fw.array(self.translation.reshape(1,3))
translation = torch.tensor(self.translation.reshape(1,3))
origin = self.origin.reshape(1,3)
centered_vectors = vectors - origin
centered_rotated_vectors = self.rotator(centered_vectors, self.rotation_angle * time)
rotated_vectors = centered_rotated_vectors + origin
if isinstance(time,(int,float)):
return rotated_vectors + translation * time
if fw is np:
return fw.squeeze(rotated_vectors + translation * fw.array(time)[:,fw.newaxis])
else:
return fw.squeeze(rotated_vectors + translation * time.unsqueeze(1))


return torch.squeeze(rotated_vectors + translation * time.unsqueeze(1))

elif len(vectors.shape) == 3:
translation = self.translation.reshape(1,3)
origin = self.origin.reshape(1,3)
centered_vectors = vectors - origin
centered_rotated_vectors = self.rotator(centered_vectors.reshape(-1,3), self.rotation_angle * fw.tile(time,(4,1)).T.reshape(-1)).reshape(-1,4,3)
centered_rotated_vectors = self.rotator(centered_vectors.reshape(-1,3), self.rotation_angle * torch.tile(time,(4,1)).T.reshape(-1)).reshape(-1,4,3)
rotated_vectors = centered_rotated_vectors + origin
return fw.squeeze(rotated_vectors + translation * fw.array(time)[:,fw.newaxis,fw.newaxis])
return torch.squeeze(rotated_vectors + translation * torch.tensor(time)[:,torch.newaxis,torch.newaxis])

def rotate(self, vectors, time):
"""Find the rotational transformation of a set of vectors at a prescribed time.
Expand Down Expand Up @@ -193,23 +189,20 @@ class _RodriguezRotator(object):
"""

def __init__(self, rotation_axis):
rotation_axis=fw.array(rotation_axis,dtype=fw.float32)
assert fw.allclose(fw.linalg.norm(rotation_axis),
fw.array(1.)), "The rotation axis must be length unity."
rotation_axis=torch.tensor(rotation_axis,dtype=torch.float32)
assert torch.allclose(torch.linalg.norm(rotation_axis),
torch.tensor(1.)), "The rotation axis must be length unity."
self.rotation_axis = rotation_axis
rx, ry, rz = self.rotation_axis
self.K = fw.array([[0, -rz, ry],
self.K = torch.tensor([[0, -rz, ry],
[rz, 0, -rx],
[-ry, rx, 0]])
self.K2 = self.K@self.K

def get_rotation_matrix(self, rotation_angle):
if fw is np:
rotation_matrix = fw.eye(3, 3)[:,:,fw.newaxis] + fw.sin(rotation_angle) * self.K[:,:,fw.newaxis] + (1 - fw.cos(rotation_angle)) * self.K2[:,:,fw.newaxis]
rotation_matrix = rotation_matrix.transpose(2,1,0)
else:
rotation_matrix = fw.eye(3, dtype=self.K.dtype, device=self.K.device).unsqueeze(2)+fw.sin(rotation_angle) * self.K.unsqueeze(2) + (1 - fw.cos(rotation_angle)) * self.K2.unsqueeze(2)
rotation_matrix = rotation_matrix.permute(2,1,0).float()

rotation_matrix = torch.eye(3, dtype=self.K.dtype, device=self.K.device).unsqueeze(2)+torch.sin(rotation_angle) * self.K.unsqueeze(2) + (1 - torch.cos(rotation_angle)) * self.K2.unsqueeze(2)
rotation_matrix = rotation_matrix.permute(2,1,0).float()
return rotation_matrix
def __call__(self, vectors, rotation_angle):
"""Rotate a vector in the plane described by v1 and v2 towards v2 a fraction s=[0,1].
Expand All @@ -224,7 +217,7 @@ def __call__(self, vectors, rotation_angle):
"""

R = self.get_rotation_matrix(rotation_angle)
# vectors = fw.array(vectors,dtype=fw.float32)
# vectors = torch.tensor(vectors,dtype=torch.float32)
if len(vectors.shape)==1:
vectors = vectors[None,:]
return fw.matmul(R,vectors[:,:,None])[:,:,0] # Syntax valid for the rotation fo the G vectors from the grains
return torch.matmul(R,vectors[:,:,None])[:,:,0] # Syntax valid for the rotation fo the G vectors from the grains
Loading

0 comments on commit 5891523

Please sign in to comment.