Skip to content

Commit

Permalink
multiple inputs for tenes_std
Browse files Browse the repository at this point in the history
  • Loading branch information
yomichi committed Jan 28, 2024
1 parent 2474fee commit d1f0382
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 12 deletions.
6 changes: 5 additions & 1 deletion docs/sphinx/en/how_to_use/standard_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ Usage of ``tenes_std``
$ tenes_std std.toml
- Takes a file as an argument
- Takes input files as arguments
- Multiple input files can be specified
- When parameters are duplicated, ``tenes_std`` stops with an error
- Sections that can be specified multiple times, such as ``[[observable.onesite]]``, can be specified in multiple input files simultaneously
- In this case, the sections in the latter input file are appended to those in the former input file
- Output an input file for ``tenes``
- Command line options are as follows
- ``--help``
Expand Down
6 changes: 5 additions & 1 deletion docs/sphinx/ja/how_to_use/standard_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@

.. code:: bash
$ tenes_std std.toml
$ tenes_std std.toml [std_2.toml ...]
- 引数としてファイルを取ります
- 複数の入力ファイルからひとつの出力ファイルを生成可能です
- 同じ名前のパラメータを複数の入力で指定した場合はエラー終了します
- ``[[observable.onesite]]`` などの複数指定できるセクションについては同時に指定可能です
- 後ろのファイルの内容が前のファイルの内容に追記されます
- ``tenes`` の入力ファイルを出力します
- コマンドラインオプションは以下の通りです
- ``--help``
Expand Down
9 changes: 3 additions & 6 deletions sample/05_hardcore_boson_triangular/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,10 @@
cmd = f"tenes_simple {simple_toml} -o {std_toml}"
subprocess.call(cmd.split())

cmd = f"tenes_std -o {input_toml} {std_toml}"
if calculate_sq:
# append twosite observable <n_i n_j> to std.toml
with open(std_toml, "a") as f:
with open("nn_obs.toml") as fin:
for line in fin:
f.write(line)
cmd = f"tenes_std {std_toml} -o {input_toml}"
# twosite observable <n_i n_j> are also calculated
cmd += " nn_obs.toml"
subprocess.call(cmd.split())

cmd = f"{MPI_cmd} tenes {input_toml}"
Expand Down
8 changes: 8 additions & 0 deletions test/data/output_std_mode.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ bonds = """
1 1 1
"""
ops = [0, 1]
[[observable.multisite]]
name = "multisite"
group = 0
dim= [2, 2, 2]
multisites = """
0 1 0 0 1
"""
ops = [0, 0, 0]

[evolution]
[[evolution.simple]]
Expand Down
23 changes: 23 additions & 0 deletions test/std_mode.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ cmd = [
"-o",
join("output_std_mode.toml"),
join("data", "std_mode.toml"),
join("data", "std_mode_multi.toml"),
]
subprocess.call(cmd)

Expand Down Expand Up @@ -101,6 +102,28 @@ if not obs_result:
print('check for the section "observable.twosite" fails')
result = False

obs_result = True
if len(res['observable']['multisite']) != len(ref['observable']['multisite']):
obs_result = False
result = False
else:
for obs_res, obs_ref in zip(res['observable']['multisite'], ref['observable']['multisite']):
obs_result = obs_result and obs_res['name'] == obs_ref['name']
obs_result = obs_result and obs_res['group'] == obs_ref['group']
obs_result = obs_result and obs_res['multisites'] == obs_ref['multisites']
if 'elements' in obs_ref:
obs_result = obs_result and 'elements' in obs_res
obs_result = obs_result and obs_res['dim'] == obs_ref['dim']
elem_res = load_str_as_array(obs_res['elements'])
elem_ref = load_str_as_array(obs_ref['elements'])
obs_result = obs_result and np.allclose(elem_ref, elem_res)
else:
obs_result = obs_result and 'elements' not in obs_res
obs_result = obs_result and obs_res['ops'] == obs_ref['ops']
if not obs_result:
print('check for the section "observable.multisite" fails')
result = False


for name in ('simple', 'full'):
evo_result = True
Expand Down
79 changes: 75 additions & 4 deletions tool/tenes_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses

from collections import namedtuple
from itertools import product
from typing import (
TextIO,
Expand Down Expand Up @@ -89,6 +88,75 @@ def value_to_str(v) -> str:
return "{}".format(v)


def merge_input_dict(
d1: dict, d2: dict
) -> None:
section1 = d1.get("parameter", {})
section2 = d2.get("parameter", {})
subsection_names = ("general", "simple_update", "full_update", "ctm", "random")
for name in subsection_names:
sub1 = section1.get(name, {})
sub2 = section2.get(name, {})
for k in sub1.keys():
if k in sub2:
msg = f"parameter.{name}.{k} is defined in multiple input files"
raise RuntimeError(msg)
for k in sub2.keys():
sub1[k] = sub2[k]
if len(sub1) > 0:
section1[name] = sub1
if len(section1) > 0:
d1["parameter"] = section1

section_names = ("correlation", "correlation_length")
for section_name in section_names:
section1 = d1.get(section_name, {})
section2 = d2.get(section_name, {})
for name in section1.keys():
if name in section2:
msg = f"{section_name}.{name} is defined in multiple input files"
raise RuntimeError(msg)
for name in section2.keys():
section1[name] = section2[name]
if len(section1) > 0:
d1[section_name] = section1

section1 = d1.get("tensor", {})
section2 = d2.get("tensor", {})
for k in section1.keys():
if k == "unitcell": continue
if k in section2:
msg = f"tensor.{k} is defined in multiple input files"
raise RuntimeError(msg)
for k in section2.keys():
if k == "unitcell": continue
section1[k] = section2[k]
if "unitcell" not in section1:
section1["unitcell"] = []
for u2 in section2.get("unitcell", []):
section1["unitcell"].append(u2)
if len(section1) > 0:
d1["tensor"] = section1

section1 = d1.get("hamiltonian", [])
section2 = d2.get("hamiltonian", [])
for h2 in section2:
section1.append(h2)
if len(section1) > 0:
d1["hamiltonian"] = section1

section1 = d1.get("observable", {})
section2 = d2.get("observable", {})
for name in ("onesite", "twosite", "multisite"):
sub1 = section1.get(name, [])
sub2 = section2.get(name, [])
for o2 in sub2:
sub1.append(o2)
if len(sub1) > 0:
section1[name] = sub1
if len(section1) > 0:
d1["observable"] = section1

class Bond:
source_site: int
dx: int
Expand Down Expand Up @@ -1172,7 +1240,7 @@ def to_toml(self, f: TextIO):
description="Input converter for TeNeS", add_help=True
)

parser.add_argument("input", help="Input TOML file")
parser.add_argument("input", nargs="+", help="Input TOML file")

parser.add_argument(
"-o", "--output", dest="output", default="input.toml", help="Output TOML file"
Expand All @@ -1183,11 +1251,14 @@ def to_toml(self, f: TextIO):

args = parser.parse_args()

if args.input == args.output:
if args.output in args.input:
print("The names of input and output are the same")
sys.exit(1)

param = toml.load(args.input)
params = [lower_dict(toml.load(f)) for f in args.input]
param = params[0]
for p in params[1:]:
merge_input_dict(param, p)
model = Model(param)

with open(args.output, "w") as f:
Expand Down

0 comments on commit d1f0382

Please sign in to comment.