-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
339 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,127 +1,123 @@ | ||
Metadata-Version: 2.1 | ||
Name: FJUtils | ||
Version: 0.0.15 | ||
Summary: UNKNOWN | ||
Home-page: https://github.com/erfanzar/ | ||
Author: Erfan Zare Chavoshi | ||
Author-email: erfanzare82@yahoo.com | ||
License: Apache License 2.0 | ||
Platform: UNKNOWN | ||
Classifier: Programming Language :: Python :: 3 | ||
Classifier: License :: OSI Approved :: MIT License | ||
Classifier: Operating System :: OS Independent | ||
Classifier: Programming Language :: Python :: 3 | ||
Classifier: Programming Language :: Python :: 3.7 | ||
Classifier: Programming Language :: Python :: 3.8 | ||
Classifier: Programming Language :: Python :: 3.9 | ||
Classifier: Programming Language :: Python :: 3.10 | ||
Classifier: Programming Language :: Python :: 3.11 | ||
Requires-Python: >=3.7, <3.11 | ||
Description-Content-Type: text/markdown | ||
License-File: LICENSE | ||
|
||
# FJUtils | ||
|
||
a package for custom Jax Flax Functions and Utils | ||
Welcome to FJUtils - A collection of useful functions and utilities for Flax and JAX! | ||
Some Parts of code if from EasyLM <3 | ||
|
||
## Overview | ||
|
||
FJUtils is a collection of functions and utilities that can help with various tasks when using Flax and JAX. It includes | ||
checkpoint savers, partitioning tools, and other helpful functions. | ||
The goal of FJUtils is to make your life easier when working with Flax and JAX. Whether you are training a new model, | ||
fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, FJUtils has something to | ||
offer. | ||
|
||
## Features | ||
|
||
Here are some of the features included in FJUtils: | ||
|
||
Checkpoint saver: This tool provides an easy way to save and restore checkpoints during training. You can specify how | ||
often to save checkpoints, where to store them, and more. | ||
|
||
Partitioning tools: FJUtils includes several tools for partitioning data across multiple devices or nodes. These tools | ||
can help you optimize the performance of your models on clusters or distributed systems. | ||
|
||
Other utilities: FJUtils includes a variety of other helpful functions and utilities and more. | ||
|
||
## Getting Started | ||
|
||
To get started with FJUtils, simply install the package using pip: | ||
|
||
```shell | ||
pip install fjutils | ||
``` | ||
|
||
Once installed, you can import the package and start using its functions and utilities. For example, here's how you can | ||
use the checkpoint saver for loading models like : | ||
|
||
```python | ||
from fjutils import StreamingCheckpointer | ||
|
||
ckpt = StreamingCheckpointer.load_trainstate_checkpoint('params::<path to model>') | ||
|
||
``` | ||
|
||
or simply getting an optimizer for example adafactor with cosine scheduler : | ||
|
||
```python | ||
from jax import numpy as jnp | ||
from fjutils.optimizers import get_adafactor_with_cosine_scheduler | ||
|
||
optimizer, scheduler = get_adafactor_with_cosine_scheduler( | ||
steps=5000, | ||
learning_rate=5e-5, | ||
weight_decay=1e-1, | ||
min_dim_size_to_factor=128, | ||
decay_rate=0.8, | ||
decay_offset=0, | ||
multiply_by_parameter_scale=True, | ||
clipping_threshold=1.0, | ||
momentum=None, | ||
dtype_momentum=jnp.float32, | ||
weight_decay_rate=None, | ||
eps=1e-30, | ||
factored=True, | ||
weight_decay_mask=None, | ||
) | ||
|
||
``` | ||
|
||
or getting adamw with linear scheduler: | ||
|
||
```python | ||
from fjutils.optimizers import get_adamw_with_linear_scheduler | ||
|
||
optimizer, scheduler = get_adamw_with_linear_scheduler( | ||
steps=5000, | ||
learning_rate_start=5e-5, | ||
learning_rate_end=1e-5, | ||
b1=0.9, | ||
b2=0.999, | ||
eps=1e-8, | ||
eps_root=0.0, | ||
weight_decay=1e-1, | ||
mu_dtype=None, | ||
) | ||
|
||
``` | ||
|
||
## Documentation | ||
|
||
- TODO | ||
|
||
## Contributing | ||
|
||
FJUtils is an open-source project, and contributions are always welcome! If you have a feature request, bug report, or | ||
just want to help out with development, please check out our GitHub repository and feel free to submit a pull request or | ||
open an issue. | ||
|
||
Thank you for using FJUtils, and happy training! | ||
|
||
## Credits | ||
|
||
- some of the jax utilities are from EasyLM by [Young-Geng](https://github.com/young-geng) | ||
|
||
|
||
Metadata-Version: 2.1 | ||
Name: FJUtils | ||
Version: 0.0.16 | ||
Home-page: https://github.com/erfanzar/ | ||
Author: Erfan Zare Chavoshi | ||
Author-email: erfanzare82@yahoo.com | ||
License: Apache License 2.0 | ||
Classifier: Programming Language :: Python :: 3 | ||
Classifier: License :: OSI Approved :: MIT License | ||
Classifier: Operating System :: OS Independent | ||
Classifier: Programming Language :: Python :: 3 | ||
Classifier: Programming Language :: Python :: 3.7 | ||
Classifier: Programming Language :: Python :: 3.8 | ||
Classifier: Programming Language :: Python :: 3.9 | ||
Classifier: Programming Language :: Python :: 3.10 | ||
Classifier: Programming Language :: Python :: 3.11 | ||
Requires-Python: >=3.7 | ||
Description-Content-Type: text/markdown | ||
License-File: LICENSE | ||
|
||
# FJUtils | ||
|
||
a package for custom Jax Flax Functions and Utils | ||
Welcome to FJUtils - A collection of useful functions and utilities for Flax and JAX! | ||
Some Parts of code if from EasyLM <3 | ||
|
||
## Overview | ||
|
||
FJUtils is a collection of functions and utilities that can help with various tasks when using Flax and JAX. It includes | ||
checkpoint savers, partitioning tools, and other helpful functions. | ||
The goal of FJUtils is to make your life easier when working with Flax and JAX. Whether you are training a new model, | ||
fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, FJUtils has something to | ||
offer. | ||
|
||
## Features | ||
|
||
Here are some of the features included in FJUtils: | ||
|
||
Checkpoint saver: This tool provides an easy way to save and restore checkpoints during training. You can specify how | ||
often to save checkpoints, where to store them, and more. | ||
|
||
Partitioning tools: FJUtils includes several tools for partitioning data across multiple devices or nodes. These tools | ||
can help you optimize the performance of your models on clusters or distributed systems. | ||
|
||
Other utilities: FJUtils includes a variety of other helpful functions and utilities and more. | ||
|
||
## Getting Started | ||
|
||
To get started with FJUtils, simply install the package using pip: | ||
|
||
```shell | ||
pip install fjutils | ||
``` | ||
|
||
Once installed, you can import the package and start using its functions and utilities. For example, here's how you can | ||
use the checkpoint saver for loading models like : | ||
|
||
```python | ||
from fjutils import StreamingCheckpointer | ||
|
||
ckpt = StreamingCheckpointer.load_trainstate_checkpoint('params::<path to model>') | ||
|
||
``` | ||
|
||
or simply getting an optimizer for example adafactor with cosine scheduler : | ||
|
||
```python | ||
from jax import numpy as jnp | ||
from fjutils.optimizers import get_adafactor_with_cosine_scheduler | ||
|
||
optimizer, scheduler = get_adafactor_with_cosine_scheduler( | ||
steps=5000, | ||
learning_rate=5e-5, | ||
weight_decay=1e-1, | ||
min_dim_size_to_factor=128, | ||
decay_rate=0.8, | ||
decay_offset=0, | ||
multiply_by_parameter_scale=True, | ||
clipping_threshold=1.0, | ||
momentum=None, | ||
dtype_momentum=jnp.float32, | ||
weight_decay_rate=None, | ||
eps=1e-30, | ||
factored=True, | ||
weight_decay_mask=None, | ||
) | ||
|
||
``` | ||
|
||
or getting adamw with linear scheduler: | ||
|
||
```python | ||
from fjutils.optimizers import get_adamw_with_linear_scheduler | ||
|
||
optimizer, scheduler = get_adamw_with_linear_scheduler( | ||
steps=5000, | ||
learning_rate_start=5e-5, | ||
learning_rate_end=1e-5, | ||
b1=0.9, | ||
b2=0.999, | ||
eps=1e-8, | ||
eps_root=0.0, | ||
weight_decay=1e-1, | ||
mu_dtype=None, | ||
) | ||
|
||
``` | ||
|
||
## Documentation | ||
|
||
- TODO | ||
|
||
## Contributing | ||
|
||
FJUtils is an open-source project, and contributions are always welcome! If you have a feature request, bug report, or | ||
just want to help out with development, please check out our GitHub repository and feel free to submit a pull request or | ||
open an issue. | ||
|
||
Thank you for using FJUtils, and happy training! | ||
|
||
## Credits | ||
|
||
- some of the jax utilities are from EasyLM by [Young-Geng](https://github.com/young-geng) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
* **StreamingCheckpointer** class: | ||
* **__init__(config, checkpoint_dir, enable=True)**: | ||
Initializes a StreamingCheckpointer object. | ||
Args: | ||
config: A dictionary of configuration options. | ||
checkpoint_dir: The directory where checkpoints will be saved. | ||
enable: Whether to enable the streaming checkpointing functionality. | ||
* **save_checkpoint(train_state, filename, gather_fns=None)**: | ||
Saves a checkpoint to the specified file. | ||
Args: | ||
train_state: The train state to save. | ||
filename: The name of the checkpoint file. | ||
gather_fns: A dictionary of functions that can be used to gather | ||
large tensors into smaller chunks before saving them. | ||
* **save_all(train_state, gather_fns, metadata=None, dataset=None, milestone=False)**: | ||
Saves a checkpoint for the current step, as well as metadata and dataset | ||
information. | ||
Args: | ||
train_state: The train state to save. | ||
gather_fns: A dictionary of functions that can be used to gather | ||
large tensors into smaller chunks before saving them. | ||
metadata: Metadata to save. | ||
dataset: Dataset information to save. | ||
milestone: Whether this is a milestone checkpoint. | ||
* **load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None)**: | ||
Loads a checkpoint from the specified file. | ||
Args: | ||
path: The path to the checkpoint file. | ||
target: The object to load the checkpoint into. | ||
shard_fns: A dictionary of functions that can be used to shard | ||
tensors after loading them. | ||
remove_dict_prefix: A tuple of keys to remove from the loaded | ||
checkpoints. | ||
* **load_flax_checkpoint(path, target=None, shard_fns=None)**: | ||
Loads a standard flax checkpoint from the specified file. | ||
Args: | ||
path: The path to the checkpoint file. | ||
target: The object to load the checkpoint into. | ||
shard_fns: A dictionary of functions that can be used to shard | ||
tensors after loading them. | ||
|
||
* **load_trainstate_checkpoint(load_from, trainstate_target=None, | ||
trainstate_shard_fns=None, | ||
disallow_trainstate=False)**: | ||
Load a train state checkpoint from the specified load_from string. | ||
Args: | ||
load_from: The load_from string, which can be one of the following: | ||
* 'trainstate': Load the entire train state. | ||
* 'trainstate_params': Load the params part of the train state. | ||
* 'params': Load the params. | ||
* 'flax_params': Load the params in the standard flax format (non-streaming). | ||
trainstate_target: The target object to load the train state into. | ||
trainstate_shard_fns: A dictionary of functions that can be used to shard | ||
tensors after loading them. | ||
disallow_trainstate: Whether to disallow loading the full train state. |
Oops, something went wrong.