Skip to content
This repository has been archived by the owner on Nov 5, 2022. It is now read-only.

Commit

Permalink
fix: improve logging and minor performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanljones committed May 12, 2022
1 parent 7f3ce95 commit 4cbe4a5
Showing 1 changed file with 33 additions and 22 deletions.
55 changes: 33 additions & 22 deletions cmpy/exactdiag.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def compute_groundstate(model, thresh=50):
def solve_sector(model: AbstractManyBodyModel, sector: Sector, cache: dict = None):
sector_key = (sector.n_up, sector.n_dn)
if cache is not None and sector_key in cache:
logger.debug("Loading eig %d, %d", sector.n_up, sector.n_dn)
logger.debug("Loading eig %d, %d (%s)", sector.n_up, sector.n_dn, sector.size)
eigvals, eigvecs = cache[sector_key]
else:
logger.debug("Solving eig %d, %d (%s)", sector.n_up, sector.n_dn, sector.size)
Expand Down Expand Up @@ -121,11 +121,12 @@ def _accumulate_sum(gf, z, evals, evals_p1, evecs_p1, cdag_evec, beta, emin):
num_m = len(evals_p1)
num_n = len(evals)
for m in prange(num_m):
eig_m = evals_p1[m]
z_m = z - eig_m
for n in range(num_n):
eig_m = evals_p1[m]
eig_n = evals[n]
weights = exp_evals[n] + exp_evals_p1[m]
gf += overlap[m, n] * weights / (z + eig_n - eig_m)
gf += overlap[m, n] * weights / (z_m + eig_n)


def accumulate_gf(gf, z, cdag, evals, evecs, evals_p1, evecs_p1, beta, emin=0.0):
Expand Down Expand Up @@ -179,17 +180,13 @@ def _acc_gf(self, sector, sector_p1, evals, evecs, evals_p1, evecs_p1, factor):
e0 = self._gs_energy
accumulate_gf(self._gf, z, cdag, evals, evecs, evals_p1, evecs_p1, beta, e0)

def _acc_occ(self, sector, evals, evecs, factor):
up = sector.up_states
dn = sector.dn_states
def _acc_occ(self, up, dn, evals, evecs, factor):
beta = self.beta
e0 = self._gs_energy
self._occ *= factor
self._occ += occupation(up, dn, evals, evecs, beta, e0, self.pos, self.sigma)

def _acc_occ_double(self, sector, evals, evecs, factor):
up = sector.up_states
dn = sector.dn_states
def _acc_occ_double(self, up, dn, evals, evecs, factor):
beta = self.beta
e0 = self._gs_energy
self._occ_double *= factor
Expand All @@ -203,31 +200,45 @@ def accumulate(self, sector, sector_p1, evals, evecs, evals_p1, evecs_p1):
self._gs_energy = min_energy
logger.debug("New ground state: E_0=%.4f", min_energy)

logger.debug("accumulating")
logger.debug("Accumulating")
up = np.array(sector.up_states, dtype=np.int64)
dn = np.array(sector.dn_states, dtype=np.int64)
self._acc_part(evals, factor)
self._acc_gf(sector, sector_p1, evals, evecs, evals_p1, evecs_p1, factor)
# self._acc_occ(sector, evals, evecs, factor)
# self._acc_occ_double(sector, evals, evecs, factor)
self._acc_occ(up, dn, evals, evecs, factor)
self._acc_occ_double(up, dn, evals, evecs, factor)


def greens_function_lehmann(model, z, beta, pos=0, sigma=UP, eig_cache=None):
logger.debug("Accumulating Lehmann sum (pos=%s, sigma=%s)", pos, sigma)
basis = model.basis

logger.info("Accumulating Lehmann sum (pos=%s, sigma=%s)", pos, sigma)
logger.debug("Sites: %s (%s states)", basis.num_sites, basis.size)

data = GreensFunctionMeasurement(z, beta, pos, sigma)
eig_cache = eig_cache if eig_cache is not None else dict()
for n_up, n_dn in model.iter_fillings():

fillings = list(basis.iter_fillings())
num = len(fillings)
w = len(str(num))
for i, (n_up, n_dn) in enumerate(fillings):
sector = model.get_sector(n_up, n_dn)
sector_p1 = model.basis.upper_sector(n_up, n_dn, sigma)
logger.info("[%s/%s] Sector %s, %s", f"{i+1:>{w}}", num, n_up, n_dn)

sector_p1 = basis.upper_sector(n_up, n_dn, sigma)
if sector_p1 is not None:
eigvals, eigvecs = solve_sector(model, sector, cache=eig_cache)
eigvals_p1, eigvecs_p1 = solve_sector(model, sector_p1, cache=eig_cache)
data.accumulate(sector, sector_p1, eigvals, eigvecs, eigvals_p1, eigvecs_p1)
# else: eig_cache.clear()

logger.debug("-" * 40)
logger.debug("gs-energy: %+.4f", data.gs_energy)
logger.debug("occupation: %.4f", data.occ)
logger.debug("double-occ: %.4f", data.occ_double)
logger.debug("-" * 40)
else:
logger.debug("No upper sector, skipping")
# eig_cache.clear()

logger.info("-" * 40)
logger.info("gs-energy: %+.4f", data.gs_energy)
logger.info("occupation: %.4f", data.occ)
logger.info("double-occ: %.4f", data.occ_double)
logger.info("-" * 40)
return data


Expand Down

0 comments on commit 4cbe4a5

Please sign in to comment.