Skip to content

Commit

Permalink
allow state-specific keyword in pyscf interface
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Sep 21, 2022
1 parent 93079c7 commit c7809e9
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion pyblock2/driver/block2main
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,11 @@ if "compression" in dic or "stopt_compression" in dic or "delta_t" in dic:
tags "read_mps_tags" and the output MPS tags "mps_tags" cannot
be the same!""")

if "statespecific" in dic and "proj_weights" in dic:
if not (len(mps_tags) == 1 and os.path.isfile(scratch + "/%s-mps_info.bin" % mps_tags[0])) \
and not os.path.isfile(scratch + "/mps_info.bin"):
del dic["fullrestart"]

# prepare mps
if len(mps_tags) > 1 or ("compression" in dic and "random_mps_init" not in dic) \
or "stopt_sampling" in dic:
Expand Down Expand Up @@ -1965,6 +1970,7 @@ if not pre_run:
_print("para mpo finished", time.perf_counter() - tx)

if mps is not None:
mps.save_data()
mps.save_mutable()
mps.deallocate()
mps_info.save_mutable()
Expand All @@ -1983,6 +1989,7 @@ if not pre_run:
assert nroots != 1

ext_mpss = []
dmrg_energies = []
for iroot in range(nroots):
tx = time.perf_counter()
_print('----- root = %3d / %3d -----' % (iroot, nroots))
Expand Down Expand Up @@ -2035,7 +2042,10 @@ if not pre_run:
dmrg.state_specific = True
proj_weights = dic.get("proj_weights", None)
if proj_weights is not None:
dmrg.projection_weights = VectorFP([float(x) for x in proj_weights.split()][:iroot])
proj_weights = [float(x) for x in proj_weights.split()][:iroot]
if len(proj_weights) == 1:
proj_weights = proj_weights * iroot
dmrg.projection_weights = VectorFP(proj_weights)
dmrg.iprint = max(min(outputlevel, 3), 0)
for ext_mps in dmrg.ext_mpss:
ext_me = MovingEnvironment(
Expand Down Expand Up @@ -2096,6 +2106,7 @@ if not pre_run:

if MPI is None or MPI.rank == 0:
np.save(scratch + "/E_dmrg-%d.npy" % iroot, E_dmrg)
dmrg_energies.append(E_dmrg)
np.save(scratch + "/bond_dims-%d.npy" %
iroot, bond_dims[:len(discarded_weights)])
np.save(scratch + "/sweep_energies-%d.npy" %
Expand All @@ -2110,6 +2121,12 @@ if not pre_run:
ext_mpss[iroot].info.save_data(
scratch + '/%s-mps_info-ss-%d.bin' % (mps_tags[0], iroot))

if MPI is None or MPI.rank == 0:
if stackblock_compat:
with open(os.path.join(scratch + "/dmrg.e"), "wb") as f:
import struct
f.write(struct.pack('d' * nroots, *dmrg_energies))

if "twodot_to_onedot" in dic:
dot = 1

Expand Down Expand Up @@ -2160,6 +2177,8 @@ if not pre_run:
proj_weights = dic.get("proj_weights", None)
assert proj_weights is not None
proj_weights = VectorFP([float(x) for x in proj_weights.split()])
if len(proj_weights) == 1:
proj_weights = VectorFP([float(x) for x in proj_weights.split()] * len(proj_tags))
assert len(proj_weights) == len(proj_tags)
ext_mpss = []
for ipj in range(len(proj_weights)):
Expand Down

0 comments on commit c7809e9

Please sign in to comment.