Skip to content

Commit

Permalink
add experintal nnx
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 9, 2023
1 parent cbf7bea commit 79e770d
Show file tree
Hide file tree
Showing 69 changed files with 12,937 additions and 36 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ build/
docs/**/tmp

# used by direnv
.envrc
.envrc

# custom
/tmp-files
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@

# -- Options for myst ----------------------------------------------
# uncomment line below to avoid running notebooks during development
# nb_execution_mode = 'off'
nb_execution_mode = 'off'
# Notebook cell execution timeout; defaults to 30.
nb_execution_timeout = 100
# List of patterns, relative to source directory, that match notebook
Expand All @@ -133,6 +133,7 @@
nb_execution_excludepatterns = [
'quick_start.ipynb', # <-- times out
'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0
'flax/experimental/nnx', # exclude nnx
]
# raise exceptions on execution so CI can catch errors
nb_execution_allow_errors = False
Expand Down
6 changes: 6 additions & 0 deletions docs/experimental/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@


.. toctree::
:maxdepth: 2

nnx
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,5 @@ Notable examples in Flax include:
developer_notes/index
philosophy
contributing
experimental
api_reference/index
92 changes: 68 additions & 24 deletions flax/core/flax_functional_engine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"import functools\n",
"import jax\n",
"from jax import numpy as jnp, random, lax\n",
"import numpy as np\n"
"import numpy as np"
]
},
{
Expand Down Expand Up @@ -78,26 +78,34 @@
}
],
"source": [
"def dense(scope: Scope, inputs: Array, features: int, bias: bool = True,\n",
" kernel_init=nn.linear.default_kernel_init,\n",
" bias_init=nn.initializers.zeros_init()):\n",
"def dense(\n",
" scope: Scope,\n",
" inputs: Array,\n",
" features: int,\n",
" bias: bool = True,\n",
" kernel_init=nn.linear.default_kernel_init,\n",
" bias_init=nn.initializers.zeros_init(),\n",
"):\n",
" kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features))\n",
" y = jnp.dot(inputs, kernel)\n",
" if bias:\n",
" y += scope.param('bias', bias_init, (features,))\n",
" return y\n",
"\n",
"\n",
"model_fn = functools.partial(dense, features=3)\n",
"\n",
"x = jnp.ones((1, 2))\n",
"y, params = init(model_fn)(random.key(0), x)\n",
"print(params)\n",
"\n",
"\n",
"def mlp(scope: Scope, inputs: Array, features: int):\n",
" hidden = scope.child(dense, 'hidden')(inputs, features)\n",
" hidden = nn.relu(hidden)\n",
" return dense(scope.push('out'), hidden, 1)\n",
"\n",
"\n",
"init(mlp)(random.key(0), x, features=3)"
]
},
Expand Down Expand Up @@ -138,16 +146,31 @@
" def attend(self, query):\n",
" return jnp.dot(query, self.table.T)\n",
"\n",
"\n",
"# all the embedding module does is provide a convenient initializers\n",
"\n",
"def embedding(scope: Scope, num_embeddings: int, features: int, init_fn=nn.linear.default_embed_init) -> Embedding:\n",
"\n",
"def embedding(\n",
" scope: Scope,\n",
" num_embeddings: int,\n",
" features: int,\n",
" init_fn=nn.linear.default_embed_init,\n",
") -> Embedding:\n",
" table = scope.param('table', init_fn, (num_embeddings, features))\n",
" return Embedding(table)\n",
"\n",
"\n",
"embedding, _ = init(embedding)(random.key(0), num_embeddings=2, features=3)\n",
"print(embedding.table)\n",
"print(embedding.lookup(1))\n",
"print(embedding.attend(jnp.ones((1, 3,))))"
"print(\n",
" embedding.attend(\n",
" jnp.ones((\n",
" 1,\n",
" 3,\n",
" ))\n",
" )\n",
")"
]
},
{
Expand Down Expand Up @@ -177,11 +200,16 @@
}
],
"source": [
"def lstm(scope, carry, inputs,\n",
" gate_fn=nn.activation.sigmoid, activation_fn=nn.activation.tanh,\n",
" kernel_init=nn.linear.default_kernel_init,\n",
" recurrent_kernel_init=nn.initializers.orthogonal(),\n",
" bias_init=nn.initializers.zeros_init()):\n",
"def lstm(\n",
" scope,\n",
" carry,\n",
" inputs,\n",
" gate_fn=nn.activation.sigmoid,\n",
" activation_fn=nn.activation.tanh,\n",
" kernel_init=nn.linear.default_kernel_init,\n",
" recurrent_kernel_init=nn.initializers.orthogonal(),\n",
" bias_init=nn.initializers.zeros_init(),\n",
"):\n",
" r\"\"\"A long short-term memory (LSTM) cell.\n",
"\n",
" the mathematical definition of the cell is as follows\n",
Expand Down Expand Up @@ -217,11 +245,15 @@
" hidden_features = h.shape[-1]\n",
" # input and recurrent layers are summed so only one needs a bias.\n",
" dense_h = lambda name: scope.child(dense, name)(\n",
" h, features=hidden_features, bias=True,\n",
" kernel_init=recurrent_kernel_init, bias_init=bias_init)\n",
" h,\n",
" features=hidden_features,\n",
" bias=True,\n",
" kernel_init=recurrent_kernel_init,\n",
" bias_init=bias_init,\n",
" )\n",
" dense_i = lambda name: scope.child(dense, name)(\n",
" inputs, features=hidden_features, bias=False,\n",
" kernel_init=kernel_init)\n",
" inputs, features=hidden_features, bias=False, kernel_init=kernel_init\n",
" )\n",
" i = gate_fn(dense_i(name='ii') + dense_h(name='hi'))\n",
" f = gate_fn(dense_i(name='if') + dense_h(name='hf'))\n",
" g = activation_fn(dense_i(name='ig') + dense_h(name='hg'))\n",
Expand All @@ -230,10 +262,12 @@
" new_h = o * activation_fn(new_c)\n",
" return (new_c, new_h), new_h\n",
"\n",
"\n",
"def lstm_init_carry(batch_dims, size, init_fn=jnp.zeros):\n",
" shape = batch_dims + (size,)\n",
" return init_fn(shape), init_fn(shape)\n",
"\n",
"\n",
"x = jnp.ones((1, 2))\n",
"carry = lstm_init_carry((1,), 3)\n",
"y, variables = init(lstm)(random.key(0), carry, x)\n",
Expand All @@ -259,23 +293,33 @@
"source": [
"def simple_scan(scope: Scope, xs):\n",
" init_carry = lstm_init_carry(xs.shape[:1], xs.shape[-1])\n",
"# cell = scope.child(lstm, 'cell')\n",
"# ys = []\n",
"# for i in range(xs.shape[1]):\n",
"# x = xs[:, i]\n",
"# init_carry, y = cell(init_carry, x)\n",
"# ys.append(y)\n",
"# return init_carry, ys\n",
" lstm_scan = lift.scan(lstm, in_axes=1, out_axes=1, variable_broadcast='params', split_rngs={'params': False})\n",
" # cell = scope.child(lstm, 'cell')\n",
" # ys = []\n",
" # for i in range(xs.shape[1]):\n",
" # x = xs[:, i]\n",
" # init_carry, y = cell(init_carry, x)\n",
" # ys.append(y)\n",
" # return init_carry, ys\n",
" lstm_scan = lift.scan(\n",
" lstm,\n",
" in_axes=1,\n",
" out_axes=1,\n",
" variable_broadcast='params',\n",
" split_rngs={'params': False},\n",
" )\n",
" return lstm_scan(scope, init_carry, xs)\n",
"\n",
"\n",
"key1, key2 = random.split(random.key(0), 2)\n",
"xs = random.uniform(key1, (1, 5, 2))\n",
"\n",
"\n",
"y, init_variables = init(simple_scan)(key2, xs)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n"
"print(\n",
" 'initialized parameter shapes:\\n',\n",
" jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)),\n",
")"
]
},
{
Expand Down
Empty file added flax/experimental/__init__.py
Empty file.
133 changes: 133 additions & 0 deletions flax/experimental/nnx/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# project specific
.vscode
/tmp
Loading

0 comments on commit 79e770d

Please sign in to comment.