diff --git a/examples/selection/GCH-ROY.py b/examples/selection/GCH-ROY.py index 7cc5df38d..0767b3a98 100644 --- a/examples/selection/GCH-ROY.py +++ b/examples/selection/GCH-ROY.py @@ -27,11 +27,9 @@ roy_data = load_roy_dataset() -structures = roy_data["structures"] - -density = np.array([s.info["density"] for s in structures]) -energy = np.array([s.info["energy"] for s in structures]) -structype = np.array([s.info["type"] for s in structures]) +density = roy_data["densities"] +energy = roy_data["energies"] +structype = roy_data["structure_types"] iknown = np.where(structype == "known")[0] iothers = np.where(structype != "known")[0] @@ -247,3 +245,5 @@ }, ) """ + +# %% diff --git a/pyproject.toml b/pyproject.toml index 019f9b395..20399c980 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ dynamic = ["version"] [project.optional-dependencies] examples = [ - "ase", "matplotlib", "pandas", "tqdm", diff --git a/src/skmatter/datasets/_base.py b/src/skmatter/datasets/_base.py index ac8287daf..90dd7f38c 100644 --- a/src/skmatter/datasets/_base.py +++ b/src/skmatter/datasets/_base.py @@ -119,32 +119,27 @@ def load_who_dataset(): def load_roy_dataset(): - """Load and returns the ROY dataset, which contains structures, - energies and SOAP-derived descriptors for 264 polymorphs of ROY, from [Beran et Al, - Chemical Science (2022)](https://doi.org/10.1039/D1SC06074K) + """Load and returns the ROY dataset, which contains densities, + energies and SOAP-derived descriptors for 264 structures of polymorphs of ROY, + from [Beran et Al, Chemical Science (2022)](https://doi.org/10.1039/D1SC06074K) + Each structure is labeled as "Known" or "Unknown". Returns ------- roy_dataset : sklearn.utils.Bunch Dictionary-like object, with the following attributes: - structures : `ase.Atoms` -- the roy structures as ASE objects - features: `np.array` -- SOAP-derived descriptors for the structures - energies: `np.array` -- energies of the structures + densities : `np.array` -- the densities of the structures + structure_types : `np.array` -- the type of the structures + features : `np.array` -- SOAP-derived descriptors for the structures + energies : `np.array` -- energies of the structures """ module_path = dirname(__file__) - target_structures = join(module_path, "data", "beran_roy_structures.xyz.bz2") - - try: - from ase.io import read - except ImportError: - raise ImportError("load_roy_dataset requires the ASE package.") - - import bz2 - - structures = read(bz2.open(target_structures, "rt"), ":", format="extxyz") - energies = np.array([f.info["energy"] for f in structures]) - - target_features = join(module_path, "data", "beran_roy_features.npz") - features = np.load(target_features)["feats"] - - return Bunch(structures=structures, features=features, energies=energies) + target_properties = join(module_path, "data", "beran_roy_properties.npz") + properties = np.load(target_properties) + + return Bunch( + densities=properties["densities"], + energies=properties["energies"], + structure_types=properties["structure_types"], + features=properties["feats"], + ) diff --git a/src/skmatter/datasets/data/beran_roy_features.npz b/src/skmatter/datasets/data/beran_roy_features.npz deleted file mode 100644 index ba307e433..000000000 Binary files a/src/skmatter/datasets/data/beran_roy_features.npz and /dev/null differ diff --git a/src/skmatter/datasets/data/beran_roy_properties.npz b/src/skmatter/datasets/data/beran_roy_properties.npz new file mode 100644 index 000000000..fec71543b Binary files /dev/null and b/src/skmatter/datasets/data/beran_roy_properties.npz differ diff --git a/src/skmatter/datasets/data/beran_roy_structures.xyz.bz2 b/src/skmatter/datasets/data/beran_roy_structures.xyz.bz2 deleted file mode 100644 index 7c2aaae9f..000000000 Binary files a/src/skmatter/datasets/data/beran_roy_structures.xyz.bz2 and /dev/null differ diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 5dd7f144a..4d7340e6a 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -107,45 +107,16 @@ class ROYTests(unittest.TestCase): def setUpClass(cls): cls.size = 264 cls.shape = (264, 32) - try: - from ase.io import read # NoQa: F401 - - cls.has_ase = True - cls.roy = load_roy_dataset() - except ImportError: - cls.has_ase = False - - def test_load_dataset_without_ase(self): - """Check if the correct exception occurs when ase isn't present.""" - with unittest.mock.patch.dict("sys.modules", {"ase.io": None}): - with self.assertRaises(ImportError) as cm: - _ = load_roy_dataset() - self.assertEqual( - str(cm.exception), "load_roy_dataset requires the ASE package." - ) + cls.roy = load_roy_dataset() def test_dataset_content(self): """Check if the correct number of datapoints are present in the dataset. Also check if the size of the dataset is correct. """ - if self.has_ase is True: - self.assertEqual(len(self.roy["structures"]), self.size) - self.assertEqual(self.roy["features"].shape, self.shape) - self.assertEqual(len(self.roy["energies"]), self.size) - - def test_dataset_consistency(self): - """Check if the energies in the structures are the same as in the explicit - array. - """ - if self.has_ase is True: - self.assertTrue( - np.allclose( - self.roy["energies"], - [f.info["energy"] for f in self.roy["structures"]], - rtol=1e-6, - ) - ) + self.assertEqual(len(self.roy["structure_types"]), self.size) + self.assertEqual(self.roy["features"].shape, self.shape) + self.assertEqual(len(self.roy["energies"]), self.size) if __name__ == "__main__":