Skip to content

Commit

Permalink
dataset and potgnn internally scales polarizability
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Sep 20, 2024
1 parent 5f1ad3f commit 7722d49
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 20 deletions.
1 change: 0 additions & 1 deletion ramannoodle/io/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,4 @@ def _read_polarizability_dataset(
atomic_numbers,
np.array(positions_list),
np.array(polarizabilities),
scale_mode="standard",
)
5 changes: 2 additions & 3 deletions ramannoodle/polarizability/torch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _scale_and_flatten_polarizabilities(
class PolarizabilityDataset(Dataset[tuple[Tensor, Tensor, Tensor, Tensor]]):
"""PyTorch dataset of atomic structures and polarizabilities.
Polarizabilities are scaled and flattened into 6-vectors containing the
Polarizabilities are standard scaled and flattened into 6-vectors containing the
independent tensor components.
Parameters
Expand All @@ -98,7 +98,6 @@ def __init__( # pylint: disable=too-many-arguments
atomic_numbers: list[int],
positions: NDArray[np.float64],
polarizabilities: NDArray[np.float64],
scale_mode: str = "standard",
):
# Validate parameter shapes
verify_ndarray_shape("lattice", lattice, (3, 3))
Expand All @@ -117,7 +116,7 @@ def __init__( # pylint: disable=too-many-arguments
self._polarizabilities = torch.tensor(polarizabilities)

_, _, scaled = _scale_and_flatten_polarizabilities(
self._polarizabilities, scale_mode=scale_mode
self._polarizabilities, scale_mode="standard"
)
self._scaled_polarizabilities = scaled.type(default_type)

Expand Down
12 changes: 10 additions & 2 deletions ramannoodle/polarizability/torch/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ class PotGNN(
| (Å) Upper bound of the Gaussian filter used in initial edge embedding.
"""

def __init__( # pylint: disable=too-many-arguments,too-many-instance-attributes
def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
ref_structure: ReferenceStructure,
cutoff: float,
Expand All @@ -442,6 +442,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-instance-attributes
num_message_passes: int,
gaussian_filter_start: float,
gaussian_filter_end: float,
mean_polarizability: NDArray[np.float64],
stddev_polarizability: NDArray[np.float64],
):
super().__init__()
# Validate parameters
Expand All @@ -459,9 +461,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-instance-attributes
if gaussian_filter_end <= gaussian_filter_start:
inequality = f"{gaussian_filter_end} <= gaussian_filter_start"
raise ValueError(f"invalid gaussian_filter_end: {inequality}")
verify_ndarray_shape("mean_polarizability", mean_polarizability, (3, 3))
verify_ndarray_shape("stddev_polarizability", stddev_polarizability, (3, 3))

self._ref_structure = ref_structure
self._cutoff = cutoff
self._mean_polarizability = mean_polarizability
self._stddev_polarizability = stddev_polarizability

# Set up graph.
lattice = torch.from_numpy(ref_structure.lattice).unsqueeze(0)
Expand Down Expand Up @@ -692,4 +698,6 @@ def calc_polarizabilities(
end_index = start_index + positions_subbatch.shape[0]
polarizabilities[start_index:end_index] = polarizability.detach()

return polarizabilities.detach().numpy()
result = polarizabilities.detach().numpy() * self._stddev_polarizability
result += self._mean_polarizability
return result
13 changes: 4 additions & 9 deletions test/tests/torch/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,15 @@ def test_load_polarizability_dataset(


@pytest.mark.parametrize(
"lattice, atomic_numbers, positions, polarizabilities, scale_mode, exception_type,"
"in_reason",
"lattice, atomic_numbers, positions, polarizabilities, exception_type,in_reason",
[
(
np.zeros((3, 3)),
[1, 2],
np.random.random((2, 2, 3)),
np.random.random((3, 2, 3)),
np.random.random((2, 3, 3)),
"invalid_scale_mode",
ValueError,
"unsupported scale mode: invalid_scale_mode",
"polarizabilities has wrong shape: (2,3,3) != (3,3,3)",
),
],
)
Expand All @@ -63,14 +61,11 @@ def test_polarizability_dataset_exception( # pylint: disable=too-many-arguments
atomic_numbers: list[int],
positions: NDArray[np.float64],
polarizabilities: NDArray[np.float64],
scale_mode: str,
exception_type: Type[Exception],
in_reason: str,
) -> None:
"""Test polarizability dataset (exception)."""
with pytest.raises(exception_type) as error:
PolarizabilityDataset(
lattice, atomic_numbers, positions, polarizabilities, scale_mode
)
PolarizabilityDataset(lattice, atomic_numbers, positions, polarizabilities)

assert in_reason in str(error.value)
10 changes: 5 additions & 5 deletions test/tests/torch/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_radius_graph_pbc() -> None:
def test_batch_polarizability(poscar_ref_structure_fixture: ReferenceStructure) -> None:
"""Test of batch functions for forward pass (normal)."""
ref_structure = poscar_ref_structure_fixture
model = PotGNN(ref_structure, 5, 5, 5, 5, 0, 5)
model = PotGNN(ref_structure, 5, 5, 5, 5, 0, 5, np.zeros((3, 3)), np.zeros((3, 3)))
model.eval()

for batch_size in range(1, 4):
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_gpu(poscar_ref_structure_fixture: ReferenceStructure) -> None:
return

torch.set_default_device(device) # type: ignore
model = PotGNN(ref_structure, 5, 5, 5, 5, 0, 5)
model = PotGNN(ref_structure, 5, 5, 5, 5, 0, 5, np.zeros((3, 3)), np.ones((3, 3)))
model.eval()

for batch_size in range(1, 4):
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_calc_polarizabilities(
) -> None:
"""Test of calc_polarizabilities (normal)."""
ref_structure = poscar_ref_structure_fixture
model = PotGNN(ref_structure, 2, 5, 5, 5, 0, 5)
model = PotGNN(ref_structure, 2, 5, 5, 5, 0, 5, np.zeros((3, 3)), np.ones((3, 3)))
model.eval()

for batch_size in [50, 100, 200]:
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_calc_polarizabilities(
# def test_symmetry(poscar_ref_structure_fixture: ReferenceStructure) -> None:
# """Test that model obeys symmetries."""
# ref_structure = poscar_ref_structure_fixture
# model = PotGNN(ref_structure, 2.3, 5, 5, 6)
# model = PotGNN(ref_structure, 2.3, 5, 5, 6, np.zeros((3, 3)), np.zeros((3, 3)))
# model.eval()

# lattice = torch.tensor([ref_structure.lattice]).float()
Expand All @@ -216,7 +216,7 @@ def test_calc_polarizabilities(
# displaced_positions = vasp_io.outcar.read_positions(
# "test/data/TiO2/Ti5_0.1x_eps_OUTCAR"
# )
# model = PotGNN(ref_structure, 5, 5, 6, 4)
# model = PotGNN(ref_structure, 5, 5, 6, 4, np.zeros((3, 3)), np.zeros((3, 3)))
# model.eval()
# parent_displacement = (displaced_positions - ref_structure.positions) / (
# (np.linalg.norm(displaced_positions - ref_structure.positions) * 10)
Expand Down

0 comments on commit 7722d49

Please sign in to comment.