From 5891523f6e9c2f904171e2c8564eeb64d2637465 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Ravent=C3=B3s?= Date: Mon, 21 Oct 2024 17:51:51 +0200 Subject: [PATCH] Numpy removed from computations --- xrd_simulator/beam.py | 22 ++- xrd_simulator/cuda.py | 5 +- xrd_simulator/detector.py | 201 +++++++++++----------------- xrd_simulator/laue.py | 62 ++++----- xrd_simulator/motion.py | 53 ++++---- xrd_simulator/polycrystal.py | 146 +++++++++----------- xrd_simulator/scattering_factors.py | 24 ++-- xrd_simulator/utils.py | 33 ++--- 8 files changed, 217 insertions(+), 329 deletions(-) diff --git a/xrd_simulator/beam.py b/xrd_simulator/beam.py index 7a2673a..e2b6a17 100644 --- a/xrd_simulator/beam.py +++ b/xrd_simulator/beam.py @@ -12,10 +12,7 @@ import dill from scipy.spatial import ConvexHull, HalfspaceIntersection from scipy.optimize import linprog -from xrd_simulator.cuda import fw -if fw != np: - fw.array = fw.tensor - +import torch class Beam: """Represents a monochromatic xray beam as a convex polyhedra with uniform intensity. @@ -89,12 +86,12 @@ def contains(self, points): """ - normal_distances = fw.matmul(self.halfspaces[:,:3],points) + normal_distances = torch.matmul(self.halfspaces[:,:3],points) if len(points.shape) == 1: - return fw.all(normal_distances + self.halfspaces[:, 3] < 0) + return torch.all(normal_distances + self.halfspaces[:, 3] < 0) else: return ( - fw.sum( + torch.sum( ( normal_distances + self.halfspaces[:, 3].reshape(self.halfspaces.shape[0], 1) @@ -178,13 +175,10 @@ def load(cls, path): raise ValueError("The loaded file must end with .beam") with open(path, "rb") as f: loaded = dill.load(f) - if fw is np: - pass - else: - loaded.wave_vector = fw.array(loaded.wave_vector, dtype=fw.float32) - loaded.vertices = fw.array(loaded.vertices, dtype=fw.float32) - loaded.polarization_vector = fw.array(loaded.polarization_vector, dtype=fw.float32) - loaded.halfspaces = fw.array(loaded.halfspaces, dtype=fw.float32) + loaded.wave_vector = torch.tensor(loaded.wave_vector, dtype=torch.float32) + loaded.vertices = torch.tensor(loaded.vertices, dtype=torch.float32) + loaded.polarization_vector = torch.tensor(loaded.polarization_vector, dtype=torch.float32) + loaded.halfspaces = torch.tensor(loaded.halfspaces, dtype=torch.float32) return loaded diff --git a/xrd_simulator/cuda.py b/xrd_simulator/cuda.py index cee2761..2e8ae02 100644 --- a/xrd_simulator/cuda.py +++ b/xrd_simulator/cuda.py @@ -1,17 +1,14 @@ import torch -import numpy as np -import pandas as pd # Default to False -fw = np # =============================================== +torch.set_default_device('cpu') try: # Check if CUDA is available if torch.cuda.is_available(): print("CUDA is available and GPUs are found.") gpu = input("Do you want to run in GPU? [y/n]").strip().lower() or 'y' if gpu == 'y': - fw = torch torch.set_default_device('cuda') print("Running in GPU...") else: diff --git a/xrd_simulator/detector.py b/xrd_simulator/detector.py index 7d08780..ec99414 100644 --- a/xrd_simulator/detector.py +++ b/xrd_simulator/detector.py @@ -11,16 +11,11 @@ Below follows a detailed description of the detector class attributes and functions. """ - +import xrd_simulator.cuda import numpy as np from xrd_simulator import utils import dill -from scipy.signal import fftconvolve -from multiprocessing import Pool -import matplotlib.pyplot as plt -from xrd_simulator.cuda import fw -if fw != np: - fw.array = fw.tensor +import torch class Detector: """Represents a rectangular 2D area detector. @@ -59,20 +54,20 @@ class Detector: def __init__( self, pixel_size_z, pixel_size_y, det_corner_0, det_corner_1, det_corner_2 ): - self.det_corner_0 = fw.array(det_corner_0) - self.det_corner_1 = fw.array(det_corner_1) - self.det_corner_2 = fw.array(det_corner_2) + self.det_corner_0 = torch.tensor(det_corner_0) + self.det_corner_1 = torch.tensor(det_corner_1) + self.det_corner_2 = torch.tensor(det_corner_2) - self.pixel_size_z = fw.array(pixel_size_z) - self.pixel_size_y = fw.array(pixel_size_y) + self.pixel_size_z = torch.tensor(pixel_size_z) + self.pixel_size_y = torch.tensor(pixel_size_y) - self.zmax = fw.linalg.norm(self.det_corner_2 - self.det_corner_0) - self.ymax = fw.linalg.norm(self.det_corner_1 - self.det_corner_0) + self.zmax = torch.linalg.norm(self.det_corner_2 - self.det_corner_0) + self.ymax = torch.linalg.norm(self.det_corner_1 - self.det_corner_0) self.zdhat = (self.det_corner_2 - self.det_corner_0 ) / self.zmax self.ydhat = (self.det_corner_1 - self.det_corner_0 ) / self.ymax - self.normal = fw.linalg.cross(self.zdhat, self.ydhat) - self.normal = self.normal / fw.linalg.norm(self.normal) + self.normal = torch.linalg.cross(self.zdhat, self.ydhat) + self.normal = self.normal / torch.linalg.norm(self.normal) self.frames = [] self.pixel_coordinates = self._get_pixel_coordinates() self._point_spread_kernel_shape = (5, 5) @@ -130,72 +125,43 @@ def render(self, # Intersect scattering vectors with detector plane zd_yd_angle = self.get_intersection(peaks[:,13:16],peaks[:,16:19]) - - if fw is np: - peaks = fw.concatenate((peaks,zd_yd_angle),axis=1) - else: - peaks = fw.cat((peaks,zd_yd_angle),dim=1) + peaks = torch.cat((peaks,zd_yd_angle),dim=1) # Filter out peaks not hitting the detector peaks = peaks[self.contains(peaks[:,21], peaks[:,22])] # Add frame number at the end of the tensor - if fw is np: - bin_edges = fw.linspace(0, 1,number_of_frames + 1) - frames = fw.digitize(peaks[:,6], bin_edges) - frames = frames[:,fw.newaxis]-1 - peaks = fw.concatenate((peaks, frames), axis=1) - else: - bin_edges = fw.linspace(0, 1, steps=number_of_frames + 1) - frames = fw.bucketize(peaks[:,6].contiguous(), bin_edges).unsqueeze(1)-1 - peaks = fw.cat((peaks,frames),dim=1) + bin_edges = torch.linspace(0, 1, steps=number_of_frames + 1) + frames = torch.bucketize(peaks[:,6].contiguous(), bin_edges).unsqueeze(1)-1 + peaks = torch.cat((peaks,frames),dim=1) # Create a 3 colum matrix with X,Y and frame coordinates for each peak - if fw is np: - pixel_indices = fw.concatenate( - (((peaks[:, 21])/self.pixel_size_z).reshape(-1, 1), - ((peaks[:, 22])/self.pixel_size_y).reshape(-1, 1), - peaks[:, 24].reshape(-1, 1)), axis=1).astype(fw.int32) - frames_n = np.unique(peaks[:,24]).shape[0] - - else: - pixel_indices = fw.cat( - (((peaks[:, 21])/self.pixel_size_z).unsqueeze(1), - ((peaks[:, 22])/self.pixel_size_y).unsqueeze(1), - peaks[:, 24].unsqueeze(1)), dim=1).to(fw.int32) - frames_n = peaks[:,24].unique().shape[0] + pixel_indices = torch.cat( + (((peaks[:, 21])/self.pixel_size_z).unsqueeze(1), + ((peaks[:, 22])/self.pixel_size_y).unsqueeze(1), + peaks[:, 24].unsqueeze(1)), dim=1).to(torch.int32) + frames_n = peaks[:,24].unique().shape[0] # Create the future frames as an empty tensor - rendered_frames = fw.zeros((frames_n,self.pixel_coordinates.shape[0],self.pixel_coordinates.shape[1]),dtype=fw.float32) + rendered_frames = torch.zeros((frames_n,self.pixel_coordinates.shape[0],self.pixel_coordinates.shape[1]),dtype=torch.float32) # Generate the relative intensity for all the diffraction peaks using the different factors. structure_factors = peaks[:,5] lorentz_factors = peaks[:,22] polarization_factors = peaks[:,23] - # peaks = peaks.cpu().numpy() - # plt.subplot(1,3,1) - # plt.scatter(peaks[:,22],peaks[:,5],s=1) - # plt.subplot(1,3,2) - # plt.scatter(peaks[:,22],peaks[:,22],s=1) - # plt.subplot(1,3,3) - # plt.scatter(peaks[:,22],peaks[:,23],s=1) - # plt.show() - relative_intensity = structure_factors*lorentz_factors*polarization_factors + relative_intensity = structure_factors*polarization_factors*lorentz_factors # Turn from lists of peaks to rendered frames - if fw is np: - fw.add.at(rendered_frames, (pixel_indices[:,2],pixel_indices[:,0],pixel_indices[:,1]), relative_intensity) - else: - # Step 1: Find unique coordinates and the inverse indices - unique_coords, inverse_indices = fw.unique(pixel_indices, dim=0, return_inverse=True) + # Step 1: Find unique coordinates and the inverse indices + unique_coords, inverse_indices = torch.unique(pixel_indices, dim=0, return_inverse=True) - # Step 2: Count occurrences of each unique coordinate, weighting by the relative intensity - counts = fw.bincount(inverse_indices,weights=relative_intensity) + # Step 2: Count occurrences of each unique coordinate, weighting by the relative intensity + counts = torch.bincount(inverse_indices,weights=relative_intensity) - # Step 3: Combine unique coordinates and their counts into a new tensor (mx4) - result = fw.cat((unique_coords, counts.unsqueeze(1)), dim=1).type_as(rendered_frames) + # Step 3: Combine unique coordinates and their counts into a new tensor (mx4) + result = torch.cat((unique_coords, counts.unsqueeze(1)), dim=1).type_as(rendered_frames) - # Step 4: Use the new column as a pixel value to be added to each coordinate - rendered_frames[result[:,2].int(),result[:,0].int(),result[:,1].int()] = result[:,3] + # Step 4: Use the new column as a pixel value to be added to each coordinate + rendered_frames[result[:,2].int(),result[:,0].int(),result[:,1].int()] = result[:,3] #rendered_frames = self._apply_point_spread_function(rendered_frames) @@ -208,9 +174,9 @@ def render(self, def _apply_point_spread_function(self, frames): # Define the 3x3 Gaussian filter - gaussian_kernel = fw.array([[[1, 2, 1], + gaussian_kernel = torch.tensor([[[1, 2, 1], [2, 4, 2], - [1, 2, 1]]], dtype=fw.float32) / 16.0 + [1, 2, 1]]], dtype=torch.float32) / 16.0 ''' frames_n = frames.shape[0] if frames.ndim == 2: @@ -221,8 +187,8 @@ def _apply_point_spread_function(self, frames): gaussian_kernel = gaussian_kernel.repeat(frames_n,frames_n,1,1) # Perform the convolution - with fw.no_grad(): - output = fw.nn.functional.conv2d(frames.unsqueeze(0),weight=gaussian_kernel, padding=1) + with torch.no_grad(): + output = torch.nn.functional.conv2d(frames.unsqueeze(0),weight=gaussian_kernel, padding=1) ''' output = fftconvolve(frames,gaussian_kernel, mode="same") @@ -234,14 +200,14 @@ def pixel_index_to_theta_eta( incoming_wavevector, pixel_zd_index, pixel_yd_index, - scattering_origin=np.array([0, 0, 0]), + scattering_origin=torch.tensor([0, 0, 0]), ): """Compute bragg angle and azimuth angle for a detector pixel index. Args: pixel_zd_index (:obj:`float`): Coordinate in microns along detector zd axis. pixel_yd_index (:obj:`float`): Coordinate in microns along detector yd axis. - scattering_origin (obj:`numpy array`): Origin of diffraction in microns. Defaults to np.array([0, 0, 0]). + scattering_origin (obj:`numpy array`): Origin of diffraction in microns. Defaults to np.tensor([0, 0, 0]). Returns: (:obj:`tuple`) Bragg angle theta and azimuth angle eta (measured from det_corner_1 - det_corner_0 axis) in radians @@ -253,7 +219,7 @@ def pixel_index_to_theta_eta( incoming_wavevector, pixel_zd_coord, pixel_yd_coord, - scattering_origin=np.array([0, 0, 0]), + scattering_origin=scattering_origin, ) return theta, eta @@ -269,7 +235,7 @@ def pixel_coord_to_theta_eta( Args: pixel_zd_coord (:obj:`float`): Coordinate in microns along detector zd axis. pixel_yd_coord (:obj:`float`): Coordinate in microns along detector yd axis. - scattering_origin (obj:`numpy array`): Origin of diffraction in microns. Defaults to np.array([0, 0, 0]). + scattering_origin (obj:`numpy array`): Origin of diffraction in microns. Defaults to np.tensor([0, 0, 0]). Returns: (:obj:`tuple`) Bragg angle theta and azimuth angle eta (measured from det_corner_1 - det_corner_0 axis) in radians @@ -300,26 +266,18 @@ def get_intersection(self, ray_direction, source_point): (:obj:`tuple`) zd, yd in detector plane coordinates. """ - s = fw.matmul(self.det_corner_0 - source_point,self.normal) / fw.matmul(ray_direction,self.normal) - if fw is np: - intersection = source_point + ray_direction * s[:, fw.newaxis] - else: - intersection = source_point + ray_direction * s.unsqueeze(1) - zd = fw.matmul(intersection - self.det_corner_0, self.zdhat) - yd = fw.matmul(intersection - self.det_corner_0, self.ydhat) + s = torch.matmul(self.det_corner_0 - source_point,self.normal) / torch.matmul(ray_direction,self.normal) - # Calculate incident angle - if fw is np: - ray_dir_norm = ray_direction / fw.linalg.norm(ray_direction,axis=1)[:,fw.newaxis] - else: - ray_dir_norm = ray_direction / fw.norm(ray_direction, dim=1).unsqueeze(1) - normal_norm = self.normal / fw.linalg.norm(self.normal) + intersection = source_point + ray_direction * s.unsqueeze(1) + zd = torch.matmul(intersection - self.det_corner_0, self.zdhat) + yd = torch.matmul(intersection - self.det_corner_0, self.ydhat) - cosine_theta = fw.matmul(ray_dir_norm, -normal_norm) # The detector normal by default goes against the beam - incident_angle_deg = fw.arccos(cosine_theta) * (180 / fw.pi) - if fw is np: - return fw.array([zd, yd,incident_angle_deg]).T - return fw.stack((zd, yd, incident_angle_deg), dim=1) + # Calculate incident angle + ray_dir_norm = ray_direction / torch.norm(ray_direction, dim=1).unsqueeze(1) + normal_norm = self.normal / torch.linalg.norm(self.normal) + cosine_theta = torch.matmul(ray_dir_norm, -normal_norm) # The detector normal by default goes against the beam + incident_angle_deg = torch.arccos(cosine_theta) * (180 / torch.pi) + return torch.stack((zd, yd, incident_angle_deg), dim=1) def contains(self, zd, yd): """Determine if the detector coordinate zd,yd lies within the detector bounds. @@ -387,7 +345,7 @@ def get_wrapping_cone(self, k, source_point): fourth_corner_of_detector = self.det_corner_2 + ( self.det_corner_1 - self.det_corner_0[:] ) - geom_mat = fw.zeros((3, 4)) + geom_mat = torch.zeros((3, 4)) for i, det_corner in enumerate( [ self.det_corner_0, @@ -397,9 +355,9 @@ def get_wrapping_cone(self, k, source_point): ] ): geom_mat[:, i] = det_corner - source_point - normalised_local_coord_geom_mat = geom_mat / fw.linalg.norm(geom_mat, axis=0) - cone_opening = fw.arccos(fw.matmul(normalised_local_coord_geom_mat.T, k / fw.linalg.norm(k))) # These are two time Bragg angles - return fw.max(cone_opening) / 2.0 + normalised_local_coord_geom_mat = geom_mat / torch.linalg.norm(geom_mat, axis=0) + cone_opening = torch.arccos(torch.matmul(normalised_local_coord_geom_mat.T, k / torch.linalg.norm(k))) # These are two time Bragg angles + return torch.max(cone_opening) / 2.0 def save(self, path): """Save the detector object to disc (via pickling). Change the arrays formats to np first. @@ -408,20 +366,20 @@ def save(self, path): path (:obj:`str`): File path at which to save, ending with the desired filename. """ - self.det_corner_0 = np.array(self.det_corner_0) - self.det_corner_1 = np.array(self.det_corner_1) - self.det_corner_2 = np.array(self.det_corner_2) + self.det_corner_0 = np.tensor(self.det_corner_0) + self.det_corner_1 = np.tensor(self.det_corner_1) + self.det_corner_2 = np.tensor(self.det_corner_2) - self.pixel_size_z = np.array(self.pixel_size_z) - self.pixel_size_y = np.array(self.pixel_size_y) + self.pixel_size_z = np.tensor(self.pixel_size_z) + self.pixel_size_y = np.tensor(self.pixel_size_y) - self.zmax = np.array(self.zmax) - self.ymax = np.array(self.ymax) + self.zmax = np.tensor(self.zmax) + self.ymax = np.tensor(self.ymax) - self.zdhat = np.array(self.zdhat) - self.ydhat = np.array(self.ydhat) - self.normal = np.array(self.normal) - self.pixel_coordinates = np.array(self.pixel_coordinates) + self.zdhat = np.tensor(self.zdhat) + self.ydhat = np.tensor(self.ydhat) + self.normal = np.tensor(self.normal) + self.pixel_coordinates = np.tensor(self.pixel_coordinates) if not path.endswith(".det"): @@ -446,19 +404,16 @@ def load(cls, path): raise ValueError("The loaded motion file must end with .det") with open(path, "rb") as f: loaded=dill.load(f) - if fw is np: - pass - else: - loaded.normal = fw.array(loaded.normal, dtype=fw.float32) - loaded.det_corner_0 = fw.array(loaded.det_corner_0, dtype=fw.float32) - loaded.det_corner_1 = fw.array(loaded.det_corner_1, dtype=fw.float32) - loaded.det_corner_2 = fw.array(loaded.det_corner_2, dtype=fw.float32) - loaded.zdhat = fw.array(loaded.zdhat, dtype=fw.float32) - loaded.ydhat = fw.array(loaded.ydhat, dtype=fw.float32) - loaded.zmax = fw.array(loaded.zmax, dtype=fw.float32) - loaded.ymax = fw.array(loaded.ymax, dtype=fw.float32) - loaded.pixel_size_z = fw.array(loaded.pixel_size_z) - loaded.pixel_size_y = fw.array(loaded.pixel_size_y) + loaded.normal = torch.tensor(loaded.normal, dtype=torch.float32) + loaded.det_corner_0 = torch.tensor(loaded.det_corner_0, dtype=torch.float32) + loaded.det_corner_1 = torch.tensor(loaded.det_corner_1, dtype=torch.float32) + loaded.det_corner_2 = torch.tensor(loaded.det_corner_2, dtype=torch.float32) + loaded.zdhat = torch.tensor(loaded.zdhat, dtype=torch.float32) + loaded.ydhat = torch.tensor(loaded.ydhat, dtype=torch.float32) + loaded.zmax = torch.tensor(loaded.zmax, dtype=torch.float32) + loaded.ymax = torch.tensor(loaded.ymax, dtype=torch.float32) + loaded.pixel_size_z = torch.tensor(loaded.pixel_size_z) + loaded.pixel_size_y = torch.tensor(loaded.pixel_size_y) return loaded def _get_point_spread_function_kernel(self): @@ -482,11 +437,11 @@ def _get_point_spread_function_kernel(self): return kernel / np.sum(kernel) def _get_pixel_coordinates(self): - zds = fw.arange(0, self.zmax, self.pixel_size_z) - yds = fw.arange(0, self.ymax, self.pixel_size_y) - Z, Y = fw.meshgrid(zds, yds, indexing="ij") - Zds = fw.zeros((len(zds), len(yds), 3)) - Yds = fw.zeros((len(zds), len(yds), 3)) + zds = torch.arange(0, self.zmax, self.pixel_size_z) + yds = torch.arange(0, self.ymax, self.pixel_size_y) + Z, Y = torch.meshgrid(zds, yds, indexing="ij") + Zds = torch.zeros((len(zds), len(yds), 3)) + Yds = torch.zeros((len(zds), len(yds), 3)) for i in range(3): Zds[:, :, i] = Z Yds[:, :, i] = Y diff --git a/xrd_simulator/laue.py b/xrd_simulator/laue.py index a576f8b..ede7a38 100644 --- a/xrd_simulator/laue.py +++ b/xrd_simulator/laue.py @@ -5,8 +5,8 @@ import numpy as np import cupy as cp import torch -from xrd_simulator import utils -from xrd_simulator.cuda import fw +from xrd_simulator import utils,cuda + def get_G(U, B, G_hkl): """Compute the diffraction vector @@ -25,16 +25,11 @@ def get_G(U, B, G_hkl): """ - if fw is np: - U = U.astype(fw.float32) - B = B.astype(fw.float32) - G_hkl = G_hkl.astype(fw.float32) - else: - U = fw.asarray(U,dtype=fw.float32) - B = fw.asarray(B,dtype=fw.float32) - G_hkl = fw.asarray(G_hkl,dtype=fw.float32) + U = torch.asarray(U,dtype=torch.float32) + B = torch.asarray(B,dtype=torch.float32) + G_hkl = torch.asarray(G_hkl,dtype=torch.float32) - return fw.matmul(fw.matmul(U, B), G_hkl.T) + return torch.matmul(torch.matmul(U, B), G_hkl.T) @@ -96,11 +91,11 @@ def find_solutions_to_tangens_half_angle_equation( # Ensure G_0 has at least 3 dimensions if len(G_0.shape) == 2: - G_0 = G_0[fw.newaxis, :, :] + G_0 = G_0[torch.newaxis, :, :] # Compute rho_0 and rho_2 - rho_0 = fw.matmul(rho_0_factor, G_0) - rho_2 = fw.matmul(rho_2_factor, G_0) + fw.sum(G_0**2, axis=1) / 2.0 + rho_0 = torch.matmul(rho_0_factor, G_0) + rho_2 = torch.matmul(rho_2_factor, G_0) + torch.sum(G_0**2, axis=1) / 2.0 denominator = rho_2 - rho_0 numerator = rho_2 + rho_0 @@ -109,14 +104,14 @@ def find_solutions_to_tangens_half_angle_equation( denominator[denominator==0] = np.nan # Calculate coefficients for quadratic equation - a = fw.divide( - fw.matmul(rho_1_factor, G_0), + a = torch.divide( + torch.matmul(rho_1_factor, G_0), denominator, - out=fw.full_like(rho_0, np.nan) + out=torch.full_like(rho_0, np.nan) ) - b = fw.divide( - numerator, denominator, out=fw.full_like(rho_0, np.nan) + b = torch.divide( + numerator, denominator, out=torch.full_like(rho_0, np.nan) ) # Clean up unnecessary variables @@ -132,23 +127,23 @@ def find_solutions_to_tangens_half_angle_equation( # discriminant[discriminant>10] = np.nan # Calculate solutions for s - s1 = -a + fw.sqrt(discriminant) - s2 = -a - fw.sqrt(discriminant) + s1 = -a + torch.sqrt(discriminant) + s2 = -a - torch.sqrt(discriminant) '''The bug is above this # Clean up discriminant and a # del discriminant, a - s = fw.concatenate((s1,s2),axis=0) + s = torch.concatenate((s1,s2),axis=0) # del s1,s2 # Calculate solutions for t1 and t2 - t = 2 * fw.arctan(s) / delta_omega + t = 2 * torch.arctan(s) / delta_omega # del s,delta_omega # Filter solutions within range [0, 1] - valid_t_indices = fw.logical_and(t >= 0, t <= 1) + valid_t_indices = torch.logical_and(t >= 0, t <= 1) # del t - peak_index = fw.argwhere(valid_t_indices) + peak_index = torch.argwhere(valid_t_indices) # peak_index = peak_index % G_0.shape[0] # del valid_t_indices grains = peak_index[:, 0] @@ -157,23 +152,20 @@ def find_solutions_to_tangens_half_angle_equation( times = t[grains,planes] ''' - t1 = 2 * fw.arctan(s1) / delta_omega - indices_t1 = fw.argwhere(fw.logical_and(t1 >= 0, t1 <= 1)) + t1 = 2 * torch.arctan(s1) / delta_omega + indices_t1 = torch.argwhere(torch.logical_and(t1 >= 0, t1 <= 1)) values_t1 = t1[indices_t1[:,0], indices_t1[:,1]] - t2 = 2 * fw.arctan(s2) / delta_omega - indices_t2 = fw.argwhere(fw.logical_and(t2 >= 0, t2 <= 1)) + t2 = 2 * torch.arctan(s2) / delta_omega + indices_t2 = torch.argwhere(torch.logical_and(t2 >= 0, t2 <= 1)) values_t2 = t2[indices_t2[:,0], indices_t2[:,1]] - peak_index = fw.concatenate((indices_t1, indices_t2), axis=0) - times = fw.concatenate((values_t1, values_t2), axis=0) + peak_index = torch.concatenate((indices_t1, indices_t2), axis=0) + times = torch.concatenate((values_t1, values_t2), axis=0) grains = peak_index[:, 0] planes = peak_index[:, 1] - if fw is np: - G_0 = fw.transpose(G_0,(0,2,1)) - else: - G_0 = fw.transpose(G_0,2,1) + G_0 = torch.transpose(G_0,2,1) G = G_0[grains, planes] return grains, planes, times, G diff --git a/xrd_simulator/motion.py b/xrd_simulator/motion.py index 11ad0a2..e732c77 100644 --- a/xrd_simulator/motion.py +++ b/xrd_simulator/motion.py @@ -12,11 +12,8 @@ Below follows a detailed description of the RigidBodyMotion class attributes and functions. """ -import numpy as np import dill -from xrd_simulator.cuda import fw -if fw != np: - fw.array = fw.tensor +import torch class RigidBodyMotion(): """Rigid body transformation of euclidean points by an euler axis rotation and a translation. @@ -46,13 +43,13 @@ class RigidBodyMotion(): """ - def __init__(self, rotation_axis, rotation_angle, translation, origin=fw.zeros((3,))): - assert rotation_angle < fw.pi and rotation_angle > 0, "The rotation angle must be in [0 pi]" + def __init__(self, rotation_axis, rotation_angle, translation, origin=torch.zeros((3,))): + assert rotation_angle < torch.pi and rotation_angle > 0, "The rotation angle must be in [0 pi]" self.rotator = _RodriguezRotator(rotation_axis) - self.rotation_axis = fw.array(rotation_axis,dtype=fw.float32) - self.rotation_angle = fw.array(rotation_angle,dtype=fw.float32) - self.translation = fw.array(translation,dtype=fw.float32) - self.origin = fw.array(origin,dtype=fw.float32) + self.rotation_axis = torch.tensor(rotation_axis,dtype=torch.float32) + self.rotation_angle = torch.tensor(rotation_angle,dtype=torch.float32) + self.translation = torch.tensor(translation,dtype=torch.float32) + self.origin = torch.tensor(origin,dtype=torch.float32) def __call__(self, vectors, time): """Find the transformation of a set of points at a prescribed time. @@ -75,28 +72,27 @@ def __call__(self, vectors, time): centered_vectors = vectors - origin centered_rotated_vectors = self.rotator(centered_vectors, self.rotation_angle * time) rotated_vectors = centered_rotated_vectors + origin - return fw.squeeze(rotated_vectors + translation * time) + return torch.squeeze(rotated_vectors + translation * time) elif len(vectors.shape) == 2: - translation = fw.array(self.translation.reshape(1,3)) + translation = torch.tensor(self.translation.reshape(1,3)) origin = self.origin.reshape(1,3) centered_vectors = vectors - origin centered_rotated_vectors = self.rotator(centered_vectors, self.rotation_angle * time) rotated_vectors = centered_rotated_vectors + origin if isinstance(time,(int,float)): return rotated_vectors + translation * time - if fw is np: - return fw.squeeze(rotated_vectors + translation * fw.array(time)[:,fw.newaxis]) - else: - return fw.squeeze(rotated_vectors + translation * time.unsqueeze(1)) + + + return torch.squeeze(rotated_vectors + translation * time.unsqueeze(1)) elif len(vectors.shape) == 3: translation = self.translation.reshape(1,3) origin = self.origin.reshape(1,3) centered_vectors = vectors - origin - centered_rotated_vectors = self.rotator(centered_vectors.reshape(-1,3), self.rotation_angle * fw.tile(time,(4,1)).T.reshape(-1)).reshape(-1,4,3) + centered_rotated_vectors = self.rotator(centered_vectors.reshape(-1,3), self.rotation_angle * torch.tile(time,(4,1)).T.reshape(-1)).reshape(-1,4,3) rotated_vectors = centered_rotated_vectors + origin - return fw.squeeze(rotated_vectors + translation * fw.array(time)[:,fw.newaxis,fw.newaxis]) + return torch.squeeze(rotated_vectors + translation * torch.tensor(time)[:,torch.newaxis,torch.newaxis]) def rotate(self, vectors, time): """Find the rotational transformation of a set of vectors at a prescribed time. @@ -193,23 +189,20 @@ class _RodriguezRotator(object): """ def __init__(self, rotation_axis): - rotation_axis=fw.array(rotation_axis,dtype=fw.float32) - assert fw.allclose(fw.linalg.norm(rotation_axis), - fw.array(1.)), "The rotation axis must be length unity." + rotation_axis=torch.tensor(rotation_axis,dtype=torch.float32) + assert torch.allclose(torch.linalg.norm(rotation_axis), + torch.tensor(1.)), "The rotation axis must be length unity." self.rotation_axis = rotation_axis rx, ry, rz = self.rotation_axis - self.K = fw.array([[0, -rz, ry], + self.K = torch.tensor([[0, -rz, ry], [rz, 0, -rx], [-ry, rx, 0]]) self.K2 = self.K@self.K def get_rotation_matrix(self, rotation_angle): - if fw is np: - rotation_matrix = fw.eye(3, 3)[:,:,fw.newaxis] + fw.sin(rotation_angle) * self.K[:,:,fw.newaxis] + (1 - fw.cos(rotation_angle)) * self.K2[:,:,fw.newaxis] - rotation_matrix = rotation_matrix.transpose(2,1,0) - else: - rotation_matrix = fw.eye(3, dtype=self.K.dtype, device=self.K.device).unsqueeze(2)+fw.sin(rotation_angle) * self.K.unsqueeze(2) + (1 - fw.cos(rotation_angle)) * self.K2.unsqueeze(2) - rotation_matrix = rotation_matrix.permute(2,1,0).float() + + rotation_matrix = torch.eye(3, dtype=self.K.dtype, device=self.K.device).unsqueeze(2)+torch.sin(rotation_angle) * self.K.unsqueeze(2) + (1 - torch.cos(rotation_angle)) * self.K2.unsqueeze(2) + rotation_matrix = rotation_matrix.permute(2,1,0).float() return rotation_matrix def __call__(self, vectors, rotation_angle): """Rotate a vector in the plane described by v1 and v2 towards v2 a fraction s=[0,1]. @@ -224,7 +217,7 @@ def __call__(self, vectors, rotation_angle): """ R = self.get_rotation_matrix(rotation_angle) -# vectors = fw.array(vectors,dtype=fw.float32) +# vectors = torch.tensor(vectors,dtype=torch.float32) if len(vectors.shape)==1: vectors = vectors[None,:] - return fw.matmul(R,vectors[:,:,None])[:,:,0] # Syntax valid for the rotation fo the G vectors from the grains + return torch.matmul(R,vectors[:,:,None])[:,:,0] # Syntax valid for the rotation fo the G vectors from the grains diff --git a/xrd_simulator/polycrystal.py b/xrd_simulator/polycrystal.py index 863ba53..c252ebf 100644 --- a/xrd_simulator/polycrystal.py +++ b/xrd_simulator/polycrystal.py @@ -11,19 +11,13 @@ """ import copy -from multiprocessing import Pool +import xrd_simulator.cuda import numpy as np -from scipy.spatial import ConvexHull -import pandas as pd import dill from xfab import tools -from xrd_simulator.scattering_unit import ScatteringUnit from xrd_simulator import utils, laue from xrd_simulator.scattering_factors import lorentz,polarization -from xrd_simulator.cuda import fw - -if fw != np: - fw.array = fw.tensor +import torch def _diffract(dict): """ @@ -52,26 +46,26 @@ def _diffract(dict): beam = dict["beam"] rigid_body_motion = dict["rigid_body_motion"] phases = dict["phases"] - espherecentroids = fw.array(dict["espherecentroids"]) + espherecentroids = torch.tensor(dict["espherecentroids"]) orientation_lab = dict["orientation_lab"] eB = dict["eB"] element_phase_map = dict["element_phase_map"] - rho_0_factor = fw.matmul(-beam.wave_vector,rigid_body_motion.rotator.K2) - rho_1_factor = fw.matmul(beam.wave_vector,rigid_body_motion.rotator.K) - rho_2_factor = fw.matmul(beam.wave_vector,(fw.eye(3, 3) + rigid_body_motion.rotator.K2)) + rho_0_factor = torch.matmul(-beam.wave_vector,rigid_body_motion.rotator.K2) + rho_1_factor = torch.matmul(beam.wave_vector,rigid_body_motion.rotator.K) + rho_2_factor = torch.matmul(beam.wave_vector,(torch.eye(3, 3) + rigid_body_motion.rotator.K2)) - peaks = fw.empty((0,10),dtype=fw.float32) # We create a dataframe to store all the relevant values for each individual reflection inr an organized manner + peaks = torch.empty((0,10),dtype=torch.float32) # We create a dataframe to store all the relevant values for each individual reflection inr an organized manner # For each phase of the sample, we compute all reflections at once in a vectorized manner for i, phase in enumerate(phases): # Get all scatterers belonging to one phase at a time, and the corresponding miller indices. - grain_indices = fw.where(element_phase_map == i)[0] - miller_indices = fw.array(phase.miller_indices, dtype=fw.float32) + grain_indices = torch.where(element_phase_map == i)[0] + miller_indices = torch.tensor(phase.miller_indices, dtype=torch.float32) # Retrieve the structure factors of the miller indices for this phase, exclude the miller incides with zero structure factor - structure_factors = fw.sum(fw.array(phase.structure_factors, dtype=fw.float32)**2,axis=1) + structure_factors = torch.sum(torch.tensor(phase.structure_factors, dtype=torch.float32)**2,axis=1) miller_indices = miller_indices[structure_factors>1e-6] structure_factors = structure_factors[structure_factors>1e-6] @@ -90,22 +84,14 @@ def _diffract(dict): # We now assemble the tensors with the valid reflections for each grain and phase including time, hkl plane and G vector #Column names of peaks are 'grain_index','phase_number','h','k','l','structure_factors','times','G0_x','G0_y','G0_z') - if fw is np: - structure_factors = structure_factors[planes][:, fw.newaxis] - grain_indices = grain_indices[grains][:, fw.newaxis] - miller_indices = miller_indices[planes] - phase_index = fw.full((G0_xyz.shape[0],), i)[:, fw.newaxis] - times = times[:,fw.newaxis] - peaks_ith_phase = fw.concatenate((grain_indices, phase_index, miller_indices[:, fw.newaxis].squeeze(), structure_factors, times[:, fw.newaxis].squeeze(2), G0_xyz), axis=1) - peaks = fw.concatenate((peaks, peaks_ith_phase), axis=0) - else: - structure_factors = structure_factors[planes].unsqueeze(1) - grain_indices = grain_indices[grains].unsqueeze(1) - miller_indices = miller_indices[planes] - phase_index = fw.full((G0_xyz.shape[0],),i).unsqueeze(1) - times = times.unsqueeze(1) - peaks_ith_phase = fw.cat((grain_indices,phase_index,miller_indices,structure_factors,times,G0_xyz),dim=1) - peaks = fw.cat([peaks, peaks_ith_phase], axis=0) + + structure_factors = structure_factors[planes].unsqueeze(1) + grain_indices = grain_indices[grains].unsqueeze(1) + miller_indices = miller_indices[planes] + phase_index = torch.full((G0_xyz.shape[0],),i).unsqueeze(1) + times = times.unsqueeze(1) + peaks_ith_phase = torch.cat((grain_indices,phase_index,miller_indices,structure_factors,times,G0_xyz),dim=1) + peaks = torch.cat([peaks, peaks_ith_phase], axis=0) # Rotated G-vectors Gxyz = rigid_body_motion.rotate(peaks[:,7:10], -peaks[:,6]) #I dont know why the - sign is necessary, there is a bug somewhere and this is a patch. Sue me. @@ -116,27 +102,24 @@ def _diffract(dict): # Polarization factor polarization_factors = polarization(beam,K_out_xyz) - if fw is np: - Sources_xyz = rigid_body_motion(espherecentroids[peaks[:,0].astype(int)],peaks[:,6].astype(int)) - peaks = fw.concatenate((peaks,Gxyz,K_out_xyz,Sources_xyz,lorentz_factors[:,fw.newaxis],polarization_factors[:,fw.newaxis]),axis=1) - else: - Sources_xyz = rigid_body_motion(espherecentroids[peaks[:,0].int()],peaks[:,6].int()) - peaks = fw.cat((peaks,Gxyz,K_out_xyz,Sources_xyz,lorentz_factors.unsqueeze(1),polarization_factors.unsqueeze(1)),dim=1) + Sources_xyz = rigid_body_motion(espherecentroids[peaks[:,0].int()],peaks[:,6].int()) + peaks = torch.cat((peaks,Gxyz,K_out_xyz,Sources_xyz,lorentz_factors.unsqueeze(1),polarization_factors.unsqueeze(1)),dim=1) - """ - Column names of peaks are - 0: 'grain_index' 10: 'Gx' 20: 'polarization_factors' - 1: 'phase_number' 11: 'Gy' - 2: 'h' 12: 'Gz' - 3: 'k' 13: 'K_out_x' - 4: 'l' 14: 'K_out_y' - 5: 'structure_factors' 15: 'K_out_z' - 6: 'diffraction_times' 16: 'Source_x' - 7: 'G0_x' 17: 'Source_y' - 8: 'G0_y' 18: 'Source_z' - 9: 'G0_z' 19: 'lorentz_factors' - """ + + """ + Column names of peaks are + 0: 'grain_index' 10: 'Gx' 20: 'polarization_factors' + 1: 'phase_number' 11: 'Gy' + 2: 'h' 12: 'Gz' + 3: 'k' 13: 'K_out_x' + 4: 'l' 14: 'K_out_y' + 5: 'structure_factors' 15: 'K_out_z' + 6: 'diffraction_times' 16: 'Source_x' + 7: 'G0_x' 17: 'Source_y' + 8: 'G0_y' 18: 'Source_z' + 9: 'G0_z' 19: 'lorentz_factors' + """ # Filter out tets not illuminated peaks = peaks[peaks[:,17] < (beam.vertices[:, 1].max())] peaks = peaks[peaks[:,17] > (beam.vertices[:, 1].min())] @@ -221,7 +204,7 @@ def diffract( ): """Compute diffraction from the rotating and translating polycrystal while illuminated by an xray beam. - The xray beam interacts with the polycrystal producing scattering units which are stored in a detector fw. + The xray beam interacts with the polycrystal producing scattering units which are stored in a detector torch. The scattering units may be rendered as pixelated patterns on the detector by using a detector rendering option. @@ -241,7 +224,7 @@ def diffract( computation. Defaults to 1, i.e a single processes. number_of_frames (:obj:`int`): Optional keyword specifying the number of desired temporally equidistantly spaced frames to be collected. Defaulrenderts to 1, which means that the detector reads diffraction during the full rigid body - motion and integrates out the signal to a single fw. The number_of_frames keyword primarily allows for single + motion and integrates out the signal to a single torch. The number_of_frames keyword primarily allows for single rotation axis full 180 dgrs or 360 dgrs sample rotation data sets to be computed rapidly and convinently. proximity (:obj:`bool`): Set to False if all or most grains from the sample are expected to diffract. For instance, if the diffraction scan illuminates all grains from the sample at least once at a give angle/position. @@ -251,7 +234,7 @@ def diffract( """ - #beam.wave_vector = fw.array(beam.wave_vector, dtype=fw.float32) + #beam.wave_vector = torch.tensor(beam.wave_vector, dtype=torch.float32) #min_bragg_angle, max_bragg_angle = self._get_bragg_angle_bounds(beam, min_bragg_angle, max_bragg_angle) @@ -270,7 +253,6 @@ def diffract( peaks = _diffract(args) - """ Column names of peaks are 0: 'grain_index' 10: 'Gx' 20: 'yd' @@ -300,12 +282,9 @@ def transform(self, rigid_body_motion, time): Rot_mat = rigid_body_motion.rotator.get_rotation_matrix( rigid_body_motion.rotation_angle * time ) + self.orientation_lab = torch.matmul(Rot_mat, self.orientation_lab) + self.strain_lab = torch.matmul(torch.matmul(Rot_mat, self.strain_lab), Rot_mat.transpose(2,1)) - self.orientation_lab = fw.matmul(Rot_mat, self.orientation_lab) - if fw is np: - self.strain_lab = fw.matmul(fw.matmul(Rot_mat, self.strain_lab), Rot_mat.transpose(0,2,1)) - else: - self.strain_lab = fw.matmul(fw.matmul(Rot_mat, self.strain_lab), Rot_mat.transpose(2,1)) def save(self, path, save_mesh_as_xdmf=True): """Save polycrystal to disc (via pickling). @@ -377,17 +356,14 @@ def load(cls, path): raise ValueError("The loaded polycrystal file must end with .pc") with open(path, "rb") as f: loaded = dill.load(f) - if fw is np: - pass - else: - loaded.orientation_lab = fw.array(loaded.orientation_lab, dtype=fw.float32) - loaded.strain_lab = fw.array(loaded.strain_lab, dtype=fw.float32) - loaded.element_phase_map = fw.array(loaded.element_phase_map, dtype=fw.float32) - loaded._eB = fw.array(loaded._eB, dtype=fw.float32) - loaded.mesh_lab = cls._move_mesh_to_gpu(loaded.mesh_lab) - loaded.mesh_sample = cls._move_mesh_to_gpu(loaded.mesh_sample) - loaded.strain_sample = fw.array(loaded.strain_sample, dtype=fw.float32) - loaded.orientation_sample = fw.array(loaded.orientation_sample, dtype=fw.float32) + loaded.orientation_lab = torch.tensor(loaded.orientation_lab, dtype=torch.float32) + loaded.strain_lab = torch.tensor(loaded.strain_lab, dtype=torch.float32) + loaded.element_phase_map = torch.tensor(loaded.element_phase_map, dtype=torch.float32) + loaded._eB = torch.tensor(loaded._eB, dtype=torch.float32) + loaded.mesh_lab = cls._move_mesh_to_gpu(loaded.mesh_lab) + loaded.mesh_sample = cls._move_mesh_to_gpu(loaded.mesh_sample) + loaded.strain_sample = torch.tensor(loaded.strain_sample, dtype=torch.float32) + loaded.orientation_sample = torch.tensor(loaded.orientation_sample, dtype=torch.float32) return loaded def _instantiate_orientation(self, orientation, mesh): @@ -423,7 +399,7 @@ def _instantiate_phase(self, phases, element_phase_map, mesh): raise ValueError("element_phase_map not set for multiphase polycrystal") element_phase_map = np.zeros((mesh.number_of_elements,), dtype=int) else: - element_phase_map = np.array(element_phase_map) + element_phase_map = np.tensor(element_phase_map) return element_phase_map, phases def _instantiate_eB( @@ -441,7 +417,7 @@ def _instantiate_eB( B0s = np.zeros((len(phases), 3, 3)) for i, phase in enumerate(phases): B0s[i] = tools.form_b_mat(phase.unit_cell) - grain_indices = np.where(np.array(element_phase_map) == i)[0] + grain_indices = np.where(np.tensor(element_phase_map) == i)[0] _eB[grain_indices] = utils.lab_strain_to_B_matrix( strain_lab[grain_indices], orientation_lab[grain_indices], B0s[i] ) @@ -455,11 +431,11 @@ def _get_bragg_angle_bounds(self, detector, beam, min_bragg_angle, max_bragg_ang """ if max_bragg_angle is None: mesh_nodes_contained_by_beam = self.mesh_lab.coord[beam.contains(self.mesh_lab.coord.T), :] - mesh_nodes_contained_by_beam = fw.array(mesh_nodes_contained_by_beam, dtype=fw.float32) + mesh_nodes_contained_by_beam = torch.tensor(mesh_nodes_contained_by_beam, dtype=torch.float32) if mesh_nodes_contained_by_beam.shape[0] != 0: - source_point = fw.mean(mesh_nodes_contained_by_beam, axis=0) + source_point = torch.mean(mesh_nodes_contained_by_beam, axis=0) else: - source_point = fw.array(self.mesh_lab.centroid, dtype=fw.float32) + source_point = torch.tensor(self.mesh_lab.centroid, dtype=torch.float32) max_bragg_angle = detector.get_wrapping_cone(beam.wave_vector, source_point).item() assert ( @@ -475,13 +451,13 @@ def _get_bragg_angle_bounds(self, detector, beam, min_bragg_angle, max_bragg_ang return min_bragg_angle, max_bragg_angle def _move_mesh_to_gpu(mesh): - mesh.coord = fw.array(mesh.coord, dtype=fw.float32) - mesh.enod = fw.array(mesh.enod, dtype=fw.int32) - mesh.dof = fw.array(mesh.dof, dtype=fw.float32) - mesh.efaces = fw.array(mesh.efaces, dtype=fw.int32) - mesh.enormals = fw.array(mesh.enormals, dtype=fw.float32) - mesh.ecentroids = fw.array(mesh.ecentroids, dtype=fw.float32) - mesh.eradius = fw.array(mesh.eradius, dtype=fw.float32) - mesh.espherecentroids = fw.array(mesh.espherecentroids, dtype=fw.float32) - mesh.centroid = fw.array(mesh.centroid, dtype=fw.float32) + mesh.coord = torch.tensor(mesh.coord, dtype=torch.float32) + mesh.enod = torch.tensor(mesh.enod, dtype=torch.int32) + mesh.dof = torch.tensor(mesh.dof, dtype=torch.float32) + mesh.efaces = torch.tensor(mesh.efaces, dtype=torch.int32) + mesh.enormals = torch.tensor(mesh.enormals, dtype=torch.float32) + mesh.ecentroids = torch.tensor(mesh.ecentroids, dtype=torch.float32) + mesh.eradius = torch.tensor(mesh.eradius, dtype=torch.float32) + mesh.espherecentroids = torch.tensor(mesh.espherecentroids, dtype=torch.float32) + mesh.centroid = torch.tensor(mesh.centroid, dtype=torch.float32) return mesh \ No newline at end of file diff --git a/xrd_simulator/scattering_factors.py b/xrd_simulator/scattering_factors.py index 302ccb3..0b18849 100644 --- a/xrd_simulator/scattering_factors.py +++ b/xrd_simulator/scattering_factors.py @@ -1,29 +1,23 @@ import numpy as np -import pandas as pd import torch -from xrd_simulator.cuda import fw -if fw != np: - fw.array = fw.tensor - fw.degrees = fw.rad2deg - def lorentz(beam,rigid_body_motion,K_out_xyz): """Compute the Lorentz intensity factor for all reflections.""" k = beam.wave_vector kp = K_out_xyz - rot_axis = fw.array(rigid_body_motion.rotation_axis,dtype=fw.float32) - k_kp_norm = fw.matmul(k,kp.T) / (fw.linalg.norm(k,axis=0) * fw.linalg.norm(kp,axis=1)) - theta = fw.arccos(k_kp_norm) / 2.0 + rot_axis = torch.tensor(rigid_body_motion.rotation_axis,dtype=torch.float32) + k_kp_norm = torch.matmul(k,kp.T) / (torch.linalg.norm(k,axis=0) * torch.linalg.norm(kp,axis=1)) + theta = torch.arccos(k_kp_norm) / 2.0 korthogonal = kp - k_kp_norm.reshape(-1,1)*k.reshape(1,3) - eta = fw.arccos(fw.matmul(rot_axis,korthogonal.T) / fw.linalg.norm(korthogonal)) + eta = torch.arccos(torch.matmul(rot_axis,korthogonal.T) / torch.linalg.norm(korthogonal)) tol = 0.5 - condition = (fw.abs(fw.degrees(eta)) < tol) | (fw.degrees(eta) < tol) | (fw.abs(fw.degrees(eta)) > 180 - tol) - infs = fw.where(condition, fw.inf, 0) - return infs + 1.0 / (fw.sin(2 * theta) * fw.abs(fw.sin(eta))) + #condition = (torch.abs(torch.degrees(eta)) < tol) | (torch.degrees(eta) < tol) | (torch.abs(torch.degrees(eta)) > 180 - tol) + #infs = torch.where(condition, torch.inf, 0) + return 1.0 / (torch.sin(2 * theta) * torch.abs(torch.sin(eta))) def polarization(beam,K_out_xyz): """Compute the Polarization intensity factor for all reflections.""" - kp_norm = K_out_xyz / fw.linalg.norm(K_out_xyz) - return 1 - fw.matmul(beam.polarization_vector, kp_norm.T) ** 2 + kp_norm = K_out_xyz / torch.linalg.norm(K_out_xyz) + return 1 - torch.matmul(beam.polarization_vector, kp_norm.T) ** 2 diff --git a/xrd_simulator/utils.py b/xrd_simulator/utils.py index eb9ff6a..c99123b 100644 --- a/xrd_simulator/utils.py +++ b/xrd_simulator/utils.py @@ -36,7 +36,7 @@ from numba import njit import sys import cupy as cp -from xrd_simulator.cuda import fw +import xrd_simulator.cuda import torch @@ -318,39 +318,26 @@ def _b_to_epsilon(B_matrix, B0): def _epsilon_to_b(crystal_strain, B0): """Handle large deformations as opposed to current xfab.tools.epsilon_to_b""" - - if fw is np: - crystal_strain = crystal_strain.astype(np.float32) - B0 = B0.astype(np.float32) - else: - crystal_strain = fw.tensor(crystal_strain, dtype=torch.float32) - B0 = fw.tensor(B0, dtype=torch.float32) + crystal_strain = torch.tensor(crystal_strain, dtype=torch.float32) + B0 = torch.tensor(B0, dtype=torch.float32) - C = 2 * crystal_strain + fw.eye(3, dtype=fw.float32) + C = 2 * crystal_strain + torch.eye(3, dtype=torch.float32) - eigen_vals = fw.linalg.eigvalsh(C) - if fw.any(eigen_vals< 0): + eigen_vals = torch.linalg.eigvalsh(C) + if torch.any(eigen_vals< 0): raise ValueError( "Unfeasible strain tensor with value: " + str(_strain_as_vector(crystal_strain)) + ", will invert the unit cell with negative deformation gradient tensor determinant" ) - if C.ndim == 3: - if fw is np: - F = fw.transpose(fw.linalg.cholesky(C),(0,2,1)) - else: - F = fw.transpose(fw.linalg.cholesky(C),2,1) + F = torch.transpose(torch.linalg.cholesky(C),2,1) else: - if fw is np: - F = fw.transpose(fw.linalg.cholesky(C),(1,0,2)) - else: - F = fw.transpose(fw.linalg.cholesky(C),1,0) + F = torch.transpose(torch.linalg.cholesky(C),1,0) - B = fw.matmul(fw.linalg.inv(F),B0) - B = B.cpu() if fw == torch else B - + B = torch.matmul(torch.linalg.inv(F),B0) + B = B.cpu() return B