Skip to content

Commit

Permalink
TL: update on sdc solver
Browse files Browse the repository at this point in the history
  • Loading branch information
tlunet committed Dec 13, 2024
1 parent 9e67024 commit 9d3ef1d
Showing 1 changed file with 55 additions and 12 deletions.
67 changes: 55 additions & 12 deletions pySDC/playgrounds/dedalus/sdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class IMEXSDCCore(object):

# For NN use to compute initial guess, etc ...
model = None
modelIsCopy = False

@classmethod
def setParameters(cls, nNodes=None, nodeType=None, quadType=None,
Expand Down Expand Up @@ -189,14 +190,15 @@ def setParameters(cls, nNodes=None, nodeType=None, quadType=None,
cls.diagonal = diagonal

@classmethod
def setupNN(cls, nnType, nEval=1, **params):
def setupNN(cls, nnType, nEval=1, initSweep="NN", modelIsCopy=False, **params):
if nnType == "FNOP-1":
from fnop.inference.inference import FNOInference as ModelClass
from cfno.inference.inference import FNOInference as ModelClass
elif nnType == "FNOP-2":
from fnop.training.fno_pysdc import FourierNeuralOp as ModelClass
from cfno.training.pySDC import FourierNeuralOp as ModelClass
cls.model = ModelClass(**params)
cls.nModelEval = nEval
cls.initSweep = "NN"
cls.initSweep = initSweep
cls.modelIsCopy = modelIsCopy

# -------------------------------------------------------------------------
# Class properties
Expand Down Expand Up @@ -268,7 +270,7 @@ def __init__(self, solver):
self.firstStep = True

# FNO state
if self.initSweep == "NN":
if self.initSweep.startswith("NN"):
self.stateFNO = [field.copy() for field in self.solver.state]

@property
Expand Down Expand Up @@ -436,18 +438,31 @@ def _toNumpy(self, state):
"""Extract from state fields a 3D numpy array containing ux, uz, b and p,
to be given to a NN model."""
for field in state:
old_scales = field.scales[0]
field.change_scales(1)
field.require_grid_space()
return np.asarray(
u = np.asarray(
# ux , uz , b , p
[state[2].data[0], state[2].data[1], state[1].data, state[0].data])
for field in state:
field.require_coeff_space()
field.change_scales(old_scales)
return u


def _setStateWith(self, u, state):
"""Write a 3D numpy array containing ux, uz, b and p into a dedalus state."""
for field in state:
old_scales = field.scales[0]
field.change_scales(1)
field.require_grid_space()
np.copyto(state[2].data[0], u[0]) # ux
np.copyto(state[2].data[1], u[1]) # uz
np.copyto(state[1].data, u[2]) # b
np.copyto(state[0].data, u[3]) # p
for field in state:
field.require_coeff_space()
field.change_scales(old_scales)


def _initSweep(self):
Expand Down Expand Up @@ -517,10 +532,36 @@ def _initSweep(self):
np.copyto(LXk[m].data, LXk[0].data)
np.copyto(Fk[m].data, Fk[0].data)

elif self.initSweep == "NN":
elif self.initSweep in "NN":
# nothing to do, initialization of tendencies already done
# during last sweep ...
pass
self._evalLX(self.LX[1][0])
self._evalF(self.F[1][0], t0, dt, wall_time)
print(f"NN, t={t0:1.2f}, firstEval : {self.firstEval}")

elif self.initSweep == "NNI":
self._evalLX(self.LX[1][0])
self._evalF(self.F[1][0], t0, dt, wall_time)
print(f"NNI, t={t0:1.2f}, firstEval : {self.firstEval}")

current = solver.state
state = self.stateFNO

# Evaluate FNO on current state
for c, f in zip(current, state):
np.copyto(f.data, c.data)
u0 = self._toNumpy(state)
u1 = self.model(u0, nEval=self.nModelEval)

# Evaluate RHS with interpolation between current and FNO solution
solver.state = state
for m in range(self.M):
tEval = t0 + dt*tau[m]
self._setStateWith(u0 + tau[m]*(u1-u0), state)
self._evalLX(LXk[m])
self._evalF(Fk[m], tEval, dt, wall_time)

solver.state = current

else:
raise NotImplementedError(f'initSweep={self.initSweep}')
Expand Down Expand Up @@ -574,16 +615,18 @@ def _sweep(self, k):

tEval = t0+dt*tau[m]
# In case NN is used for initial guess (last sweep only)
if self.initSweep == "NN" and k == self.nSweeps-1:
if self.initSweep == "NN" and k == (self.nSweeps-1):
# => evaluate current state with NN to be used
# for the tendencies at k=0 for the initial guess of next step
current = solver.state
state = self.stateFNO
for c, f in zip(current, state):
np.copyto(f.data, c.data)
uState = self._toNumpy(state)
uNext = self.model(uState, nEval=self.nModelEval)
np.clip(uNext[2], a_min=0, a_max=1, out=uNext[2]) # temporary : clip buoyancy between 0 and 1
if self.modelIsCopy:
uNext = uState
else:
uNext = self.model(uState, nEval=self.nModelEval)
self._setStateWith(uNext, state)
solver.state = state
tEval += dt
Expand All @@ -593,7 +636,7 @@ def _sweep(self, k):
# Evaluate and store LX with current state
self._evalLX(LXk1[m])

if self.initSweep == "NN" and k == self.nSweeps-1:
if self.initSweep == "NN" and k == (self.nSweeps-1):
# Reset state if it was used for NN initial guess
solver.state = current

Expand Down

0 comments on commit 9d3ef1d

Please sign in to comment.