Skip to content

Commit

Permalink
manual grabage collection to limit OOM
Browse files Browse the repository at this point in the history
  • Loading branch information
SNMS95 committed Sep 28, 2023
1 parent 3c4a7c2 commit 1c321e5
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions jax_am/fem/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

# petsc4py.init()
from petsc4py import PETSc
import gc

from jax_am import logger

################################################################################
# PETSc linear solver or JAX linear solver


def petsc_solve(A, b, ksp_type, pc_type, direct_solve=True):
rhs = PETSc.Vec().createSeq(len(b))
rhs.setValues(range(len(b)), onp.array(b))
Expand Down Expand Up @@ -49,7 +49,10 @@ def petsc_solve(A, b, ksp_type, pc_type, direct_solve=True):
rhs.destroy()
y.destroy()
A.destroy()
return x.getArray()
result = x.getArray()
x.destroy()
gc.collect()
return result


def jax_solve(problem, A_fn, b, x0, precond: bool, pc_matrix=None):
Expand Down Expand Up @@ -213,7 +216,7 @@ def operator_to_matrix(operator_fn, problem):

def jacobi_preconditioner(problem):
logger.debug(f"Compute and use jacobi preconditioner")
jacobi = np.array(problem.A_sp_scipy.diagonal())
jacobi = np.array(problem.A_sp_scipy_diag)
jacobi = assign_ones_bc(jacobi.reshape(-1), problem)
return jacobi

Expand Down Expand Up @@ -321,7 +324,7 @@ def get_A_fn(problem, use_petsc):
# logger.debug(f"Creating sparse matrix from scipy using JAX BCOO...")
A_sp = BCOO.from_scipy_sparse(A_sp_scipy).sort_indices()
# logger.info(f"Global sparse matrix takes about {A_sp.data.shape[0]*8*3/2**30} G memory to store.")
problem.A_sp_scipy = A_sp_scipy
problem.A_sp_scipy_diag = A_sp_scipy.diagonal()

def compute_linearized_residual(dofs):
return A_sp @ dofs
Expand All @@ -338,6 +341,8 @@ def compute_linearized_residual(dofs):
else:
A = row_elimination(compute_linearized_residual, problem)

del A_sp_scipy
del A_sp
return A


Expand Down Expand Up @@ -414,6 +419,13 @@ def newton_update_helper(dofs):
logger.debug(f"max of sol = {np.max(sol)}")
logger.debug(f"min of sol = {np.min(sol)}")

if use_petsc:
A_fn.destroy()
else:
del A_fn
del dofs
del res_vec
gc.collect()
return sol


Expand Down

0 comments on commit 1c321e5

Please sign in to comment.