Skip to content

Commit

Permalink
[nnx] add flaxlib
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 16, 2024
1 parent d111adf commit abe94ff
Show file tree
Hide file tree
Showing 14 changed files with 455 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

name: Upload Python Package
name: Flax - Build and upload to PyPI

on:
release:
types: [created]
types: [published]

jobs:
deploy:
Expand Down
25 changes: 8 additions & 17 deletions .github/workflows/build.yml → .github/workflows/flax_test.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Build
name: Flax - Test

on:
push:
Expand Down Expand Up @@ -70,7 +70,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: yezz123/setup-uv@v4
- uses: astral-sh/setup-uv@v2
with:
uv-version: "0.3.0"
- name: Install standalone dependencies only
Expand Down Expand Up @@ -104,23 +104,14 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: yezz123/setup-uv@v4
- name: Setup uv
uses: astral-sh/setup-uv@v2
with:
uv-version: "0.3.0"
- name: Cached virtual environment
id: venv_cache
uses: actions/cache@v3
with:
path: .venv
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('uv.lock') }}
- name: Install Dependencies for cache
if: steps.venv_cache.outputs.cache-hit != 'true'
run: |
if [ -d ".venv" ]; then rm -rf .venv; fi
uv sync --locked --all-extras
- name: Check lockfile
version: "0.3.0"
- name: Install dependencies
run: |
uv sync --locked --all-extras
uv sync --locked --extra all --extra testing --extra docs
uv pip install ./flaxlib
- name: Install JAX
run: |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
Expand Down
65 changes: 65 additions & 0 deletions .github/workflows/flaxlib_publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
name: Flaxlib - Build and upload to PyPI

on:
workflow_dispatch:
pull_request:
push:
branches: [main]
paths: ['flaxlib/**']
release:
types: [published]

jobs:
build_wheels:
name: Build wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
# macos-13 is an intel runner, macos-14 is apple silicon
os: [ubuntu-latest, windows-latest, macos-13, macos-14]

steps:
- uses: actions/checkout@v4

- name: Build wheels
uses: pypa/cibuildwheel@v2.21.0

- uses: actions/upload-artifact@v4
with:
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
path: ./wheelhouse/*.whl

build_sdist:
name: Build source distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Build sdist
run: pipx run build --sdist

- uses: actions/upload-artifact@v4
with:
name: cibw-sdist
path: dist/*.tar.gz

upload_pypi:
needs: [build_wheels, build_sdist]
runs-on: ubuntu-latest
environment: pypi
permissions:
id-token: write
if: github.event_name == 'release' && github.event.action == 'published'
# or, alternatively, upload to PyPI on every tag starting with 'v' (remove on: release above to use this)
# if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/download-artifact@v4
with:
# unpacks all CIBW artifacts into dist/
pattern: cibw-*
path: dist
merge-multiple: true

- uses: pypa/gh-action-pypi-publish@release/v1
with:
# To test: repository-url: https://test.pypi.org/legacy/
2 changes: 1 addition & 1 deletion flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
AuxData = tp.TypeVar('AuxData')

StateLeaf = VariableState[tp.Any]
NodeLeaf = VariableState[tp.Any]
NodeLeaf = Variable[tp.Any]
GraphState = State[Key, StateLeaf]
GraphFlatState = FlatState[StateLeaf]

Expand Down
72 changes: 72 additions & 0 deletions flaxlib/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/target

# Byte-compiled / optimized / DLL files
__pycache__/
.pytest_cache/
*.py[cod]

# C extensions
*.so

# Distribution / packaging
.Python
.venv/
env/
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
include/
man/
venv/
*.egg-info/
.installed.cfg
*.egg

# Installer logs
pip-log.txt
pip-delete-this-directory.txt
pip-selfcheck.json

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.cache
nosetests.xml
coverage.xml

# Translations
*.mo

# Mr Developer
.mr.developer.cfg
.project
.pydevproject

# Rope
.ropeproject

# Django stuff:
*.log
*.pot

.DS_Store

# Sphinx documentation
docs/_build/

# PyCharm
.idea/

# VSCode
.vscode/

# Pyenv
.python-version
171 changes: 171 additions & 0 deletions flaxlib/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions flaxlib/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "flaxlib"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "flaxlib"
crate-type = ["cdylib"]

[dependencies]
pyo3 = "0.22.0"
Loading

0 comments on commit abe94ff

Please sign in to comment.