Skip to content

Commit

Permalink
[nnx] add flaxlib
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 16, 2024
1 parent d111adf commit cb47a63
Show file tree
Hide file tree
Showing 15 changed files with 556 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

name: Upload Python Package
name: Flax - Build and upload to PyPI

on:
release:
types: [created]
types: [published]

jobs:
deploy:
Expand Down
27 changes: 10 additions & 17 deletions .github/workflows/build.yml → .github/workflows/flax_test.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Build
name: Flax - Test

on:
push:
Expand Down Expand Up @@ -70,7 +70,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: yezz123/setup-uv@v4
- uses: astral-sh/setup-uv@v2
with:
uv-version: "0.3.0"
- name: Install standalone dependencies only
Expand Down Expand Up @@ -104,23 +104,16 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: yezz123/setup-uv@v4
- name: Setup uv
uses: astral-sh/setup-uv@v2
with:
uv-version: "0.3.0"
- name: Cached virtual environment
id: venv_cache
uses: actions/cache@v3
with:
path: .venv
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('uv.lock') }}
- name: Install Dependencies for cache
if: steps.venv_cache.outputs.cache-hit != 'true'
run: |
if [ -d ".venv" ]; then rm -rf .venv; fi
uv sync --locked --all-extras
- name: Check lockfile
version: "0.3.0"
- name: Setup Rust (flaxlib)
uses: actions-rs/toolchain@v1
- name: Install dependencies
run: |
uv sync --locked --all-extras
uv sync --locked --extra all --extra testing --extra docs
uv pip install ./flaxlib
- name: Install JAX
run: |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
Expand Down
78 changes: 78 additions & 0 deletions .github/workflows/flaxlib_publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
name: Flaxlib - Build and upload to PyPI

# for testing only:
on:
pull_request:
branches: [main]

# on:
# workflow_dispatch:
# pull_request:
# push:
# branches: [main]
# paths: ['flaxlib/**']
# release:
# types: [published]

jobs:
build_wheels:
name: Build wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
# macos-13 is an intel runner, macos-14 is apple silicon
os: [ubuntu-latest, windows-latest, macos-13, macos-14]

steps:
- uses: actions/checkout@v4

- name: Setup Rust
uses: actions-rs/toolchain@v1

- name: Build wheels
uses: pypa/cibuildwheel@v2.21.0
with:
package-dir: './flaxlib'
output-dir: './flaxlib/wheelhouse'

- uses: actions/upload-artifact@v4
with:
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
path: ./flaxlib/wheelhouse/*.whl

build_sdist:
name: Build source distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Setup Rust
uses: actions-rs/toolchain@v1

- name: Build sdist
run: pipx run build --sdist flaxlib

- uses: actions/upload-artifact@v4
with:
name: cibw-sdist
path: dist/*.tar.gz

upload_pypi:
needs: [build_wheels, build_sdist]
runs-on: ubuntu-latest
permissions:
id-token: write
steps:
- uses: actions/download-artifact@v4
with:
# unpacks all CIBW artifacts into dist/
pattern: cibw-*
path: ./flaxlib/dist
merge-multiple: true

- name: Build and publish
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
twine upload flaxlib/dist/*
2 changes: 1 addition & 1 deletion flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
AuxData = tp.TypeVar('AuxData')

StateLeaf = VariableState[tp.Any]
NodeLeaf = VariableState[tp.Any]
NodeLeaf = Variable[tp.Any]
GraphState = State[Key, StateLeaf]
GraphFlatState = FlatState[StateLeaf]

Expand Down
72 changes: 72 additions & 0 deletions flaxlib/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/target

# Byte-compiled / optimized / DLL files
__pycache__/
.pytest_cache/
*.py[cod]

# C extensions
*.so

# Distribution / packaging
.Python
.venv/
env/
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
include/
man/
venv/
*.egg-info/
.installed.cfg
*.egg

# Installer logs
pip-log.txt
pip-delete-this-directory.txt
pip-selfcheck.json

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.cache
nosetests.xml
coverage.xml

# Translations
*.mo

# Mr Developer
.mr.developer.cfg
.project
.pydevproject

# Rope
.ropeproject

# Django stuff:
*.log
*.pot

.DS_Store

# Sphinx documentation
docs/_build/

# PyCharm
.idea/

# VSCode
.vscode/

# Pyenv
.python-version
171 changes: 171 additions & 0 deletions flaxlib/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit cb47a63

Please sign in to comment.