Skip to content

Commit

Permalink
AAMAS submission
Browse files Browse the repository at this point in the history
  • Loading branch information
LovelyBuggies committed Oct 28, 2022
1 parent f6de537 commit f3ada66
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 6 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
3 changes: 1 addition & 2 deletions MFG_VI.py → MFG-RL-PIDL.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@
option = options[2]
d = np.loadtxt(f"data/rho-{option}.txt")[:, 0].flatten('F')
train_ddpg(option, n_cell, T_terminal, d, fake_critic=True, pidl=True, surf_plot=True, smooth_plot=False,
diff_plot=True)
# plot_diff("./diff/", smooth=False)
diff_plot=True)
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# MFG-RL-PIDL

This is the source code for paper - A Hybrid Framework of Reinforcement Learning and Physics-Informed Deep Learning for Spatiotemporal Mean Field Games.

## File Structure

- `data` folder includes the numerical results of MFGs.
- `MFG-RL-PIDL.py` is the runner of Alg. 1 MFG RL-PIDL in our paper.
- `value_iteration_DDPG.py` is the training function of Alg. 1 MFG RL-PIDL.
- `MFG-Pure-PIDL-*.ipynb` are the source codes of the Alg. 2 MFG-Pure-PIDL.
- `model.py` includes the PyTorch network models of $\rho$ -Net, V-Net and u-net.
- `uitls.py` are the implementations of auxiliary function.
8 changes: 4 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ def plot_diff(fig_path=None, smooth=False):
u_diff_plot = u_diff_hist
rho_diff_plot = rho_diff_hist

plt.plot(u_diff_plot, lw=3, label=r"$|u^{(k)} - u^{(k-1)}|$", c='steelblue', ls='--')
plt.plot(rho_diff_plot, lw=3, label=r"$|\rho^{(k)} - \rho^{(k-1)}|$", c='indianred', alpha=.8)
plt.plot(u_diff_plot, lw=3, label=r"$|u^{(i)} - u^{(i-1)}|$", c='steelblue', ls='--')
plt.plot(rho_diff_plot, lw=3, label=r"$|\rho^{(i)} - \rho^{(i-1)}|$", c='indianred', alpha=.8)
plt.xlabel("iterations", fontsize=18, labelpad=6)
plt.xticks(fontsize=18)
plt.ylabel("convergence gap", fontsize=18, labelpad=6)
Expand All @@ -263,8 +263,8 @@ def plot_diff(fig_path=None, smooth=False):
u_diff_plot = u_diff_hist
rho_diff_plot = rho_diff_hist

plt.plot(u_diff_plot, lw=3, label=r"$|u^{(k)} - u^*|$", c='steelblue', ls='--')
plt.plot(rho_diff_plot, lw=3, label=r"$|\rho^{(k)} - \rho^*|$", c='indianred', alpha=.8)
plt.plot(u_diff_plot, lw=3, label=r"$|u^{(i)} - u^*|$", c='steelblue', ls='--')
plt.plot(rho_diff_plot, lw=3, label=r"$|\rho^{(i)} - \rho^*|$", c='indianred', alpha=.8)
plt.xlabel("iterations", fontsize=18, labelpad=6)
plt.xticks(fontsize=18)
plt.ylabel("loss", fontsize=18, labelpad=6)
Expand Down

0 comments on commit f3ada66

Please sign in to comment.