Skip to content

Commit

Permalink
and example benchmarks and utils.save_to_json, utils.load_from_json
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Aug 9, 2024
1 parent 984c8da commit f329eeb
Show file tree
Hide file tree
Showing 9 changed files with 14,861 additions and 4 deletions.
57 changes: 53 additions & 4 deletions cotengra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,9 +1026,7 @@ def rand_tree(
seed=seed,
)

tree = array_contract_tree(
inputs, output, size_dict, optimize=optimize
)
tree = array_contract_tree(inputs, output, size_dict, optimize=optimize)
return tree


Expand Down Expand Up @@ -1280,7 +1278,7 @@ def make_arrays_from_inputs(inputs, size_dict, seed=None, dtype="float64"):
elif dtype == "complex64":
array = (array + 1j * rng.normal(size=shape)).astype(np.complex64)
elif dtype == "complex128":
array = (array + 1j * rng.normal(size=shape))
array = array + 1j * rng.normal(size=shape)
elif dtype != "float64":
raise ValueError(f"unsupported dtype {dtype}")

Expand Down Expand Up @@ -1558,3 +1556,54 @@ def parse_einsum_input(args, shapes=False, tuples=False, constants=None):
inputs, output = parse_equation_ellipses(eq, _shapes, tuples=tuples)

return (inputs, output, arrays)


def save_to_json(inputs, output, size_dict, filename):
"""Save a contraction to a json file.
Parameters
----------
inputs : list[list[str]]
The input terms.
output : list[str]
The output term.
size_dict : dict[str, int]
The index size dictionary.
filename : str
The filename to save to.
"""
import json

data = {
"inputs": tuple(map(tuple, inputs)),
"output": tuple(output),
"size_dict": size_dict,
}

with open(filename, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)


def load_from_json(filename):
"""Load a contraction from a json file.
Parameters
----------
filename : str
The filename to load from.
Returns
-------
inputs : list[list[str]]
The input terms.
output : list[str]
The output term.
size_dict : dict[str, int]
The index size dictionary.
"""
import json

with open(filename, "r", encoding="utf-8") as f:
data = json.load(f)

return (data["inputs"], data["output"], data["size_dict"])
Loading

0 comments on commit f329eeb

Please sign in to comment.