Skip to content

Commit

Permalink
core: use center 0 mps for csf coefficients
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Oct 14, 2023
1 parent 4e275c2 commit 86085f7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
3 changes: 3 additions & 0 deletions pyblock2/driver/block2main
Original file line number Diff line number Diff line change
Expand Up @@ -4496,6 +4496,9 @@ if not pre_run:
dtrie = DeterminantTRIE(n_sites, True)
mps.info.load_mutable()
mps.load_mutable()
if mps.center != 0:
_print("Warning: sample an MPS with center != 0 will be highly inefficient!")
_print("One can load the MPS and do one extra sweep to change the center.")
dtrie.evaluate(UnfusedMPS(mps), sample_cutoff, max_rank, VectorUInt8(sample_ref))

if "sample_phase" in dic:
Expand Down
11 changes: 9 additions & 2 deletions pyblock2/driver/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3600,6 +3600,12 @@ def get_csf_coefficients(self, ket, cutoff=0.1, max_print=200, iprint=1):
iprint = iprint >= 1 and (self.mpi is None or self.mpi.rank == self.mpi.root)
import numpy as np, time

if ket.center != 0:
ket = self.copy_mps(ket, tag="CSF-TMP")
self.align_mps_center(ket, ref=0)
if iprint:
print("mps center changed (temporarily)")

tx = time.perf_counter()
dtrie = bw.bs.DeterminantTRIE(ket.n_sites, True)
dtrie.evaluate(bw.bs.UnfusedMPS(ket), cutoff)
Expand Down Expand Up @@ -3638,9 +3644,10 @@ def get_csf_coefficients(self, ket, cutoff=0.1, max_print=200, iprint=1):
def align_mps_center(self, ket, ref):
if self.mpi is not None:
self.mpi.barrier()
refc = ref if isinstance(ref, int) else ref.center
ket.info.bond_dim = max(ket.info.bond_dim, ket.info.get_max_bond_dimension())
if ket.center != ref.center:
if ref.center == 0:
if ket.center != refc:
if refc == 0:
if ket.dot == 2:
ket.center += 1
if ket.canonical_form[-1] == 'C':
Expand Down
4 changes: 2 additions & 2 deletions src/dmrg/determinant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace block2 {
// time complexity: (D = MPS bond dimension)
// (n_dets << 4^n_sites) : n_sites * n_dets * D * D
// (n_dets ~ 4^n_sites) : (4 / 3) * n_dets * D * D
template <typename D, typename FL, uint8_t L = 4, typename IT = long long>
template <typename D, typename FL, uint8_t L = 4, typename IT = uint32_t>
struct TRIE {
typedef typename GMatrix<FL>::FP FP;
typedef IT XIT;
Expand Down Expand Up @@ -112,7 +112,7 @@ struct TRIE {
}
// find the index of a determinant
// dets must be sorted
IT find(const vector<uint8_t> &det) {
int find(const vector<uint8_t> &det) {
assert((int)det.size() == n_sites);
IT cur = 0;
for (int i = 0; i < n_sites; i++) {
Expand Down

0 comments on commit 86085f7

Please sign in to comment.