Skip to content

Commit

Permalink
core driver: multiply
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Aug 31, 2022
1 parent c8b2dca commit 8bcae23
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions pyblock2/driver/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ def get_conventional_qc_mpo(self, fcidump):
mpo = bw.bs.ParallelMPO(mpo, self.prule)
return mpo

def get_identity_mpo(self):
return self.get_mpo(self.expr_builder().add_term("", [], 1.0).finalize())

def get_qc_mpo(self, h1e, g2e, ecore=0, para_type=None, reorder=None, iprint=1):
import numpy as np

Expand Down Expand Up @@ -888,17 +891,23 @@ def split_mps(self, ket, iroot, tag):
iket = self.adjust_mps(iket)[0]
return iket

def multiply(self, bra, mpo, ket, n_sweeps=10, tol=1e-8, bond_dims=None, iprint=0):
def multiply(self, bra, mpo, ket, n_sweeps=10, tol=1e-8, bond_dims=None,
bra_bond_dims=None, cutoff=1E-24, iprint=0):
bw = self.bw
if bra.info.tag == ket.info.tag:
raise RuntimeError("Same tag for bra and ket!!")
if bond_dims is None:
bond_dims = [ket.info.bond_dim]
if bra_bond_dims is None:
bra_bond_dims = [bra.info.bond_dim]
self.align_mps_center(bra, ket)
me = bw.bs.MovingEnvironment(mpo, bra, ket, "MULT")
me.delayed_contraction = bw.b.OpNamesSet.normal_ops()
me.cached_contraction = True
me.init_environments(iprint >= 3)
cps = bw.bs.Linear(me, bw.b.VectorUBond(bond_dims), bw.b.VectorUBond(bond_dims))
cps = bw.bs.Linear(me, bw.b.VectorUBond(bra_bond_dims), bw.b.VectorUBond(bond_dims))
cps.iprint = iprint
cps.cutoff = cutoff
norm = cps.solve(n_sweeps, ket.center == 0, tol)
if self.mpi is not None:
self.mpi.barrier()
Expand Down

0 comments on commit 8bcae23

Please sign in to comment.