You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all,
Currently I am running my program, I have a function called update, and I get error as follows:
Traceback (most recent call last):
File "./hydrogen-finiteT-hf-new/src/main.py", line 720, in <module>
= update(params_flow, params_van, params_wfn, state_idx, mo_coeff, bands, s, x, keys, data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc, flow_fisher_acc, van_fisher_acc, wfn_fisher_acc, wfn_score_mean_acc, final_step, mix_fisher, opt_state_wfn_e, opt_state_flow_p,opt_state_van)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 3743729664 bytes.
It is clear that my gpu is out of memory, and I want to check which lines in function "update" caused this, however, I can even print the last line of function "update", the fig is in attached files.
Now I am confused about it, and I do not know which line caused this error, could you please give me some advice?
Here is the whole code
import os
import time
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from jax import config
config.update("jax_enable_x64", True)
#config.update("jax_traceback_filtering", "off")
from functools import partial
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import argparse
import sys
sys.path.append("src/")
import checkpoint
from vmc import sample_s_and_x, make_loss
from mcmc import adjust_mc_width
from ad import make_grad_real, make_grad_complex
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
#os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
key = jax.random.PRNGKey(42)
jax.print_environment_info()
devices = jax.devices()
num_devices = jax.device_count()
print("GPU devices:")
for i, device in enumerate(devices):
print("---- ", i, " ", device.device_kind)
import argparse
parser = argparse.ArgumentParser(description="Hydrogen")
# path
parser.add_argument("--folder", default="../data/", help="the folder to save data")
parser.add_argument("--restore_path", default=None, help="checkpoint path or file")
# physical parameters
parser.add_argument("--n", type=int, default=14, help="total number of electrons == # of protons")
parser.add_argument("--dim", type=int, default=3, help="spatial dimension")
parser.add_argument("--rs", type=float, default=1.86, help="rs")
parser.add_argument("--T", type=float, default=31250.0, help="temperature in Kelvin")
# many-body state distribution: autoregressive transformer
parser.add_argument("--Emax", type=int, default=10, help="energy cutoff for the single-particle orbitals")
parser.add_argument("--nlayers", type=int, default=2, help="CausalTransformer: number of layers")
parser.add_argument("--modelsize", type=int, default=16, help="CausalTransformer: embedding dimension")
parser.add_argument("--nheads", type=int, default=4, help="CausalTransformer:number of heads")
parser.add_argument("--nhidden", type=int, default=32, help="CausalTransformer: number of hidden units of the MLP within each layer")
# normalizing flow
parser.add_argument("--flow_steps", type=int, default=1, help="FermiNet: transformation steps")
parser.add_argument("--flow_depth", type=int, default=3, help="FermiNet: network depth")
parser.add_argument("--flow_h1size", type=int, default=64, help="FermiNet: single-particle feature size")
parser.add_argument("--flow_h2size", type=int, default=16, help="FermiNet: two-particle feature size")
parser.add_argument("--wfn_depth", type=int, default=3, help="FermiNet: network depth")
parser.add_argument("--wfn_h1size", type=int, default=32, help="FermiNet: single-particle feature size")
parser.add_argument("--wfn_h2size", type=int, default=16, help="FermiNet: two-particle feature size")
parser.add_argument("--Nf", type=int, default=5, help="FermiNet: number of fequencies")
parser.add_argument("--K", type=int, default=4, help="FermiNet: number of dets")
parser.add_argument("--nk", type=int, default=None, help="FermiNet: number of plane wave basis")
# parameters relevant to th Ewald summation of Coulomb interaction
parser.add_argument("--Gmax", type=int, default=15, help="k-space cutoff in the Ewald summation of Coulomb potential")
parser.add_argument("--kappa", type=int, default=10, help="screening parameter (in unit of 1/L) in Ewald summation")
# MCMC
parser.add_argument("--mc_therm", type=int, default=5, help="MCMC thermalization steps")
parser.add_argument("--mc_steps_p", type=int, default=100, help="MCMC update steps")
parser.add_argument("--mc_steps_e", type=int, default=400, help="MCMC update steps")
parser.add_argument("--mc_width_p", type=float, default=0.02, help="standard deviation of the Gaussian proposal in MCMC update")
parser.add_argument("--mc_width_e", type=float, default=0.04, help="standard deviation of the Gaussian proposal in MCMC update")
# technical miscellaneous
parser.add_argument("--hutchinson", action='store_true', help="use Hutchinson's trick to compute the laplacian")
parser.add_argument("--remat", action='store_true', help="remat FermiNet and Transformer to save RAM")
# optimizer parameters
parser.add_argument("--lr_flow", type=float, default=1.00, help="initial learning rate")
parser.add_argument("--lr_van", type=float, default=1.00, help="initial learning rate")
parser.add_argument("--lr_wfn", type=float, default=1.00, help="initial learning rate")
parser.add_argument("--decay", type=float, default=1e-2, help="learning rate decay")
parser.add_argument("--damping_flow", type=float, default=1e-3, help="damping")
parser.add_argument("--damping_van", type=float, default=1e-3, help="damping")
parser.add_argument("--damping_wfn", type=float, default=1e-3, help="damping")
parser.add_argument("--maxnorm_flow", type=float, default=1e-3, help="gradnorm maximum")
parser.add_argument("--maxnorm_van", type=float, default=1e-3, help="gradnorm maximum")
parser.add_argument("--maxnorm_wfn", type=float, default=1e-3, help="gradnorm maximum")
parser.add_argument("--clip_factor", type=float, default=5.0, help="clip factor for gradient")
parser.add_argument("--alpha", type=float, default=0.1, help="mixing of new fisher matrix")
# training parameters
parser.add_argument("--batchsize", type=int, default=256, help="batch size (per single gradient accumulation step)")
parser.add_argument("--acc_steps", type=int, default=1, help="gradient accumulation steps")
parser.add_argument("--epoch", type=int, default=100000, help="final epoch")
# electron wavefunction base
parser.add_argument("--orbital", type=str, default="hf", choices=["pw", "hf", "dft"], help="orbitals for electron wavefunction")
parser.add_argument("--basis", default="gth-dzv", help="gto basis used in Hartree Fock calculation, eg. gth-dzv")
# yqm add from wanghan
parser.add_argument("--gamma", type=float, default=1., help="cg: the decaying parameter for the ProxSR method, gamma = 1 means ordinary sr without damping")
parser.add_argument("--lr_e", type=float, default=1e-2, help="initial learning rate for electrons")
parser.add_argument("--lr_p", type=float, default=1e-2, help="initial learning rate for protons")
parser.add_argument("--decay_e", type=float, default=1e-3, help="learning rate decay")
parser.add_argument("--decay_p", type=float, default=1e-3, help="learning rate decay")
parser.add_argument("--damping_e", type=float, default=1e-3, help="sr & fac & cg: damping for electrons")
parser.add_argument("--damping_p", type=float, default=1e-3, help="sr & fac & cg: damping for protons")
parser.add_argument("--maxnorm_e", type=float, default=1e-3, help="sr & fac & cg: gradnorm maximum for electrons")
parser.add_argument("--maxnorm_p", type=float, default=1e-3, help="sr & fac & cg: gradnorm maximum for protons")
parser.add_argument("--cg_mode", type=str, default="qr", choices=["score", "vjp_jvp", "svd", "qr"], help="optimizer cg mode")
parser.add_argument("--init_vec_last_step", action="store_true", help="use the solution of last step as the initial guess of the solver iterations")
parser.add_argument("--solver_precondition", action="store_true", help="precondition the solver iterations")
parser.add_argument("--solver_maxiter", type=int, default=None, help="maximum number of solver iterations")
parser.add_argument("--solver_tol", type=float, default=1e-10, help="the tolerance of solver iterations")
parser.add_argument("--solver_style", type=str, default="cg", choices=["cg", "gmres"], help="the tolerance of solver iterations")
parser.add_argument("--num_hosts", type=int, default=1, help="num of hosts")
parser.add_argument("--walker", type=int, default=16, help="walker per twist")
parser.add_argument("--batch", type=int, default=8, help="batch size per walker")
parser.add_argument("--mc_alg", type=str, default="mala", choices=["mcmc", "mala", "hmc"], help="MCMC: algorithm type")
parser.add_argument("--mc_init_therm", type=int, help="MCMC: initial thermalization steps")
parser.add_argument("--twist_shift", type=float, nargs="+", default=None,
help="twist angle.\n"
"If not None, it should be a single vector of length 'dim' and the same for each configuration. A float number t is also supported and represents the vector (t, ..., t).\n"
"Otherwise the twist for each member of the training set is sampled from the uniform distribution over [-.5, .5)^dim.")
parser.add_argument("--lf_steps_p", type=int, default=50, help="MCMC: leaf-frog steps for protons")
parser.add_argument("--lf_steps_e", type=int, default=10, help="MCMC: leaf-frog steps for electron")
args = parser.parse_args()
if args.batchsize% num_devices != 0:
raise ValueError("Batch size must be divisible by the number of GPU devices. "
"Got batch = %d for %d devices now." % (args.batchsize, num_devices))
batch_in_axes_name = {
"xe": ("W", "B"),
"xp": ("W",),
"twist": ("W",),
"v": ("W", "B"), # used in hutchinson's implementation of laplacian
}
batch_out_axes_name = ("W", "B")
n, dim = args.n, args.dim
assert (n%2==0)
# Ry = 157888.088922572 Kelvin
beta = 157888.088922572/args.T # inverse temperature in unit of 1/Ry
print ("temperature in Rydberg unit:", 1.0/beta)
if dim == 3:
L = (4/3*jnp.pi*n)**(1/3)
elif dim == 2:
L = jnp.sqrt(jnp.pi*n)
print("n = %d, dim = %d, L = %f, rs = %f" % (n, dim, L, args.rs))
####################################################################################
print("\n========== Initialize single-particle orbitals ==========")
kpt = jnp.array([0,0,0])
print("k point =", kpt)
from lcao import make_hf
from orbitals import make_lcao_orbitals
if args.orbital == "hf":
hf = make_hf(args.n, L, args.rs, basis=args.basis)
else:
raise ValueError("Unknown orbital type: %s" % args.orbital)
hf_orbitals = make_lcao_orbitals(n, L, args.rs, basis=args.basis)
s = jax.random.uniform(key, (n, dim), minval=0., maxval=L)
mo_coeff, bands = hf(s)
from scipy.special import comb
num_states = bands.shape[0]
print("Number of available single-particle orbitals: %d" % num_states)
print("Total number of many-body states (%d in %d)^2: %f" % (n//2, num_states, comb(num_states, n//2)**2))
# from orbitals import sp_orbitals
# sp_indices, Es = sp_orbitals(dim, args.Emax)
# sp_indices, Es = jnp.array(sp_indices), jnp.array(Es)
# Ef = Es[n//2-1]
# print("beta = %f, Ef = %d, Emax = %d, corresponding delta_logit = %f"
# % (beta, Ef, args.Emax, beta * (2*jnp.pi/L)**2 * (args.Emax - Ef)))
# sp_indices, Es = sp_indices[::-1], Es[::-1]
# num_states = Es.size
# print("Number of available single-particle orbitals: %d" % num_states)
# from scipy.special import comb
# print("Total number of many-body states (%d in %d)^2: %f" % (n//2, num_states, comb(num_states, n//2)**2))
####################################################################################
print("\n========== Initialize relevant quantities for Ewald summation ==========")
from potential import kpoints, Madelung
G = kpoints(dim, args.Gmax)
Vconst = n * args.rs/L * Madelung(dim, args.kappa, G)
print("(scaled) Vconst:", Vconst/(n*args.rs/L))
####################################################################################
print("\n========== Initialize many-body state distribution (VAN) ==========")
import haiku as hk
from autoregressive import Transformer
def forward_fn(state):
model = Transformer(num_states, args.nlayers, args.modelsize, args.nheads, args.nhidden, remat=args.remat)
return model(state)
van = hk.transform(forward_fn)
state_idx_dummy = jnp.array([jnp.arange(n, dtype=jnp.float64)]).T
# state_idx_dummy = sp_indices[-n:].astype(jnp.float64)
params_van = van.init(key, state_idx_dummy)
raveled_params_van, _ = ravel_pytree(params_van)
print("#parameters in the autoregressive model: %d" % raveled_params_van.size)
from sampler import make_autoregressive_sampler, make_classical_score, make_classical_score_van
sampler, logprob_van_novmap = make_autoregressive_sampler(van, n, num_states, beta)
logprob_e = jax.vmap(logprob_van_novmap, (None, 0, 0), 0)
####################################################################################
# print("\n========== Pretraining ==========")
# # Pretraining parameters for the free-fermion model.
# pre_lr = 1e-3
# pre_sr, pre_damping, pre_maxnorm = True, 0.001, 0.001
# pre_batch = 8192
# freefermion_path = args.folder + "freefermion/pretraining/" \
# + "n_%d_dim_%d_T_%g_Emax_%d/" % (n, dim, args.T, args.Emax) \
# + "nlayers_%d_modelsize_%d_nheads_%d_nhidden_%d" % \
# (args.nlayers, args.modelsize, args.nheads, args.nhidden) \
# + ("_damping_%g_maxnorm_%g" % (pre_damping, pre_maxnorm)
# if pre_sr else "_lr_%g" % pre_lr) \
# + "_batch_%d" % pre_batch
# import os
# if not os.path.isdir(freefermion_path):
# os.makedirs(freefermion_path)
# print("Create freefermion directory: %s" % freefermion_path)
# import checkpoint
# pretrained_model_filename = checkpoint.pretrained_model_filename(freefermion_path)
# if os.path.isfile(pretrained_model_filename):
# print("Load pretrained free-fermion model parameters from file: %s" % pretrained_model_filename)
# params_van = checkpoint.load_data(pretrained_model_filename)
# else:
# print("No pretrained free-fermion model found. Initialize parameters from scratch...")
# from pretraining import pretrain
# params_van = pretrain(van, params_van,
# logprob_van_novmap, sampler,
# n, dim, beta/args.rs**2, L, args.Emax,
# freefermion_path, key,
# pre_lr, pre_sr, pre_damping, pre_maxnorm,
# pre_batch, epoch=args.epoch)
# print("Initialization done. Save the model to file: %s" % pretrained_model_filename)
# checkpoint.save_data(params_van, pretrained_model_filename)
####################################################################################
print("\n========== Initialize normalizing flow ==========")
import haiku as hk
from ferminet import FermiNet
def forward_fn(x):
for _ in range(args.flow_steps):
model = FermiNet(args.flow_depth, args.flow_h1size, args.flow_h2size, args.Nf, L, False, remat=args.remat)
x = model(x)
return x
network_flow = hk.transform(forward_fn)
x_dummy = jax.random.uniform(key, (n, dim), minval=0., maxval=L)
params_flow = network_flow.init(key, x_dummy)
raveled_params_flow, _ = ravel_pytree(params_flow)
print("#parameters in the flow model: %d" % raveled_params_flow.size)
from flow import make_flow
logprob_flow_novmap = make_flow(network_flow, n, dim, L)
vmap_p = partial(jax.vmap, in_axes=(None, 0), out_axes=0)
logprob_p = vmap_p(logprob_flow_novmap)
force_fn_p = vmap_p(make_grad_real(logprob_flow_novmap, argnums=1))
#yqm add
score_fn_p = vmap_p(make_grad_real(logprob_flow_novmap))
#over
####################################################################################
print("\n========== Initialize wavefunction ==========")
def forward_fn(x):
model = FermiNet(args.wfn_depth, args.wfn_h1size, args.wfn_h2size, args.Nf, L, True, remat=args.remat)
return model(x)
network_wfn = hk.transform(forward_fn)
sx_dummy = jax.random.uniform(key, (2*n, dim), minval=0., maxval=L)
params_wfn = network_wfn.init(key, sx_dummy)
raveled_params_wfn, _ = ravel_pytree(params_wfn)
print("#parameters in the wavefunction model: %d" % raveled_params_wfn.size)
from logpsi import make_logpsi, make_logpsi_grad_laplacian, \
make_logpsi2, make_quantum_score
logpsi_novmap = make_logpsi(network_wfn, hf_orbitals, kpt)
logpsi2_novmap = make_logpsi2(logpsi_novmap)
#vmap_wfn = partial(jax.vmap, in_axes=(0, None, 0, 0, 0), out_axes=0)
# yqm add
vmap_wfn = partial(jax.vmap, in_axes=(0, None, 0, 0, 0), out_axes=0)
# over
logpsi2 = vmap_wfn(logpsi2_novmap)
force_fn_e = vmap_wfn(make_grad_real(logpsi2_novmap))
# yqm modified
#score_fn_e = vmap_wfn(make_grad_complex(logpsi_novmap, argnums=1))
score_fn_e =vmap_wfn(make_grad_complex(logpsi_novmap, argnums=1))
# over
####################################################################################
print("\n========== Initialize optimizer ==========")
import optax
van_score_fn = make_classical_score_van(logprob_van_novmap)
flow_score_fn = make_classical_score(logprob_flow_novmap)
wfn_score_fn = make_quantum_score(logpsi_novmap)
from sr import hybrid_fisher_sr
"""
fishers_fn, optimizer = hybrid_fisher_sr(flow_score_fn, van_score_fn, wfn_score_fn,
args.lr_flow, args.lr_van, args.lr_wfn, args.decay,
args.damping_flow, args.damping_van, args.damping_wfn,
args.maxnorm_flow, args.maxnorm_van, args.maxnorm_wfn)
"""
# yqm add
from sr_wangshan import quantum_fisher_cg
fishers_fn_wfn_e, _, optimizer_wfn_e = quantum_fisher_cg(logpsi = logpsi2, score_fn = wfn_score_fn, acc_steps = args.acc_steps,
gamma = args.gamma,lr = args.lr_wfn, decay = args.decay, damping = args.damping_wfn, maxnorm = args.maxnorm_wfn,
mode=args.cg_mode,
init_vec_last_step=args.init_vec_last_step,
solver_precondition=args.solver_precondition,
solver_maxiter=args.solver_maxiter,
solver_tol=args.solver_tol,
solver_style=args.solver_style,)
fishers_fn_flow_p, _, optimizer_flow_p = quantum_fisher_cg(logpsi = logpsi2, score_fn = score_fn_p, acc_steps = args.acc_steps,
gamma = args.gamma,lr = args.lr_flow, decay = args.decay, damping = args.damping_flow, maxnorm = args.maxnorm_flow,
mode=args.cg_mode,
init_vec_last_step=args.init_vec_last_step,
solver_precondition=args.solver_precondition,
solver_maxiter=args.solver_maxiter,
solver_tol=args.solver_tol,
solver_style=args.solver_style,)
fishers_fn_van, _, optimizer_van = quantum_fisher_cg(logpsi = logpsi2, score_fn = van_score_fn, acc_steps = args.acc_steps,
gamma = args.gamma,lr = args.lr_van, decay = args.decay, damping = args.damping_van, maxnorm = args.maxnorm_van,
mode=args.cg_mode,
init_vec_last_step=args.init_vec_last_step,
solver_precondition=args.solver_precondition,
solver_maxiter=args.solver_maxiter,
solver_tol=args.solver_tol,
solver_style=args.solver_style,)
# yqm add over
####################################################################################
print("\n========== Checkpointing ==========")
from utils import shard, replicate, p_split
path = args.folder + "n_%d_dim_%d_rs_%g_T_%g" % (n, dim, args.rs, args.T) \
+ "_Em_%d" % args.Emax \
+ "_l_%d_m_%d_he_%d_hi_%d" % \
(args.nlayers, args.modelsize, args.nheads, args.nhidden) \
+ "_fs_%d_fd_%d_fh1_%d_fh2_%d" % \
(args.flow_steps, args.flow_depth, args.flow_h1size, args.flow_h2size) \
+ "_wd_%d_wh1_%d_wh2_%d_Nf_%d" % \
(args.wfn_depth, args.wfn_h1size, args.wfn_h2size, args.Nf) \
+ "_G_%d_kp_%d" % (args.Gmax, args.kappa) \
+ "_mt_%d_mp_%d_%d_mw_%g_%g" % (args.mc_therm, args.mc_steps_p, args.mc_steps_e, args.mc_width_p, args.mc_width_e) \
+ ("_ht" if args.hutchinson else "") \
+ "_lr_%g_%g_%g_decay_%g_dp_%g_%g_%g_nm_%g_%g_%g" % (args.lr_flow, args.lr_van, args.lr_wfn, args.decay, args.damping_flow, args.damping_van, args.damping_wfn, args.maxnorm_flow, args.maxnorm_van, args.maxnorm_wfn) \
+ "_cl_%g_al_%g"%(args.clip_factor, args.alpha) \
+ "_bs_%d_ap_%d" %(args.batchsize, args.acc_steps)
if not os.path.isdir(path):
os.makedirs(path)
print("Create directory: %s" % path)
ckpt_filename, epoch_finished = checkpoint.find_ckpt_filename(args.restore_path or path)
batch_per_device = args.batchsize // num_devices
if ckpt_filename is not None:
print("Load checkpoint file: %s, epoch finished: %g" %(ckpt_filename, epoch_finished))
ckpt = checkpoint.load_data(ckpt_filename)
#keys, s, x, params_flow, params_van, params_wfn, opt_state = \
# ckpt["keys"], ckpt["s"], ckpt["x"], ckpt["params_flow"], ckpt["params_van"], ckpt["params_wfn"], ckpt["opt_state"]
# yqm add
keys, s, x, params_flow, params_van, params_wfn, opt_state_wfn_e, opt_state_flow_p = \
ckpt["keys"], ckpt["s"], ckpt["x"], ckpt["params_flow"], ckpt["params_van"], ckpt["params_wfn"], ckpt["opt_state_wfn_e"], ckpt["opt_state_flow_p"]
# over
keys = jax.random.split(keys[0], num_devices)
try:
mc_width_p, mc_width_e = ckpt["mc_width_p"], ckpt["mc_width_e"]
except (NameError, KeyError):
mc_width_p, mc_width_e = args.mc_width_p, args.mc_width_e
if (s.size == num_devices*batch_per_device*n*dim) and (x.size == num_devices*batch_per_device*n*dim):
s = jnp.reshape(s, (num_devices, batch_per_device, n, dim))
x = jnp.reshape(x, (num_devices, batch_per_device, n, dim))
else:
keys, subkeys = p_split(keys)
s = jax.pmap(jax.random.uniform, static_broadcasted_argnums=(1,2,3,4))(subkeys, (batch_per_device, n, dim), sx_dummy.dtype, 0., L)
keys, subkeys = p_split(keys)
x = jax.pmap(jax.random.uniform, static_broadcasted_argnums=(1,2,3,4))(subkeys, (batch_per_device, n, dim), sx_dummy.dtype, 0., L)
epoch_finished = 0
s, x, keys = shard(s), shard(x), shard(keys)
params_flow, params_van, params_wfn = replicate((params_flow, params_van, params_wfn), num_devices)
else:
print("No checkpoint file found. Start from scratch.")
#opt_state = optimizer.init((params_flow, params_van, params_wfn))
# yqm modified
opt_state_wfn_e = optimizer_wfn_e.init(params_wfn)
opt_state_flow_p = optimizer_flow_p.init(params_flow)
opt_state_van = optimizer_van.init(params_van)
# over
print("Initialize key and coordinate samples...")
key, key_proton, key_electron = jax.random.split(key, 3)
s = jax.random.uniform(key_proton, (num_devices, batch_per_device, n, dim), minval=0., maxval=L)
x = jax.random.uniform(key_electron, (num_devices, batch_per_device, n, dim), minval=0., maxval=L)
keys = jax.random.split(key, num_devices)
s, x, keys = shard(s), shard(x), shard(keys)
params_flow, params_van, params_wfn = replicate((params_flow, params_van, params_wfn), num_devices)
mc_width_p, mc_width_e = args.mc_width_p, args.mc_width_e
# yqm add
#opt_state_wfn_e = optimizer_wfn_e.init(params_wfn, x)
# over
# rerun thermalization steps since we regenerated s and x samples
if epoch_finished == 0:
for i in range(args.mc_therm):
print("---- thermal step %d ----" % (i+1))
print("s.shape =", s.shape)
keys, state_idx, mo_coeff, bands, s, x, ar_s, ar_x = sample_s_and_x(keys,
sampler, params_van,
logprob_p, force_fn_p, s, params_flow,
logpsi2, force_fn_e, x, params_wfn,
args.mc_steps_p, args.mc_steps_e, mc_width_p, mc_width_e, L, hf, kpt)
print ('acc, proton entropy:', jnp.mean(ar_s), jnp.mean(ar_x), -jax.pmap(logprob_p)(params_flow, s).mean()/n)
print("keys shape:", keys.shape, "\t\ttype:", type(keys))
print("x shape:", x.shape, "\t\ttype:", type(x))
####################################################################################
print("\n========== Training ==========")
from vmc import make_loss
logpsi, logpsi_grad_laplacian = make_logpsi_grad_laplacian(logpsi_novmap, hutchinson=args.hutchinson)
observable_and_lossfn = make_loss(logprob_p, logprob_e, logpsi, logpsi_grad_laplacian,
args.kappa, G, L, args.rs, Vconst, beta, args.clip_factor)
@partial(jax.pmap, axis_name="p",
in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None, None,None),
out_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, None,None),
static_broadcasted_argnums=(18, 19))
# yqm add
# def update(params_flow, params_van, params_wfn, state_idx, mo_coeff, bands, s, x, key, data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc,
# flow_fisher_acc, van_fisher_acc, wfn_fisher_acc, wfn_score_mean_acc, final_step, mix_fisher, opt_state):
# data, flow_lossfn, van_lossfn, wfn_lossfn = observable_and_lossfn(params_flow, params_van, params_wfn, state_idx, mo_coeff, bands, s, x, key)
# grad_params_flow, flow_score = jax.jacrev(flow_lossfn)(params_flow)
# grad_params_van, van_score = jax.jacrev(van_lossfn)(params_van)
# grad_params_wfn, wfn_score = jax.jacrev(wfn_lossfn)(params_wfn)
# grads = grad_params_flow, grad_params_van, grad_params_wfn
# grads, flow_score, van_score, wfn_score = jax.lax.pmean((grads, flow_score, van_score, wfn_score), axis_name="p")
# data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc = jax.tree_util.tree_map(lambda acc, i: acc + i,
# (data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc),
# (data, grads, flow_score, van_score, wfn_score))
# #flow_fisher, van_fisher, wfn_fisher, wfn_score_mean = fishers_fn((params_flow, params_van, params_wfn), state_idx, mo_coeff, bands, s, x, opt_state)
# # yqm add
# #flow_fisher, van_fisher, wfn_fisher, wfn_score_mean = fishers_fn((params_flow, params_van, params_wfn), opt_state, x, state_idx, mo_coeff, bands, s)
# #flow_fisher, van_fisher, wfn_fisher, wfn_score_mean = fishers_fn((params_flow, params_van, params_wfn), x, opt_state, s, state_idx, mo_coeff)
# _, wfn_score = fishers_fn(params_wfn, x, opt_state, s, state_idx, mo_coeff)
# wfn_fisher = jax.lax.pmean(
# wfn_score.T.dot(wfn_score) / batchsize,
# axis_name="p")
# print("wfn_score:", wfn_score)
# #over
# if mix_fisher:
# i= opt_state["acc"]
# # 1/(1-alpha)**(a-1-i) factor to account for the same mixing factor for all acc steps
# flow_fisher_acc = (1-args.alpha)*flow_fisher_acc + args.alpha/(1-args.alpha)**(args.acc_steps-1-i)*flow_fisher/args.acc_steps
# van_fisher_acc = (1-args.alpha)*van_fisher_acc + args.alpha/(1-args.alpha)**(args.acc_steps-1-i)*van_fisher/args.acc_steps
# wfn_fisher_acc = (1-args.alpha)*wfn_fisher_acc + args.alpha/(1-args.alpha)**(args.acc_steps-1-i)*wfn_fisher/args.acc_steps
# wfn_score_mean_acc = (1-args.alpha)*wfn_score_mean_acc + args.alpha/(1-args.alpha)**(args.acc_steps-1-i)*wfn_score_mean/args.acc_steps
# else:
# flow_fisher_acc += flow_fisher/args.acc_steps
# van_fisher_acc += van_fisher/args.acc_steps
# wfn_fisher_acc += wfn_fisher/args.acc_steps
# wfn_score_mean_acc += wfn_score_mean/args.acc_steps
# if final_step:
# data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc = \
# jax.tree_map(lambda acc: acc / args.acc_steps,
# (data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc))
# grad_params_flow, grad_params_van, grad_params_wfn = grads_acc
# grad_params_flow = jax.tree_util.tree_map(lambda grad, score:
# grad - data_acc["F"] * score,
# grad_params_flow, flow_score_acc)
# grad_params_van = jax.tree_util.tree_map(lambda grad, score:
# grad - data_acc["F"] * score,
# grad_params_van, van_score_acc)
# grad_params_wfn = jax.tree_util.tree_map(lambda grad, score:
# grad - data_acc["E"] * score,
# grad_params_wfn, wfn_score_acc)
# grads_acc = grad_params_flow, grad_params_van, grad_params_wfn
# updates, opt_state = optimizer.update(grads_acc, opt_state,
# params=(flow_fisher_acc, van_fisher_acc, wfn_fisher_acc, wfn_score_mean_acc))
# params_flow, params_van, params_wfn = optax.apply_updates((params_flow, params_van, params_wfn), updates)
# return params_flow, params_van, params_wfn, data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc,\
# flow_fisher_acc, van_fisher_acc, wfn_fisher_acc, wfn_score_mean_acc, opt_state
def update(params_flow, params_van, params_wfn, state_idx, mo_coeff, bands, s, x, key, data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc,
flow_fisher_acc, van_fisher_acc, wfn_fisher_acc, wfn_score_mean_acc, final_step, mix_fisher, opt_state_wfn_e, opt_state_flow_p, opt_state_van):
data, flow_lossfn, van_lossfn, wfn_lossfn = observable_and_lossfn(params_flow, params_van, params_wfn, state_idx, mo_coeff, bands, s, x, key)
grad_params_flow, flow_score = jax.jacrev(flow_lossfn)(params_flow)
grad_params_van, van_score = jax.jacrev(van_lossfn)(params_van)
grad_params_wfn, wfn_score = jax.jacrev(wfn_lossfn)(params_wfn)
grads = grad_params_flow, grad_params_van, grad_params_wfn
grads, flow_score, van_score, wfn_score = jax.lax.pmean((grads, flow_score, van_score, wfn_score), axis_name="p")
data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc = jax.tree_util.tree_map(lambda acc, i: acc + i,
(data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc),
(data, grads, flow_score, van_score, wfn_score))
#flow_fisher, van_fisher, wfn_fisher, wfn_score_mean = fishers_fn((params_flow, params_van, params_wfn), state_idx, mo_coeff, bands, s, x, opt_state)
# yqm add
#flow_fisher, van_fisher, wfn_fisher, wfn_score_mean = fishers_fn((params_flow, params_van, params_wfn), opt_state, x, state_idx, mo_coeff, bands, s)
#flow_fisher, van_fisher, wfn_fisher, wfn_score_mean = fishers_fn((params_flow, params_van, params_wfn), x, opt_state, s, state_idx, mo_coeff)
wfn_score = fishers_fn_wfn_e('wfn', params_wfn, x, opt_state_wfn_e, s, state_idx, mo_coeff)
flow_score = fishers_fn_flow_p('flow', params_flow, s, opt_state_flow_p)
van_score = fishers_fn_van('van', params_van, x, opt_state_van, state_idx, bands)
wfn_score_mean = jax.lax.pmean(wfn_score.mean(axis=0), axis_name="p")
batchsize = flow_score.shape[0]
wfn_fisher = jax.lax.pmean(
wfn_score.conj().T.dot(wfn_score).real / batchsize,
axis_name="p")
flow_fisher = jax.lax.pmean(
flow_score.T.dot(flow_score) / batchsize,
axis_name="p")
van_fisher = jax.lax.pmean(
van_score.T.dot(van_score) / batchsize,
axis_name="p")
print("wfn_score:", wfn_score)
print("flow_score:", flow_score)
print("van_score:", van_score)
#sys.exit()
#over
if mix_fisher:
i= opt_state["acc"]
# 1/(1-alpha)**(a-1-i) factor to account for the same mixing factor for all acc steps
flow_fisher_acc = (1-args.alpha)*flow_fisher_acc + args.alpha/(1-args.alpha)**(args.acc_steps-1-i)*flow_fisher/args.acc_steps
van_fisher_acc = (1-args.alpha)*van_fisher_acc + args.alpha/(1-args.alpha)**(args.acc_steps-1-i)*van_fisher/args.acc_steps
wfn_fisher_acc = (1-args.alpha)*wfn_fisher_acc + args.alpha/(1-args.alpha)**(args.acc_steps-1-i)*wfn_fisher/args.acc_steps
wfn_score_mean_acc = (1-args.alpha)*wfn_score_mean_acc + args.alpha/(1-args.alpha)**(args.acc_steps-1-i)*wfn_score_mean/args.acc_steps
else:
flow_fisher_acc += flow_fisher/args.acc_steps
van_fisher_acc += van_fisher/args.acc_steps
wfn_fisher_acc += wfn_fisher/args.acc_steps
wfn_score_mean_acc += wfn_score_mean/args.acc_steps
if final_step:
data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc = \
jax.tree_map(lambda acc: acc / args.acc_steps,
(data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc))
grad_params_flow, grad_params_van, grad_params_wfn = grads_acc
grad_params_flow = jax.tree_util.tree_map(lambda grad, score:
grad - data_acc["F"] * score,
grad_params_flow, flow_score_acc)
grad_params_van = jax.tree_util.tree_map(lambda grad, score:
grad - data_acc["F"] * score,
grad_params_van, van_score_acc)
grad_params_wfn = jax.tree_util.tree_map(lambda grad, score:
grad - data_acc["E"] * score,
grad_params_wfn, wfn_score_acc)
grads_acc = grad_params_flow, grad_params_van, grad_params_wfn
# yqm add
#updates, opt_state = optimizer.update(grads_acc, opt_state,
# params=(flow_fisher_acc, van_fisher_acc, wfn_fisher_acc, wfn_score_mean_acc))
#params_flow, params_van, params_wfn = optax.apply_updates((params_flow, params_van, params_wfn), updates)
updates_wfn_e, opt_state_wfn_e = optimizer_wfn_e.update('wfn',grad_params_wfn, opt_state_wfn_e,
params=(wfn_fisher_acc, wfn_score_mean_acc))
updates_flow_p, opt_state_flow_p = optimizer_flow_p.update('flow',grad_params_flow, opt_state_flow_p,
params=(flow_fisher_acc))
updates_van, opt_state_van = optimizer_van.update('van',grad_params_van, opt_state_van,
params=(van_fisher_acc))
print('updates_wfn_e:{}, opt_state_wfn_e:{}'.format(updates_wfn_e, opt_state_wfn_e))
print('updates_flow_p:{}, opt_state_flow_p:{}'.format(updates_flow_p, opt_state_flow_p))
print('updates_van:{}, opt_state_van:{}'.format(updates_van, opt_state_van))
params_wfn = optax.apply_updates(params_wfn, updates_wfn_e)
print('params_wfn:',params_wfn)
params_flow = optax.apply_updates(params_flow, updates_flow_p)
print('params_flow:',params_flow)
params_van = optax.apply_updates(params_van, updates_van)
print('params_van:',params_van)
# over
return params_flow, params_van, params_wfn, data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc,\
flow_fisher_acc, van_fisher_acc, wfn_fisher_acc, wfn_score_mean_acc, opt_state_wfn_e, opt_state_flow_p, opt_state_van
# over
flow_fisher_acc = replicate(jnp.zeros((raveled_params_flow.size, raveled_params_flow.size)), num_devices)
van_fisher_acc = replicate(jnp.zeros((raveled_params_van.size, raveled_params_van.size)), num_devices)
wfn_fisher_acc = replicate(jnp.zeros((raveled_params_wfn.size, raveled_params_wfn.size)), num_devices)
wfn_score_mean_acc = replicate(jnp.zeros(raveled_params_wfn.size), num_devices)
mix_fisher = False
time_of_last_ckpt = time.time()
log_filename = os.path.join(path, "data.txt")
f = open(log_filename, "w" if epoch_finished == 0 else "a",
buffering=1, newline="\n")
if os.path.getsize(log_filename)==0:
f.write("epoch f f_err e e_err k k_err vpp vpp_err vep vep_err vee vee_err p p_err ep_cov sp sp_err se se_err acc_s acc_x\n")
for i in range(epoch_finished + 1, args.epoch + 1):
data_acc = replicate({"K": 0., "K2": 0.,
"Vpp": 0., "Vpp2": 0.,
"Vep": 0., "Vep2": 0.,
"Vee": 0., "Vee2": 0.,
"P": 0., "P2": 0.,
"E": 0., "E2": 0.,
"EP": 0.,
"F": 0., "F2": 0.,
"Sp": 0., "Sp2": 0.,
"Se": 0., "Se2": 0.,
}, num_devices)
grads_acc = shard(jax.tree_map(jnp.zeros_like, (params_flow, params_van, params_wfn)))
flow_score_acc = shard(jax.tree_map(jnp.zeros_like, params_flow))
van_score_acc = shard(jax.tree_map(jnp.zeros_like, params_van))
wfn_score_acc = shard(jax.tree_map(jnp.zeros_like, params_wfn))
ar_s_acc = shard(jnp.zeros(num_devices))
ar_x_acc = shard(jnp.zeros(num_devices))
for acc in range(args.acc_steps):
keys, state_idx, mo_coeff, bands, s, x, ar_s, ar_x = sample_s_and_x(keys,
sampler, params_van,
logprob_p, force_fn_p, s, params_flow,
logpsi2, force_fn_e, x, params_wfn,
args.mc_steps_p, args.mc_steps_e,
mc_width_p, mc_width_e, L, hf, kpt)
ar_s_acc += ar_s/args.acc_steps
ar_x_acc += ar_x/args.acc_steps
final_step = (acc == args.acc_steps - 1)
#opt_state["acc"] = acc
# yqm add
opt_state_wfn_e["acc"] = acc
opt_state_flow_p["acc"] = acc
opt_state_van["acc"] = acc
# over
# yqm add
#params_flow, params_van, params_wfn, data_acc, grads_acc, \
#flow_score_acc, van_score_acc, wfn_score_acc,\
#flow_fisher_acc, van_fisher_acc, wfn_fisher_acc, wfn_score_mean_acc, opt_state\
# = update(params_flow, params_van, params_wfn, state_idx, mo_coeff, bands, s, x, keys,
# data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc,
# flow_fisher_acc, van_fisher_acc, wfn_fisher_acc, wfn_score_mean_acc, final_step, mix_fisher, opt_state)
params_flow, params_van, params_wfn, data_acc, grads_acc, \
flow_score_acc, van_score_acc, wfn_score_acc,\
flow_fisher_acc, van_fisher_acc, wfn_fisher_acc, wfn_score_mean_acc, opt_state_wfn_e, opt_state_flow_p, opt_state_van\
= update(params_flow, params_van, params_wfn, state_idx, mo_coeff, bands, s, x, keys, data_acc, grads_acc, flow_score_acc, van_score_acc, wfn_score_acc, flow_fisher_acc, van_fisher_acc, wfn_fisher_acc, wfn_score_mean_acc, final_step, mix_fisher, opt_state_wfn_e, opt_state_flow_p,opt_state_van)
# over
# if we have finished a step, we can mix fisher from now on
if final_step: mix_fisher = True
data = jax.tree_map(lambda x: x[0], data_acc)
ar_s = ar_s_acc[0]
ar_x = ar_x_acc[0]
K, K2, Vpp, Vpp2, Vep, Vep2, Vee, Vee2, P, P2, E, E2, EP, F, F2, Sp, Sp2, Se, Se2 = \
data["K"], data["K2"], \
data["Vpp"], data["Vpp2"],\
data["Vep"], data["Vep2"],\
data["Vee"], data["Vee2"],\
data["P"], data["P2"], \
data["E"], data["E2"], \
data["EP"], \
data["F"], data["F2"], \
data["Sp"], data["Sp2"], \
data["Se"], data["Se2"]
K_std = jnp.sqrt((K2- K**2) / (args.batchsize*args.acc_steps))
Vpp_std = jnp.sqrt((Vpp2- Vpp**2) / (args.batchsize*args.acc_steps))
Vep_std = jnp.sqrt((Vep2- Vep**2) / (args.batchsize*args.acc_steps))
Vee_std = jnp.sqrt((Vee2- Vee**2) / (args.batchsize*args.acc_steps))
P_std = jnp.sqrt((P2- P**2) / (args.batchsize*args.acc_steps))
E_std = jnp.sqrt((E2- E**2) / (args.batchsize*args.acc_steps))
EP_cov = (EP- E*P) / (args.batchsize*args.acc_steps)
F_std = jnp.sqrt((F2- F**2) / (args.batchsize*args.acc_steps))
Sp_std = jnp.sqrt((Sp2- Sp**2) / (args.batchsize*args.acc_steps))
Se_std = jnp.sqrt((Se2- Se**2) / (args.batchsize*args.acc_steps))
# Note the quantities with energy dimension has a prefactor 1/rs^2
print("iter: %04d" % i,
"F:", F/args.rs**2, "F_std:", F_std/args.rs**2,
"E:", E/args.rs**2, "E_std:", E_std/args.rs**2,
"K:", K/args.rs**2, "K_std:", K_std/args.rs**2,
"Sp:", Sp, "Sp_std:", Sp_std,
"Se:", Se, "Se_std:", Se_std,
"accept_rate:", ar_s, ar_x
)
f.write( ("%6d" + " %.6f"*19 + " %.4f"*2 + "\n") % (i,
F/n/args.rs**2, F_std/n/args.rs**2,
E/n/args.rs**2, E_std/n/args.rs**2,
K/n/args.rs**2, K_std/n/args.rs**2,
Vpp/n/args.rs**2, Vpp_std/n/args.rs**2,
Vep/n/args.rs**2, Vep_std/n/args.rs**2,
Vee/n/args.rs**2, Vee_std/n/args.rs**2, # Ry
P/args.rs**2, P_std/args.rs**2, # GPa
EP_cov/n/args.rs**4, # GPa
Sp/n, Sp_std/n,
Se/n, Se_std/n,
ar_s, ar_x) )
if time.time() - time_of_last_ckpt > 600:
ckpt = {"keys": keys, "s": s, "x": x,
"params_flow": jax.tree_map(lambda x: x[0], params_flow),
"params_van": jax.tree_map(lambda x: x[0], params_van),
"params_wfn": jax.tree_map(lambda x: x[0], params_wfn),
"opt_state": opt_state,
"mc_width_p": mc_width_p,
"mc_width_e": mc_width_e
}
ckpt_filename = os.path.join(path, "epoch_%06d.pkl" %i)
checkpoint.save_data(ckpt, ckpt_filename)
print("Save checkpoint file: %s" % ckpt_filename)
time_of_last_ckpt = time.time()
if jnp.isnan(F):
raise RuntimeError("Free energy is nan")
if ar_s < 1e-7 or ar_x < 1e-7:
raise RuntimeError("Acceptance rate nearly zero")
if i % 100 == 0:
mc_width_p = adjust_mc_width(mc_width_p, ar_s, "mcmc")
mc_width_e = adjust_mc_width(mc_width_e, ar_x, "mcmc")
f.close()
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all,
Currently I am running my program, I have a function called update, and I get error as follows:
It is clear that my gpu is out of memory, and I want to check which lines in function "update" caused this, however, I can even print the last line of function "update", the fig is in attached files.
Now I am confused about it, and I do not know which line caused this error, could you please give me some advice?
Here is the whole code
Many thanks!
Beta Was this translation helpful? Give feedback.
All reactions