Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rascal json encoder #366

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/rascal/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
get_supported_io_versions,
dump_obj,
load_obj,
json_dumps_frame,
)

# Warning potential dependency loop: FPS imports models, which imports KRR,
Expand Down
56 changes: 56 additions & 0 deletions bindings/rascal/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import Iterable
import numpy as np
import json
import datetime
from copy import deepcopy
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -422,3 +423,58 @@ def _load_npy(data, path):
if len(v) == 2:
if "npy" == v[0]:
data[k] = np.array(v[1])


class RascalEncoder(json.JSONEncoder):
def default(self, obj):
if hasattr(obj, "todict"):
d = obj.todict()

if not isinstance(d, dict):
raise RuntimeError(
"todict() of {} returned object of type {} "
"but should have returned dict".format(obj, type(d))
)
if hasattr(obj, "ase_objtype"):
d["__ase_objtype__"] = obj.ase_objtype

return d
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.bool_):
return bool(obj)
if isinstance(obj, datetime.datetime):
return {"__datetime__": obj.isoformat()}
if isinstance(obj, complex):
return {"__complex__": (obj.real, obj.imag)}
return json.JSONEncoder.default(self, obj)


def json_dumps_frame(frames, **json_dumps_kwargs):
"""Serialize frames to a JSON formatted string.

Parameters
----------
frames : list(ase.Atoms) or ase.Atoms
List of atomic structures (or single one) to be dumped to a json

json_dumps_kwargs : dict
List of arguments forwarded to json.dumps

Return
------
T
"""
if type(frames) is not list:
frames = [frames]

json_frames = {}
for i, frame in enumerate(frames):
json_frames[str(i)] = json.loads(json.dumps(frame, cls=RascalEncoder))

json_frames["ids"] = list(range(len(frames)))
json_frames["nextid"] = len(frames)

return json.dumps(json_frames, **json_dumps_kwargs)
3 changes: 2 additions & 1 deletion tests/python/python_binding_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from python_models_test import TestNumericalKernelGradient, TestCosineKernel
from python_math_test import TestMath
from python_test_sparsify_fps import TestFPS
from python_utils_test import TestOptimalRadialBasis
from python_utils_test import TestOptimalRadialBasis, TestIO

from md_calculator_test import TestGenericMD


Expand Down
40 changes: 40 additions & 0 deletions tests/python/python_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,23 @@
get_radial_basis_pca,
get_radial_basis_projections,
get_optimal_radial_basis_hypers,
json_dumps_frame,
)
from rascal.lib import neighbour_list
from rascal.neighbourlist import base

from test_utils import load_json_frame, BoxList, Box, dot
import tempfile
import unittest
import numpy as np
import sys
import os
import json
import tempfile
from copy import copy, deepcopy
from scipy.stats import ortho_group
import pickle
import ase.io

rascal_reference_path = "reference_data"
inputs_path = os.path.join(rascal_reference_path, "inputs")
Expand Down Expand Up @@ -91,3 +97,37 @@ def test_hypers_construction(self):
soap_feats_2 = soap_opt_2.transform(self.frames).get_features(soap_opt_2)

self.assertTrue(np.allclose(soap_feats, soap_feats_2))


class TestIO(unittest.TestCase):
def setUp(self):
self.fns = [
os.path.join(inputs_path, "CaCrP2O7_mvc-11955_symmetrized.json"),
os.path.join(inputs_path, "SiC_moissanite_supercell.json"),
os.path.join(inputs_path, "methane.json"),
]

def test_json_dumps_frame(self):
"""
Checks if json file decoded by RascalEncoder in dumps_frame can be read
by rascal
"""
nl_options = [
dict(name="centers", args=dict()),
dict(name="neighbourlist", args=dict(cutoff=3)),
dict(name="centercontribution", args=dict()),
dict(name="strict", args=dict(cutoff=3)),
]
managers = base.StructureCollectionFactory(nl_options)
for fn in self.fns:
frame = ase.io.read(fn)
dumped_json = json_dumps_frame(frame)
tmp = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False)
tmp.write(dumped_json)
try:
managers.add_structures(tmp.name)
tmp.close()
os.unlink(tmp.name)
except:
tmp.close()
os.unlink(tmp.name)