diff --git a/abipy/ml/aseml.py b/abipy/ml/aseml.py index c6023c6a3..f95a8f7b6 100644 --- a/abipy/ml/aseml.py +++ b/abipy/ml/aseml.py @@ -967,7 +967,7 @@ def from_string(cls, string: str): class _MyMlCalculator: """ - Add _abi_forces_list and _abi_stress_list internal attributes to an ASE calculator. + Add __abi_forces_list and __abi_stress_list internal attributes to an ASE calculator. Extend `calculate` method so that ML forces and stresses can be corrected. """ @@ -987,6 +987,7 @@ def __init__(self, *args, **kwargs): self.__verbose = 0 def set_correct_forces_algo(self, new_algo: int) -> int: + """Set the correction algorithm for forces.""" assert new_algo in CORRALGO old_algo = self.__correct_forces_algo self.__correct_forces_algo = new_algo @@ -994,16 +995,19 @@ def set_correct_forces_algo(self, new_algo: int) -> int: @property def correct_forces_algo(self) -> int: + """Correction algorithm for forces.""" return self.__correct_forces_algo def set_correct_stress_algo(self, new_algo: int) -> int: + """Set the correction algorithm for the stress.""" assert new_algo in CORRALGO old_algo = self.__correct_stress_algo self.__correct_stress_algo = new_algo return old_algo @property - def correct_stress_algo(self) -> bool: + def correct_stress_algo(self) -> int: + """Correction algorithm for the stress.""" return self.__correct_stress_algo def store_abi_forstr_atoms(self, abi_forces, abi_stress, atoms): @@ -1018,17 +1022,17 @@ def store_abi_forstr_atoms(self, abi_forces, abi_stress, atoms): self.__abi_stress_list.append(abi_stress) self.__abi_atoms_list.append(atoms.copy()) - # Compute ML values + # Compute ML forces ans stresses for the input atoms. self.reset() old_forces_algo = self.set_correct_forces_algo(CORRALGO.none) old_stress_algo = self.set_correct_stress_algo(CORRALGO.none) ml_forces = self.get_forces(atoms=atoms) ml_stress = self.get_stress(atoms=atoms) #print(f"{ml_forces=}"); print(f"{ml_stress=}") - self.set_correct_forces_algo(old_forces_algo) - self.set_correct_stress_algo(old_stress_algo) self.__ml_forces_list.append(ml_forces) self.__ml_stress_list.append(ml_stress) + self.set_correct_forces_algo(old_forces_algo) + self.set_correct_stress_algo(old_stress_algo) self.reset() def fmt_vec3(vec3) -> str: @@ -1044,20 +1048,24 @@ def fmt_vec6(vec6) -> str: print(f"abi_fcart_{iat=}:", fmt_vec3(abi_forces[iat])) print(f"ml_fcart_{iat=} :", fmt_vec3(ml_forces[iat])) - def get_abi_forces(self): - if self.__abi_forces_list: return self.__abi_forces_list[-1] + def get_abi_forces(self, pos=-1): + """Return the ab-initio forces or None if not available.""" + if self.__abi_forces_list: return self.__abi_forces_list[pos] return None - def get_ml_forces(self): - if self.__ml_forces_list: return self.__ml_forces_list[-1] + def get_ml_forces(self, pos=-1): + """Return the ML forces or None if not available.""" + if self.__ml_forces_list: return self.__ml_forces_list[pos] return None - def get_abi_stress(self): - if self.__abi_stress_list: return self.__abi_stress_list[-1] + def get_abi_stress(self, pos=-1): + """Return the ab-initio stress or None if not available.""" + if self.__abi_stress_list: return self.__abi_stress_list[pos] return None - def get_ml_stress(self): - if self.__ml_stress_list: return self.__ml_stress_list[-1] + def get_ml_stress(self, pos=-1): + """Return the ML stress or None if not available.""" + if self.__ml_stress_list: return self.__ml_stress_list[pos] return None def calculate( @@ -1080,7 +1088,7 @@ def calculate( super().calculate(atoms=atoms, properties=properties, system_changes=system_changes) if self.correct_forces_algo != CORRALGO.none: - # Apply ab-initio correction to ml_forces. + # Apply ab-initio correction to the ml_forces. forces = self.results["forces"] abi_forces = self.get_abi_forces() ml_forces = self.get_ml_forces() @@ -1098,7 +1106,7 @@ def calculate( self.results.update(forces=forces) if self.correct_stress_algo != CORRALGO.none: - # Apply ab-initio correction to stress. + # Apply ab-initio correction to the ml_stress. stress = self.results["stress"] abi_stress = self.get_abi_stress() ml_stress = self.get_ml_stress()