Skip to content

Commit

Permalink
Tests: Avoid vise suppression of warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Mar 31, 2024
1 parent 9584137 commit 9a40855
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 39 deletions.
15 changes: 13 additions & 2 deletions doped/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,15 @@ def from_dict(cls, d: dict):
Returns:
``DefectEntry`` object
"""
from doped.utils.parsing import suppress_logging
from doped.utils.parsing import _reset_warnings, suppress_logging

with suppress_logging():
return super().from_dict(d)
obj = super().from_dict(d)

_reset_warnings() # vise suppresses ``UserWarning``s (and this initialises ``vise``
# ``BandEdgeStates`` objects) so need to reset

return obj

def _check_correction_error_and_return_output(
self,
Expand Down Expand Up @@ -598,10 +603,13 @@ def _load_and_parse_eigenvalue_data(
from doped.utils.parsing import (
_get_output_files_and_check_if_multiple,
_multiple_files_warning,
_reset_warnings,
get_procar,
get_vasprun,
)

_reset_warnings() # vise suppresses `UserWarning`s, so need to reset

parsed_vr_procar_dict = {}
for vr, procar, label in [(bulk_vr, bulk_procar, "bulk"), (defect_vr, defect_procar, "defect")]:
path = self.calculation_metadata.get(f"{label}_path")
Expand Down Expand Up @@ -782,6 +790,9 @@ def get_eigenvalue_analysis(
``Figure`` object (if ``plot=True``).
"""
from doped.utils.eigenvalues import get_eigenvalue_analysis
from doped.utils.parsing import _reset_warnings

_reset_warnings() # vise suppresses `UserWarning`s, so need to reset

self._load_and_parse_eigenvalue_data(
bulk_vr=bulk_vr,
Expand Down
14 changes: 3 additions & 11 deletions doped/corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
"""

import os
import warnings
from typing import Optional, Union

import matplotlib.pyplot as plt
Expand All @@ -49,21 +48,18 @@
from pymatgen.io.vasp.outputs import Locpot, Outcar
from shakenbreak.plotting import _install_custom_font

from doped import _ignore_pmg_warnings
from doped.analysis import _convert_dielectric_to_tensor
from doped.utils.parsing import (
_get_bulk_supercell,
_get_defect_supercell,
_get_defect_supercell_bulk_site_coords,
_reset_warnings,
get_locpot,
get_outcar,
)
from doped.utils.plotting import _get_backend, format_defect_name

warnings.simplefilter("default")
# `message` only needs to match start of message:
warnings.filterwarnings("ignore", message="`np.int` is a deprecated alias for the builtin `int`")
warnings.filterwarnings("ignore", message="Use get_magnetic_symmetry()")
_reset_warnings() # vise suppresses `UserWarning`s, so need to reset


def _monty_decode_nested_dicts(d):
Expand Down Expand Up @@ -418,11 +414,7 @@ def get_kumagai_correction(
"You can do this by running `pip install pydefect`."
) from exc

# vise suppresses `UserWarning`s, so need to reset
warnings.simplefilter("default")
warnings.filterwarnings("ignore", message="`np.int` is a deprecated alias for the builtin `int`")
warnings.filterwarnings("ignore", message="Use get_magnetic_symmetry()")
_ignore_pmg_warnings()
_reset_warnings() # vise suppresses `UserWarning`s, so need to reset

def doped_make_efnv_correction(
charge: float,
Expand Down
14 changes: 7 additions & 7 deletions doped/utils/eigenvalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
from pymatgen.io.vasp.outputs import Procar, Vasprun
from shakenbreak.plotting import _install_custom_font

from doped import _ignore_pmg_warnings
from doped.core import DefectEntry
from doped.utils.parsing import get_magnetization_from_vasprun, get_nelect_from_vasprun, get_procar
from doped.utils.parsing import (
_reset_warnings,
get_magnetization_from_vasprun,
get_nelect_from_vasprun,
get_procar,
)
from doped.utils.plotting import _get_backend

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,11 +56,7 @@
"You can do this by running `pip install pydefect`."
) from exc

# vise suppresses `UserWarning`s, so need to reset
warnings.simplefilter("default")
warnings.filterwarnings("ignore", message="`np.int` is a deprecated alias for the builtin `int`")
warnings.filterwarnings("ignore", message="Use get_magnetic_symmetry()")
_ignore_pmg_warnings()
_reset_warnings() # vise suppresses `UserWarning`s, so need to reset


def _coordination(self, include_on_site=True, cutoff_factor=None) -> "Coordination":
Expand Down
17 changes: 12 additions & 5 deletions doped/utils/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ def _get_potcar_summary_stats() -> dict:
return loadfn(POTCAR_STATS_PATH)


def _reset_warnings():
"""
When importing ``vise``/``pydefect``, ``UserWarning``s are suppressed, so
we need to reset.
"""
warnings.simplefilter("default")
warnings.filterwarnings("ignore", message="`np.int` is a deprecated alias for the builtin `int`")
warnings.filterwarnings("ignore", message="Use get_magnetic_symmetry()")
_ignore_pmg_warnings()


@contextlib.contextmanager
def suppress_logging(level=logging.CRITICAL):
"""
Expand Down Expand Up @@ -542,11 +553,7 @@ def check_atom_mapping_far_from_defect(bulk, defect, defect_coords):
except ImportError: # can't check as vise/pydefect not installed. Not critical so just return
return

# vise suppresses `UserWarning`s, so need to reset
warnings.simplefilter("default")
warnings.filterwarnings("ignore", message="`np.int` is a deprecated alias for the builtin `int`")
warnings.filterwarnings("ignore", message="Use get_magnetic_symmetry()")
_ignore_pmg_warnings()
_reset_warnings() # vise suppresses `UserWarning`s, so need to reset

far_from_defect_disps = {site.specie.symbol: [] for site in bulk}

Expand Down
2 changes: 1 addition & 1 deletion examples/YTOS/YTOS_example_defect_dict.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/YTOS/YTOS_example_thermo.json

Large diffs are not rendered by default.

32 changes: 20 additions & 12 deletions tests/test_thermodynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def setUp(self):
self.MgO_defect_thermo = deepcopy(self.orig_MgO_defect_thermo)
self.MgO_defect_dict = deepcopy(self.orig_MgO_defect_dict)
self.Sb2O5_defect_thermo = deepcopy(self.orig_Sb2O5_defect_thermo)
self.Zns_defect_thermo = deepcopy(self.orig_ZnS_defect_thermo)
self.ZnS_defect_thermo = deepcopy(self.orig_ZnS_defect_thermo)

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -175,8 +175,8 @@ def _compare_defect_thermo_and_dict(self, defect_thermo, defect_dict):
round(entry.get_ediff(), 3) for entry in defect_dict.values()
}
assert { # check coords are the same by getting their products
round(np.product(entry.sc_defect_frac_coords), 3) for entry in defect_thermo.defect_entries
} == {round(np.product(entry.sc_defect_frac_coords), 3) for entry in defect_dict.values()}
round(np.prod(entry.sc_defect_frac_coords), 3) for entry in defect_thermo.defect_entries
} == {round(np.prod(entry.sc_defect_frac_coords), 3) for entry in defect_dict.values()}

def _compare_defect_thermos(self, defect_thermo1, defect_thermo2):
assert len(defect_thermo1.defect_entries) == len(defect_thermo2.defect_entries)
Expand All @@ -187,8 +187,8 @@ def _compare_defect_thermos(self, defect_thermo1, defect_thermo2):
round(entry.get_ediff(), 3) for entry in defect_thermo2.defect_entries
}
assert { # check coords are the same by getting their products
round(np.product(entry.sc_defect_frac_coords), 3) for entry in defect_thermo1.defect_entries
} == {round(np.product(entry.sc_defect_frac_coords), 3) for entry in defect_thermo2.defect_entries}
round(np.prod(entry.sc_defect_frac_coords), 3) for entry in defect_thermo1.defect_entries
} == {round(np.prod(entry.sc_defect_frac_coords), 3) for entry in defect_thermo2.defect_entries}

def _check_defect_thermo(
self,
Expand Down Expand Up @@ -228,8 +228,8 @@ def _check_defect_thermo(
defect_thermo.to_json() # test default naming
compositions = ["CdTe", "Y2Ti2S2O5", "Sb2Se3", "SiSbTe3", "V2O5", "MgO", "Sb2O5", "ZnS"]
assert defect_thermo.bulk_formula in compositions
assert any(os.path.exists(f"{i}_defect_thermodynamics.json") for i in compositions)
for i in compositions:
os.path.exists(f"{i}_defect_thermodynamics.json")
if_present_rm(f"{i}_defect_thermodynamics.json")

thermo_dict = defect_thermo.as_dict()
Expand Down Expand Up @@ -446,7 +446,7 @@ def test_DefectsParser_thermo_objs(self):
(self.V2O5_defect_thermo, "V2O5_defect_thermo"),
(self.MgO_defect_thermo, "MgO_defect_thermo"),
(self.Sb2O5_defect_thermo, "Sb2O5_defect_thermo"),
(self.Zns_defect_thermo, "ZnS_defect_thermo"),
(self.ZnS_defect_thermo, "ZnS_defect_thermo"),
]:
print(f"Checking {name}")
if "V2O5" in name:
Expand Down Expand Up @@ -475,7 +475,7 @@ def test_DefectsParser_thermo_objs(self):
def test_DefectsParser_thermo_objs_no_metadata(self):
"""
Test the `DefectThermodynamics` objects created from the
`DefectsParser.get_defect_thermodynamics()` method.
``DefectsParser.get_defect_thermodynamics()`` method.
"""
for defect_thermo, name in [
(self.CdTe_defect_thermo, "CdTe_defect_thermo"),
Expand All @@ -485,11 +485,16 @@ def test_DefectsParser_thermo_objs_no_metadata(self):
(self.V2O5_defect_thermo, "V2O5_defect_thermo"),
(self.MgO_defect_thermo, "MgO_defect_thermo"),
(self.Sb2O5_defect_thermo, "Sb2O5_defect_thermo"),
(self.Zns_defect_thermo, "ZnS_defect_thermo"),
(self.ZnS_defect_thermo, "ZnS_defect_thermo"),
]:
print(f"Checking {name}")
defect_entries_wout_metadata = defect_thermo.defect_entries
# get set of random 7 entries from defect_entries_wout_metadata (7 because need full CdTe
# example set for its semi-hard 😏 tests)
defect_entries_wout_metadata = random.sample(
defect_thermo.defect_entries, min(7, len(defect_thermo.defect_entries))
)
for entry in defect_entries_wout_metadata:
print(f"Setting metadata to empty for {entry.name}")
entry.calculation_metadata = {}

with pytest.raises(ValueError) as exc:
Expand All @@ -512,8 +517,11 @@ def test_DefectsParser_thermo_objs_no_metadata(self):
)
self._check_defect_thermo(thermo_wout_metadata) # default values

defect_entries_wout_metadata_or_degeneracy = defect_thermo.defect_entries
defect_entries_wout_metadata_or_degeneracy = random.sample(
defect_thermo.defect_entries, min(7, len(defect_thermo.defect_entries))
) # get set of random 7 entries from defect_entries_wout_metadata (7 needed for CdTe tests)
for entry in defect_entries_wout_metadata_or_degeneracy:
print(f"Setting metadata and degeneracy factors to empty for {entry.name}")
entry.calculation_metadata = {}
entry.degeneracy_factors = {}

Expand All @@ -524,7 +532,7 @@ def test_DefectsParser_thermo_objs_no_metadata(self):
symm_df = thermo_wout_metadata_or_degeneracy.get_symmetries_and_degeneracies()
assert symm_df["g_Spin"].apply(lambda x: isinstance(x, int)).all()

for defect_entry in defect_thermo.defect_entries:
for defect_entry in defect_entries_wout_metadata_or_degeneracy:
assert defect_entry.degeneracy_factors["spin degeneracy"] in {1, 2}
assert isinstance(
defect_entry.degeneracy_factors["orientational degeneracy"], (int, float)
Expand Down

0 comments on commit 9a40855

Please sign in to comment.