Skip to content

Commit

Permalink
Update generation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Jun 10, 2024
1 parent 2990040 commit a69c2b2
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,28 @@ def _compare_attributes(obj1, obj2, exclude=None):

if isinstance(val1, np.ndarray):
assert np.allclose(val1, val2)
elif attr == "prim_interstitial_coords":
_compare_prim_interstitial_coords(val1, val2)
elif attr == "defects" and any(len(i.defect_structure) == 0 for i in val1["vacancies"]):
continue # StructureMatcher comparison breaks for empty structures, which we can have with
# our 1-atom primitive Cu input
else:
assert val1 == val2


def _compare_prim_interstitial_coords(result, expected): # check attribute set
if result is None:
assert expected is None
return

assert len(result) == len(expected), "Lengths do not match"

for (r_coord, r_num, r_list), (e_coord, e_num, e_list) in zip(result, expected):
assert np.array_equal(r_coord, e_coord), "Coordinates do not match"
assert r_num == e_num, "Number of coordinates do not match"
assert all(np.array_equal(r, e) for r, e in zip(r_list, e_list)), "List of arrays do not match"


class DefectsGeneratorTest(unittest.TestCase):
def setUp(self):
# don't run heavy tests on GH Actions, these are run locally (too slow without multiprocessing etc)
Expand Down Expand Up @@ -1453,18 +1471,7 @@ def test_interstitial_coords(self):
(np.array([0.75, 0.75, 0.75]), 1, [np.array([0.75, 0.75, 0.75])]),
(np.array([0.5, 0.5, 0.5]), 1, [np.array([0.5, 0.5, 0.5])]),
]

def _test_prim_interstitial_coords(result, expected): # check attribute set
assert len(result) == len(expected), "Lengths do not match"

for (r_coord, r_num, r_list), (e_coord, e_num, e_list) in zip(result, expected):
assert np.array_equal(r_coord, e_coord), "Coordinates do not match"
assert r_num == e_num, "Number of coordinates do not match"
assert all(
np.array_equal(r, e) for r, e in zip(r_list, e_list)
), "List of arrays do not match"

_test_prim_interstitial_coords(CdTe_defect_gen.prim_interstitial_coords, expected)
_compare_prim_interstitial_coords(CdTe_defect_gen.prim_interstitial_coords, expected)

# defect_gen_check changes defect_entries ordering, so save to json first:
self._save_defect_gen_jsons(CdTe_defect_gen)
Expand All @@ -1475,7 +1482,7 @@ def _test_prim_interstitial_coords(result, expected): # check attribute set
self.prim_cdte, interstitial_coords=[[0.5, 0.5, 0.5]], extrinsic={"Te": ["Se", "S"]}
)
assert CdTe_defect_gen.interstitial_coords == [[0.5, 0.5, 0.5]] # check attribute set
_test_prim_interstitial_coords(
_compare_prim_interstitial_coords(
CdTe_defect_gen.prim_interstitial_coords,
[(np.array([0.5, 0.5, 0.5]), 1, [np.array([0.5, 0.5, 0.5])])], # check attribute set
)
Expand Down Expand Up @@ -1516,7 +1523,7 @@ def _test_prim_interstitial_coords(result, expected): # check attribute set
self._load_and_test_defect_gen_jsons(ytos_defect_gen)

assert ytos_defect_gen.interstitial_coords == ytos_interstitial_coords # check attribute set
_test_prim_interstitial_coords(
_compare_prim_interstitial_coords(
ytos_defect_gen.prim_interstitial_coords,
[
(
Expand Down Expand Up @@ -1563,7 +1570,7 @@ def _test_prim_interstitial_coords(result, expected): # check attribute set
assert (
CdTe_defect_gen.interstitial_coords == CdTe_supercell_interstitial_coords
) # check attribute set
_test_prim_interstitial_coords(
_compare_prim_interstitial_coords(
CdTe_defect_gen.prim_interstitial_coords,
[
(
Expand Down

0 comments on commit a69c2b2

Please sign in to comment.