Skip to content

Commit

Permalink
renamed tensor<->vector utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Sep 20, 2024
1 parent 6a6650b commit 2ab314e
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 29 deletions.
4 changes: 2 additions & 2 deletions ramannoodle/polarizability/torch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def scale_polarizabilities(
scaled = self._polarizabilities.detach().clone()
scaled = scaled - torch.tensor(mean)
scaled /= torch.tensor(stddev)
self._scaled_polarizabilities = rn_torch_utils.get_polarizability_vectors(
scaled
self._scaled_polarizabilities = (
rn_torch_utils.polarizability_tensors_to_vectors(scaled)
)

def __len__(self) -> int:
Expand Down
6 changes: 4 additions & 2 deletions ramannoodle/polarizability/torch/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def _get_edge_polarizability_vectors(
+ a5 * torch.tensor([[0, 0, 0], [0, 1, 0], [0, 0, 0]])
+ a6 * torch.tensor([[0, 0, 0], [0, 0, 0], [0, 0, 1]])
)
return rn_torch_utils.get_polarizability_vectors(edge_polarizability)
return rn_torch_utils.polarizability_tensors_to_vectors(edge_polarizability)


class PotGNN(
Expand Down Expand Up @@ -676,6 +676,8 @@ def calc_polarizabilities(
atomic_numbers,
torch.tensor(positions_batch).type(default_type),
)
polarizability = rn_torch_utils.get_polarizability_tensors(polarizability)
polarizability = rn_torch_utils.polarizability_vectors_to_tensors(
polarizability
)

return polarizability.detach().clone().numpy()
39 changes: 16 additions & 23 deletions ramannoodle/polarizability/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,7 @@ def polarizability_vectors_to_tensors(polarizability_vectors: Tensor) -> Tensor:
:
3D tensor with size [S,3,3].
"""
indices = torch.tensor(
[
[0, 3, 4],
[3, 1, 5],
[4, 5, 2],
]
)
try:
return polarizability_vectors[:, indices]
except IndexError as exc:
raise get_tensor_size_error(
"polarizability_vectors", polarizability_vectors, "[_,6]"
) from exc
except TypeError as exc:
raise get_type_error(
"polarizability_vectors", polarizability_vectors, "Tensor"
) from exc


def get_polarizability_tensors(polarizability_vectors: Tensor) -> Tensor:
"""X should have size (_,6)."""
verify_tensor_size("polarizability_vectors", polarizability_vectors, (None, 6))
indices = torch.tensor(
[
[0, 3, 4],
Expand All @@ -59,8 +39,21 @@ def get_polarizability_tensors(polarizability_vectors: Tensor) -> Tensor:
return polarizability_vectors[:, indices]


def get_polarizability_vectors(polarizability_tensors: Tensor) -> Tensor:
"""X should have size (_,3,3)."""
def polarizability_tensors_to_vectors(polarizability_tensors: Tensor) -> Tensor:
"""Convert polarizability tensors to vectors.
Parameters
----------
polarizability_tensors
| 3D tensor with size [S,3,3] where S is the number of samples.
Returns
-------
:
2D tensor with size [S,6].
"""
verify_tensor_size("polarizability_tensors", polarizability_tensors, (None, 3, 3))
indices = torch.tensor([[0, 0], [1, 1], [2, 2], [0, 1], [0, 2], [1, 2]]).T
return polarizability_tensors[:, indices[0], indices[1]]

Expand Down
4 changes: 2 additions & 2 deletions test/tests/torch/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ramannoodle.polarizability.torch.utils import (
_radius_graph_pbc,
get_rotations,
get_polarizability_tensors,
polarizability_vectors_to_tensors,
)

# import ramannoodle.io.vasp as vasp_io
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_calc_polarizabilities(
batch_positions = torch.randn(batch_size, num_atoms, 3)

forward = model.forward(batch_lattices, batch_atomic_numbers, batch_positions)
forward = get_polarizability_tensors(forward.detach().clone()).numpy()
forward = polarizability_vectors_to_tensors(forward.detach().clone()).numpy()
calc = model.calc_polarizabilities(batch_positions.detach().clone().numpy())

assert np.allclose(forward, calc)
Expand Down

0 comments on commit 2ab314e

Please sign in to comment.