Skip to content

Commit

Permalink
Merge branch 'devel' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz authored Jun 5, 2024
2 parents f2bdf69 + bba3e9f commit 1360492
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
6 changes: 4 additions & 2 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,9 +776,11 @@ def replicate(self, ncopy: list[int] | tuple[int, int, int]):
tmp.data["atom_numbs"] = list(
np.array(np.copy(data["atom_numbs"])) * np.prod(ncopy)
)
tmp.data["atom_types"] = np.sort(
np.tile(np.copy(data["atom_types"]), np.prod(ncopy).item()), kind="stable"
tmp.data["atom_types"] = np.tile(
np.copy(data["atom_types"]), (int(np.prod(ncopy)),) + (1,)
)
tmp.data["atom_types"] = np.transpose(tmp.data["atom_types"]).reshape([-1])

tmp.data["cells"] = np.copy(data["cells"])
for ii in range(3):
tmp.data["cells"][:, ii, :] *= ncopy[ii]
Expand Down
8 changes: 6 additions & 2 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def setUp(self):
)
zero_driver = ZeroDriver()
self.system_1 = ori_sys.predict(driver=zero_driver)
self.system_2 = ori_sys.minimize(driver=zero_driver, minimizer="ase")
self.system_2 = ori_sys.minimize(
driver=zero_driver, minimizer="ase", max_steps=100
)
self.places = 6
self.e_places = 6
self.f_places = 6
Expand All @@ -123,7 +125,9 @@ def setUp(self):
zero_driver = ZeroDriver()
self.system_1 = list(multi_sys.predict(driver=zero_driver).systems.values())[0]
self.system_2 = list(
multi_sys.minimize(driver=zero_driver, minimizer="ase").systems.values()
multi_sys.minimize(
driver=zero_driver, minimizer="ase", max_steps=100
).systems.values()
)[0]
self.places = 6
self.e_places = 6
Expand Down
28 changes: 28 additions & 0 deletions tests/test_replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import unittest

import numpy as np
from comp_sys import CompSys, IsPBC
from context import dpdata

Expand Down Expand Up @@ -36,5 +37,32 @@ def setUp(self):
self.places = 6


class TestReplicateTriclinicBox(unittest.TestCase, CompSys, IsPBC):
def setUp(self):
self.system_1 = dpdata.System()
self.system_1.data["atom_names"] = ["foo", "bar"]
self.system_1.data["atom_types"] = np.array([1, 0], dtype=int)
self.system_1.data["atom_numbs"] = [1, 1]
self.system_1.data["cells"] = np.array(
[10, 0, 0, 0, 10, 0, 0, 0, 10], dtype=float
).reshape(1, 3, 3)
self.system_1.data["coords"] = np.array(
[0, 0, 0, 0, 0, 1], dtype=float
).reshape(1, 2, 3)
self.system_1 = self.system_1.replicate([2, 1, 1])

self.system_2 = dpdata.System()
self.system_2.data["atom_names"] = ["foo", "bar"]
self.system_2.data["atom_types"] = np.array([1, 1, 0, 0], dtype=int)
self.system_2.data["atom_numbs"] = [2, 2]
self.system_2.data["cells"] = np.array(
[20, 0, 0, 0, 10, 0, 0, 0, 10], dtype=float
).reshape(1, 3, 3)
self.system_2.data["coords"] = np.array(
[0, 0, 0, 10, 0, 0, 0, 0, 1, 10, 0, 1], dtype=float
).reshape(1, 4, 3)
self.places = 6


if __name__ == "__main__":
unittest.main()

0 comments on commit 1360492

Please sign in to comment.