Skip to content

Commit

Permalink
"torch" --> "mlx".
Browse files Browse the repository at this point in the history
  • Loading branch information
hallvardnmbu committed Feb 11, 2024
1 parent 6513c89 commit 48e9423
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
7 changes: 3 additions & 4 deletions reinforcement-learning/mlx-policy-based.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
"source": [
"import time\n",
"import imageio\n",
"import numpy as np\n",
"import mlx.nn as nn\n",
"import mlx.core as mx\n",
"import gymnasium as gym\n",
Expand Down Expand Up @@ -232,7 +231,7 @@
" ax[0].axvline(x=i, color='gray', linewidth=0.5)\n",
" ax[1].axvline(x=i, color='gray', linewidth=0.5)\n",
"\n",
"plt.savefig(\"./static/images/torch-pbg.png\")\n",
"plt.savefig(\"./static/images/mlx-pbg.png\")\n",
"plt.show()"
],
"metadata": {
Expand Down Expand Up @@ -270,7 +269,7 @@
" state = mx.array(state)\n",
"\n",
" images.append(environment.render())\n",
"_ = imageio.mimsave('./static/images/torch-pbg.gif', images, duration=25)"
"_ = imageio.mimsave('./static/images/mlx-pbg.gif', images, duration=25)"
],
"metadata": {
"collapsed": false,
Expand All @@ -284,7 +283,7 @@
{
"cell_type": "markdown",
"source": [
"<img src=\"./static/images/torch-pbg.gif\" width=\"1000\" height=\"1000\" />"
"<img src=\"./static/images/mlx-pbg.gif\" width=\"1000\" height=\"1000\" />"
],
"metadata": {
"collapsed": false
Expand Down
8 changes: 3 additions & 5 deletions reinforcement-learning/mlx-value-based.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
},
"outputs": [],
"source": [
"import copy\n",
"import time\n",
"import imageio\n",
"import mlx.nn as nn\n",
"import mlx.core as mx\n",
"import gymnasium as gym\n",
"import mlx.optimizers as optim\n",
Expand Down Expand Up @@ -296,7 +294,7 @@
" ax[0].axvline(x=i, color='gray', linewidth=0.5)\n",
" ax[1].axvline(x=i, color='gray', linewidth=0.5)\n",
"\n",
"plt.savefig(\"./static/images/torch-dqn.png\")\n",
"plt.savefig(\"./static/images/mlx-dqn.png\")\n",
"plt.show()"
],
"metadata": {
Expand Down Expand Up @@ -333,7 +331,7 @@
" state = mx.array(state)\n",
"\n",
" images.append(environment.render())\n",
"_ = imageio.mimsave('./static/images/torch-dqn.gif', images, duration=25)"
"_ = imageio.mimsave('./static/images/mlx-dqn.gif', images, duration=25)"
],
"metadata": {
"collapsed": false,
Expand All @@ -347,7 +345,7 @@
{
"cell_type": "markdown",
"source": [
"<img src=\"./static/images/torch-dqn.gif\" width=\"1000\" height=\"1000\" />"
"<img src=\"./static/images/mlx-dqn.gif\" width=\"1000\" height=\"1000\" />"
],
"metadata": {
"collapsed": false
Expand Down

0 comments on commit 48e9423

Please sign in to comment.