Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed Sep 6, 2023
1 parent 833559d commit a185ad4
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions abipy/ml/aseml.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def from_string(cls, string: str):

class _MyMlCalculator:
"""
Add _abi_forces_list and _abi_stress_list internal attributes to an ASE calculator.
Add __abi_forces_list and __abi_stress_list internal attributes to an ASE calculator.
Extend `calculate` method so that ML forces and stresses can be corrected.
"""

Expand All @@ -987,23 +987,27 @@ def __init__(self, *args, **kwargs):
self.__verbose = 0

def set_correct_forces_algo(self, new_algo: int) -> int:
"""Set the correction algorithm for forces."""
assert new_algo in CORRALGO
old_algo = self.__correct_forces_algo
self.__correct_forces_algo = new_algo
return old_algo

@property
def correct_forces_algo(self) -> int:
"""Correction algorithm for forces."""
return self.__correct_forces_algo

def set_correct_stress_algo(self, new_algo: int) -> int:
"""Set the correction algorithm for the stress."""
assert new_algo in CORRALGO
old_algo = self.__correct_stress_algo
self.__correct_stress_algo = new_algo
return old_algo

@property
def correct_stress_algo(self) -> bool:
def correct_stress_algo(self) -> int:
"""Correction algorithm for the stress."""
return self.__correct_stress_algo

def store_abi_forstr_atoms(self, abi_forces, abi_stress, atoms):
Expand All @@ -1018,16 +1022,17 @@ def store_abi_forstr_atoms(self, abi_forces, abi_stress, atoms):
self.__abi_stress_list.append(abi_stress)
self.__abi_atoms_list.append(atoms.copy())

# Compute ML forces ans stresses for the input atoms.
self.reset()
old_forces_algo = self.set_correct_forces_algo(CORRALGO.none)
old_stress_algo = self.set_correct_stress_algo(CORRALGO.none)
ml_forces = self.get_forces(atoms=atoms)
ml_stress = self.get_stress(atoms=atoms)
#print(f"{ml_forces=}"); print(f"{ml_stress=}")
self.set_correct_forces_algo(old_forces_algo)
self.set_correct_stress_algo(old_stress_algo)
self.__ml_forces_list.append(ml_forces)
self.__ml_stress_list.append(ml_stress)
self.set_correct_forces_algo(old_forces_algo)
self.set_correct_stress_algo(old_stress_algo)
self.reset()

def fmt_vec3(vec3) -> str:
Expand All @@ -1043,20 +1048,24 @@ def fmt_vec6(vec6) -> str:
print(f"abi_fcart_{iat=}:", fmt_vec3(abi_forces[iat]))
print(f"ml_fcart_{iat=} :", fmt_vec3(ml_forces[iat]))

def get_abi_forces(self):
if self.__abi_forces_list: return self.__abi_forces_list[-1]
def get_abi_forces(self, pos=-1):
"""Return the ab-initio forces or None if not available."""
if self.__abi_forces_list: return self.__abi_forces_list[pos]
return None

def get_ml_forces(self):
if self.__ml_forces_list: return self.__ml_forces_list[-1]
def get_ml_forces(self, pos=-1):
"""Return the ML forces or None if not available."""
if self.__ml_forces_list: return self.__ml_forces_list[pos]
return None

def get_abi_stress(self):
if self.__abi_stress_list: return self.__abi_stress_list[-1]
def get_abi_stress(self, pos=-1):
"""Return the ab-initio stress or None if not available."""
if self.__abi_stress_list: return self.__abi_stress_list[pos]
return None

def get_ml_stress(self):
if self.__ml_stress_list: return self.__ml_stress_list[-1]
def get_ml_stress(self, pos=-1):
"""Return the ML stress or None if not available."""
if self.__ml_stress_list: return self.__ml_stress_list[pos]
return None

def calculate(
Expand All @@ -1079,7 +1088,7 @@ def calculate(
super().calculate(atoms=atoms, properties=properties, system_changes=system_changes)

if self.correct_forces_algo != CORRALGO.none:
# Apply ab-initio correction to ml_forces.
# Apply ab-initio correction to the ml_forces.
forces = self.results["forces"]
abi_forces = self.get_abi_forces()
ml_forces = self.get_ml_forces()
Expand All @@ -1097,7 +1106,7 @@ def calculate(
self.results.update(forces=forces)

if self.correct_stress_algo != CORRALGO.none:
# Apply ab-initio correction to stress.
# Apply ab-initio correction to the ml_stress.
stress = self.results["stress"]
abi_stress = self.get_abi_stress()
ml_stress = self.get_ml_stress()
Expand Down

0 comments on commit a185ad4

Please sign in to comment.