diff --git a/ramannoodle/polarizability/torch/dataset.py b/ramannoodle/polarizability/torch/dataset.py index d29827c..9f75b5b 100644 --- a/ramannoodle/polarizability/torch/dataset.py +++ b/ramannoodle/polarizability/torch/dataset.py @@ -221,8 +221,8 @@ def scale_polarizabilities( scaled = self._polarizabilities.detach().clone() scaled = scaled - torch.tensor(mean) scaled /= torch.tensor(stddev) - self._scaled_polarizabilities = rn_torch_utils.get_polarizability_vectors( - scaled + self._scaled_polarizabilities = ( + rn_torch_utils.polarizability_tensors_to_vectors(scaled) ) def __len__(self) -> int: diff --git a/ramannoodle/polarizability/torch/gnn.py b/ramannoodle/polarizability/torch/gnn.py index 2125e25..70f4b90 100644 --- a/ramannoodle/polarizability/torch/gnn.py +++ b/ramannoodle/polarizability/torch/gnn.py @@ -403,7 +403,7 @@ def _get_edge_polarizability_vectors( + a5 * torch.tensor([[0, 0, 0], [0, 1, 0], [0, 0, 0]]) + a6 * torch.tensor([[0, 0, 0], [0, 0, 0], [0, 0, 1]]) ) - return rn_torch_utils.get_polarizability_vectors(edge_polarizability) + return rn_torch_utils.polarizability_tensors_to_vectors(edge_polarizability) class PotGNN( @@ -676,6 +676,8 @@ def calc_polarizabilities( atomic_numbers, torch.tensor(positions_batch).type(default_type), ) - polarizability = rn_torch_utils.get_polarizability_tensors(polarizability) + polarizability = rn_torch_utils.polarizability_vectors_to_tensors( + polarizability + ) return polarizability.detach().clone().numpy() diff --git a/ramannoodle/polarizability/torch/utils.py b/ramannoodle/polarizability/torch/utils.py index 1c6fd13..61fc212 100644 --- a/ramannoodle/polarizability/torch/utils.py +++ b/ramannoodle/polarizability/torch/utils.py @@ -28,27 +28,7 @@ def polarizability_vectors_to_tensors(polarizability_vectors: Tensor) -> Tensor: : 3D tensor with size [S,3,3]. """ - indices = torch.tensor( - [ - [0, 3, 4], - [3, 1, 5], - [4, 5, 2], - ] - ) - try: - return polarizability_vectors[:, indices] - except IndexError as exc: - raise get_tensor_size_error( - "polarizability_vectors", polarizability_vectors, "[_,6]" - ) from exc - except TypeError as exc: - raise get_type_error( - "polarizability_vectors", polarizability_vectors, "Tensor" - ) from exc - - -def get_polarizability_tensors(polarizability_vectors: Tensor) -> Tensor: - """X should have size (_,6).""" + verify_tensor_size("polarizability_vectors", polarizability_vectors, (None, 6)) indices = torch.tensor( [ [0, 3, 4], @@ -59,8 +39,21 @@ def get_polarizability_tensors(polarizability_vectors: Tensor) -> Tensor: return polarizability_vectors[:, indices] -def get_polarizability_vectors(polarizability_tensors: Tensor) -> Tensor: - """X should have size (_,3,3).""" +def polarizability_tensors_to_vectors(polarizability_tensors: Tensor) -> Tensor: + """Convert polarizability tensors to vectors. + + Parameters + ---------- + polarizability_tensors + | 3D tensor with size [S,3,3] where S is the number of samples. + + Returns + ------- + : + 2D tensor with size [S,6]. + + """ + verify_tensor_size("polarizability_tensors", polarizability_tensors, (None, 3, 3)) indices = torch.tensor([[0, 0], [1, 1], [2, 2], [0, 1], [0, 2], [1, 2]]).T return polarizability_tensors[:, indices[0], indices[1]] diff --git a/test/tests/torch/test_gnn.py b/test/tests/torch/test_gnn.py index e3bb63a..28e4aef 100644 --- a/test/tests/torch/test_gnn.py +++ b/test/tests/torch/test_gnn.py @@ -11,7 +11,7 @@ from ramannoodle.polarizability.torch.utils import ( _radius_graph_pbc, get_rotations, - get_polarizability_tensors, + polarizability_vectors_to_tensors, ) # import ramannoodle.io.vasp as vasp_io @@ -173,7 +173,7 @@ def test_calc_polarizabilities( batch_positions = torch.randn(batch_size, num_atoms, 3) forward = model.forward(batch_lattices, batch_atomic_numbers, batch_positions) - forward = get_polarizability_tensors(forward.detach().clone()).numpy() + forward = polarizability_vectors_to_tensors(forward.detach().clone()).numpy() calc = model.calc_polarizabilities(batch_positions.detach().clone().numpy()) assert np.allclose(forward, calc)