diff --git a/pySDC/playgrounds/dedalus/sdc.py b/pySDC/playgrounds/dedalus/sdc.py index eba390228..52720d7a4 100644 --- a/pySDC/playgrounds/dedalus/sdc.py +++ b/pySDC/playgrounds/dedalus/sdc.py @@ -955,3 +955,47 @@ def step(self, dt, wall_time): # Only last rank (i.e node) will be allowed to (eventually) write outputs if self.rank != self.M-1: self.firstEval = False + + +class SDCIMEX_MPI2(SDCIMEX_MPI): + + def _broadcastState(self): + state = self.solver.state + sizes = [f.data.size for f in state] + buffer = np.empty(sum(sizes), dtype=state[0].data.dtype) + rank, M = self.rank, self.M + + if rank == M-1: # copy last rank state into buffer + pos = 0 + for f, size in zip(state, sizes): + np.copyto(buffer[pos:size], f.data.flat) + pos += size + + self.comm.Bcast(buffer, root=self.M-1) + + if rank != M-1: # copy buffer data into state + pos = 0 + for f, size in zip(state, sizes): + np.copyto(f.data, buffer[pos:size].reshape(f.data.shape)) + pos += size + + def _computeMX0(self, MX0:CoeffSystem): + """ + Compute MX0 term used in RHS of both initStep and sweep methods + + Update the MX0 attribute of the timestepper object. + """ + super(SDCIMEX_MPI, self)._computeMX0(MX0) + + def _initSweep(self): + t0, dt, wall_time = self.solver.sim_time, self.dt, self.wall_time + Fk, LXk = self.F[0], self.LX[0] + if self.initSweep == 'COPY': + self._evalLX(LXk) + self._evalF(Fk, t0, dt, wall_time) + else: + raise NotImplementedError() + + def step(self, dt, wall_time): + super().step(dt, wall_time) + self._broadcastState()