diff --git a/pyblock2/driver/block2main b/pyblock2/driver/block2main index 55f2ffd9..ca1f7d62 100755 --- a/pyblock2/driver/block2main +++ b/pyblock2/driver/block2main @@ -1153,6 +1153,11 @@ if "compression" in dic or "stopt_compression" in dic or "delta_t" in dic: tags "read_mps_tags" and the output MPS tags "mps_tags" cannot be the same!""") +if "statespecific" in dic and "proj_weights" in dic: + if not (len(mps_tags) == 1 and os.path.isfile(scratch + "/%s-mps_info.bin" % mps_tags[0])) \ + and not os.path.isfile(scratch + "/mps_info.bin"): + del dic["fullrestart"] + # prepare mps if len(mps_tags) > 1 or ("compression" in dic and "random_mps_init" not in dic) \ or "stopt_sampling" in dic: @@ -1965,6 +1970,7 @@ if not pre_run: _print("para mpo finished", time.perf_counter() - tx) if mps is not None: + mps.save_data() mps.save_mutable() mps.deallocate() mps_info.save_mutable() @@ -1983,6 +1989,7 @@ if not pre_run: assert nroots != 1 ext_mpss = [] + dmrg_energies = [] for iroot in range(nroots): tx = time.perf_counter() _print('----- root = %3d / %3d -----' % (iroot, nroots)) @@ -2035,7 +2042,10 @@ if not pre_run: dmrg.state_specific = True proj_weights = dic.get("proj_weights", None) if proj_weights is not None: - dmrg.projection_weights = VectorFP([float(x) for x in proj_weights.split()][:iroot]) + proj_weights = [float(x) for x in proj_weights.split()][:iroot] + if len(proj_weights) == 1: + proj_weights = proj_weights * iroot + dmrg.projection_weights = VectorFP(proj_weights) dmrg.iprint = max(min(outputlevel, 3), 0) for ext_mps in dmrg.ext_mpss: ext_me = MovingEnvironment( @@ -2096,6 +2106,7 @@ if not pre_run: if MPI is None or MPI.rank == 0: np.save(scratch + "/E_dmrg-%d.npy" % iroot, E_dmrg) + dmrg_energies.append(E_dmrg) np.save(scratch + "/bond_dims-%d.npy" % iroot, bond_dims[:len(discarded_weights)]) np.save(scratch + "/sweep_energies-%d.npy" % @@ -2110,6 +2121,12 @@ if not pre_run: ext_mpss[iroot].info.save_data( scratch + '/%s-mps_info-ss-%d.bin' % (mps_tags[0], iroot)) + if MPI is None or MPI.rank == 0: + if stackblock_compat: + with open(os.path.join(scratch + "/dmrg.e"), "wb") as f: + import struct + f.write(struct.pack('d' * nroots, *dmrg_energies)) + if "twodot_to_onedot" in dic: dot = 1 @@ -2160,6 +2177,8 @@ if not pre_run: proj_weights = dic.get("proj_weights", None) assert proj_weights is not None proj_weights = VectorFP([float(x) for x in proj_weights.split()]) + if len(proj_weights) == 1: + proj_weights = VectorFP([float(x) for x in proj_weights.split()] * len(proj_tags)) assert len(proj_weights) == len(proj_tags) ext_mpss = [] for ipj in range(len(proj_weights)):