Skip to content

Commit

Permalink
remove psi::Psi<T, Device> from Diago_DavSubspace (deepmodeling#4416)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozhihan authored Jun 17, 2024
1 parent 9217fc1 commit ec891a1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 93 deletions.
149 changes: 59 additions & 90 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,11 @@ Diago_DavSubspace<T, Device>::~Diago_DavSubspace()
}

template <typename T, typename Device>
int Diago_DavSubspace<T, Device>::diag_once(

const Func& hpsi_func,
T* psi_in,

psi::Psi<T, Device>& psi,

Real* eigenvalue_in_hsolver,
const std::vector<bool>& is_occupied)
int Diago_DavSubspace<T, Device>::diag_once(const Func& hpsi_func,
T* psi_in,
const int psi_in_dmax,
Real* eigenvalue_in_hsolver,
const std::vector<bool>& is_occupied)
{
ModuleBase::timer::tick("Diago_DavSubspace", "diag_once");

Expand Down Expand Up @@ -119,15 +115,10 @@ int Diago_DavSubspace<T, Device>::diag_once(
syncmem_complex_op()(this->ctx,
this->ctx,
this->psi_in_iter + m * this->dim,
psi.get_k_first() ? &psi(m, 0) : &psi(m, 0, 0),
psi_in + m * psi_in_dmax,
this->dim);
}

// auto psi_iter_wrapper = psi::Psi<T, Device>(this->psi_in_iter, 1, this->nbase_x, this->dim);
// // calculate H|psi>
// hpsi_info dav_hpsi_in(&psi_iter_wrapper, psi::Range(1, 0, 0, psi_iter_wrapper.get_nbands() - 1), this->hphi);
// phm_in->ops->hPsi(dav_hpsi_in);

hpsi_func(this->hphi, this->psi_in_iter, this->nbase_x, this->dim, 0, this->nbase_x - 1);

this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc);
Expand Down Expand Up @@ -155,18 +146,15 @@ int Diago_DavSubspace<T, Device>::diag_once(
{
dav_iter++;

this->cal_grad(

hpsi_func,

this->dim,
nbase,
this->notconv,
this->psi_in_iter,
this->hphi,
this->vcc,
unconv.data(),
&eigenvalue_iter);
this->cal_grad(hpsi_func,
this->dim,
nbase,
this->notconv,
this->psi_in_iter,
this->hphi,
this->vcc,
unconv.data(),
&eigenvalue_iter);

this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc);

Expand Down Expand Up @@ -212,23 +200,22 @@ int Diago_DavSubspace<T, Device>::diag_once(
ModuleBase::timer::tick("Diago_DavSubspace", "last");

// updata eigenvectors of Hamiltonian
setmem_complex_op()(this->ctx, psi.get_pointer(), 0, n_band * psi.get_nbasis());
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// haozhihan repalce 2022-10-18
setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax);

gemm_op<T, Device>()(this->ctx,
'N',
'N',
this->dim, // m: row of A,C
this->n_band, // n: col of B,C
nbase, // k: col of A, row of B
this->dim,
this->n_band,
nbase,
this->one,
this->psi_in_iter, // A dim * nbase
this->psi_in_iter,
this->dim,
this->vcc, // B nbase * n_band
this->vcc,
this->nbase_x,
this->zero,
psi.get_pointer(), // C dim * n_band
psi.get_nbasis());
psi_in,
psi_in_dmax);

if (!this->notconv || (dav_iter == this->iter_nmax))
{
Expand All @@ -243,16 +230,26 @@ int Diago_DavSubspace<T, Device>::diag_once(
// then replace the first N (=nband) basis vectors with the current
// estimate of the eigenvectors and set the basis dimension to N;

// update this->psi_in_iter according to psi_in
for (size_t i = 0; i < this->n_band; i++)
{
syncmem_complex_op()(this->ctx,
this->ctx,
this->psi_in_iter + i * this->dim,
psi_in + i * psi_in_dmax,
this->dim);
}

this->refresh(this->dim,
this->n_band,
nbase,
eigenvalue_in_hsolver,
psi,
this->psi_in_iter,
this->hphi,
this->hcc,
this->scc,
this->vcc);

ModuleBase::timer::tick("Diago_DavSubspace", "last");
}
}
Expand Down Expand Up @@ -289,18 +286,17 @@ void Diago_DavSubspace<T, Device>::cal_grad(const Func& hpsi_func,
gemm_op<T, Device>()(this->ctx,
'N',
'N',
this->dim, // m: row of A,C
notconv, // n: col of B,C
nbase, // k: col of A, row of B
this->one, // alpha
psi_iter, // A
this->dim, // LDA
vcc, // B
this->nbase_x, // LDB
this->zero, // belta
psi_iter + nbase * this->dim, // C dim * notconv
this->dim // LDC
);
this->dim,
notconv,
nbase,
this->one,
psi_iter,
this->dim,
vcc,
this->nbase_x,
this->zero,
psi_iter + nbase * this->dim,
this->dim);

for (int m = 0; m < notconv; m++)
{
Expand All @@ -317,18 +313,17 @@ void Diago_DavSubspace<T, Device>::cal_grad(const Func& hpsi_func,
gemm_op<T, Device>()(this->ctx,
'N',
'N',
this->dim, // m: row of A,C
notconv, // n: col of B,C
nbase, // k: col of A, row of B
this->one, // alpha
hphi, // A dim * nbase
this->dim, // LDA
vcc, // B nbase * notconv
this->nbase_x, // LDB
this->one, // belta
this->dim,
notconv,
nbase,
this->one,
hphi,
this->dim,
vcc,
this->nbase_x,
this->one,
psi_iter + (nbase) * this->dim,
this->dim // LDC
);
this->dim);

// "precondition!!!"
std::vector<Real> pre(this->dim, 0.0);
Expand Down Expand Up @@ -365,10 +360,6 @@ void Diago_DavSubspace<T, Device>::cal_grad(const Func& hpsi_func,
psi_norm[i]);
}

// auto psi_iter_wrapper = psi::Psi<T, Device>(psi_iter, 1, this->nbase_x, this->dim);
// // "calculate H|psi>" for not convergence bands
// hpsi_info dav_hpsi_in(&psi_iter_wrapper, psi::Range(1, 0, nbase, nbase + notconv - 1), &hphi[nbase * this->dim]);
// phm_in->ops->hPsi(dav_hpsi_in);
hpsi_func(&hphi[nbase * this->dim], psi_iter, this->nbase_x, this->dim, nbase, nbase + notconv - 1);

ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
Expand Down Expand Up @@ -516,7 +507,6 @@ void Diago_DavSubspace<T, Device>::diag_zhegvx(const int& nbase,
T* scc,
const int& nbase_x,
std::vector<Real>* eigenvalue_iter,
// Real* eigenvalue_iter,
T* vcc,
bool init,
bool is_subspace)
Expand Down Expand Up @@ -647,7 +637,7 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
const int& nband,
int& nbase,
const Real* eigenvalue_in_hsolver,
const psi::Psi<T, Device>& psi,
// const psi::Psi<T, Device>& psi,
T* psi_iter,
T* hp,
T* sp,
Expand All @@ -656,15 +646,6 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
{
ModuleBase::timer::tick("Diago_DavSubspace", "refresh");

// update psi
for (size_t i = 0; i < nband; i++)
{
syncmem_complex_op()(this->ctx,
this->ctx,
psi_iter + i * this->dim,
&psi(i, 0),
this->dim);
}
gemm_op<T, Device>()(this->ctx,
'N',
'N',
Expand All @@ -681,11 +662,7 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
this->dim);

// update hphi
syncmem_complex_op()(this->ctx,
this->ctx,
hphi,
psi_iter + nband * this->dim,
this->dim * nband);
syncmem_complex_op()(this->ctx, this->ctx, hphi, psi_iter + nband * this->dim, this->dim * nband);

nbase = nband;

Expand Down Expand Up @@ -816,15 +793,7 @@ int Diago_DavSubspace<T, Device>::diag(const Func& hpsi_func,
DiagoIterAssist<T, Device>::diagH_subspace(phm_in, psi, psi, eigenvalue_in_hsolver, psi.get_nbands());
}

sum_iter += this->diag_once(

hpsi_func,
psi_in,

psi,

eigenvalue_in_hsolver,
is_occupied);
sum_iter += this->diag_once(hpsi_func, psi_in, psi.get_nbasis(), eigenvalue_in_hsolver, is_occupied);

++ntry;

Expand Down
5 changes: 2 additions & 3 deletions source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Diago_DavSubspace : public DiagH<T, Device>

hamilt::Hamilt<T, Device>* phm_in,
psi::Psi<T, Device>& phi,

Real* eigenvalue_in,
const std::vector<bool>& is_occupied,
const bool& scf_type);
Expand Down Expand Up @@ -105,7 +105,6 @@ class Diago_DavSubspace : public DiagH<T, Device>
const int& nband,
int& nbase,
const Real* eigenvalue,
const psi::Psi<T, Device>& psi,
T* psi_iter,
T* hphi,
T* hcc,
Expand All @@ -124,7 +123,7 @@ class Diago_DavSubspace : public DiagH<T, Device>

int diag_once(const Func& hpsi_func,
T* psi_in,
psi::Psi<T, Device>& psi,
const int psi_in_dmax,
Real* eigenvalue_in,
const std::vector<bool>& is_occupied);

Expand Down

0 comments on commit ec891a1

Please sign in to comment.