Skip to content

Commit

Permalink
adding /docs
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Sep 10, 2023
1 parent 2e1b219 commit 9ff75c0
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 127 deletions.
250 changes: 123 additions & 127 deletions FJUtils.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -1,127 +1,123 @@
Metadata-Version: 2.1
Name: FJUtils
Version: 0.0.15
Summary: UNKNOWN
Home-page: https://github.com/erfanzar/
Author: Erfan Zare Chavoshi
Author-email: erfanzare82@yahoo.com
License: Apache License 2.0
Platform: UNKNOWN
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Requires-Python: >=3.7, <3.11
Description-Content-Type: text/markdown
License-File: LICENSE

# FJUtils

a package for custom Jax Flax Functions and Utils
Welcome to FJUtils - A collection of useful functions and utilities for Flax and JAX!
Some Parts of code if from EasyLM <3

## Overview

FJUtils is a collection of functions and utilities that can help with various tasks when using Flax and JAX. It includes
checkpoint savers, partitioning tools, and other helpful functions.
The goal of FJUtils is to make your life easier when working with Flax and JAX. Whether you are training a new model,
fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, FJUtils has something to
offer.

## Features

Here are some of the features included in FJUtils:

Checkpoint saver: This tool provides an easy way to save and restore checkpoints during training. You can specify how
often to save checkpoints, where to store them, and more.

Partitioning tools: FJUtils includes several tools for partitioning data across multiple devices or nodes. These tools
can help you optimize the performance of your models on clusters or distributed systems.

Other utilities: FJUtils includes a variety of other helpful functions and utilities and more.

## Getting Started

To get started with FJUtils, simply install the package using pip:

```shell
pip install fjutils
```

Once installed, you can import the package and start using its functions and utilities. For example, here's how you can
use the checkpoint saver for loading models like :

```python
from fjutils import StreamingCheckpointer

ckpt = StreamingCheckpointer.load_trainstate_checkpoint('params::<path to model>')

```

or simply getting an optimizer for example adafactor with cosine scheduler :

```python
from jax import numpy as jnp
from fjutils.optimizers import get_adafactor_with_cosine_scheduler

optimizer, scheduler = get_adafactor_with_cosine_scheduler(
steps=5000,
learning_rate=5e-5,
weight_decay=1e-1,
min_dim_size_to_factor=128,
decay_rate=0.8,
decay_offset=0,
multiply_by_parameter_scale=True,
clipping_threshold=1.0,
momentum=None,
dtype_momentum=jnp.float32,
weight_decay_rate=None,
eps=1e-30,
factored=True,
weight_decay_mask=None,
)

```

or getting adamw with linear scheduler:

```python
from fjutils.optimizers import get_adamw_with_linear_scheduler

optimizer, scheduler = get_adamw_with_linear_scheduler(
steps=5000,
learning_rate_start=5e-5,
learning_rate_end=1e-5,
b1=0.9,
b2=0.999,
eps=1e-8,
eps_root=0.0,
weight_decay=1e-1,
mu_dtype=None,
)

```

## Documentation

- TODO

## Contributing

FJUtils is an open-source project, and contributions are always welcome! If you have a feature request, bug report, or
just want to help out with development, please check out our GitHub repository and feel free to submit a pull request or
open an issue.

Thank you for using FJUtils, and happy training!

## Credits

- some of the jax utilities are from EasyLM by [Young-Geng](https://github.com/young-geng)


Metadata-Version: 2.1
Name: FJUtils
Version: 0.0.16
Home-page: https://github.com/erfanzar/
Author: Erfan Zare Chavoshi
Author-email: erfanzare82@yahoo.com
License: Apache License 2.0
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Requires-Python: >=3.7
Description-Content-Type: text/markdown
License-File: LICENSE

# FJUtils

a package for custom Jax Flax Functions and Utils
Welcome to FJUtils - A collection of useful functions and utilities for Flax and JAX!
Some Parts of code if from EasyLM <3

## Overview

FJUtils is a collection of functions and utilities that can help with various tasks when using Flax and JAX. It includes
checkpoint savers, partitioning tools, and other helpful functions.
The goal of FJUtils is to make your life easier when working with Flax and JAX. Whether you are training a new model,
fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, FJUtils has something to
offer.

## Features

Here are some of the features included in FJUtils:

Checkpoint saver: This tool provides an easy way to save and restore checkpoints during training. You can specify how
often to save checkpoints, where to store them, and more.

Partitioning tools: FJUtils includes several tools for partitioning data across multiple devices or nodes. These tools
can help you optimize the performance of your models on clusters or distributed systems.

Other utilities: FJUtils includes a variety of other helpful functions and utilities and more.

## Getting Started

To get started with FJUtils, simply install the package using pip:

```shell
pip install fjutils
```

Once installed, you can import the package and start using its functions and utilities. For example, here's how you can
use the checkpoint saver for loading models like :

```python
from fjutils import StreamingCheckpointer

ckpt = StreamingCheckpointer.load_trainstate_checkpoint('params::<path to model>')

```

or simply getting an optimizer for example adafactor with cosine scheduler :

```python
from jax import numpy as jnp
from fjutils.optimizers import get_adafactor_with_cosine_scheduler

optimizer, scheduler = get_adafactor_with_cosine_scheduler(
steps=5000,
learning_rate=5e-5,
weight_decay=1e-1,
min_dim_size_to_factor=128,
decay_rate=0.8,
decay_offset=0,
multiply_by_parameter_scale=True,
clipping_threshold=1.0,
momentum=None,
dtype_momentum=jnp.float32,
weight_decay_rate=None,
eps=1e-30,
factored=True,
weight_decay_mask=None,
)

```

or getting adamw with linear scheduler:

```python
from fjutils.optimizers import get_adamw_with_linear_scheduler

optimizer, scheduler = get_adamw_with_linear_scheduler(
steps=5000,
learning_rate_start=5e-5,
learning_rate_end=1e-5,
b1=0.9,
b2=0.999,
eps=1e-8,
eps_root=0.0,
weight_decay=1e-1,
mu_dtype=None,
)

```

## Documentation

- TODO

## Contributing

FJUtils is an open-source project, and contributions are always welcome! If you have a feature request, bug report, or
just want to help out with development, please check out our GitHub repository and feel free to submit a pull request or
open an issue.

Thank you for using FJUtils, and happy training!

## Credits

- some of the jax utilities are from EasyLM by [Young-Geng](https://github.com/young-geng)
55 changes: 55 additions & 0 deletions docs/CheckPointing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
* **StreamingCheckpointer** class:
* **__init__(config, checkpoint_dir, enable=True)**:
Initializes a StreamingCheckpointer object.
Args:
config: A dictionary of configuration options.
checkpoint_dir: The directory where checkpoints will be saved.
enable: Whether to enable the streaming checkpointing functionality.
* **save_checkpoint(train_state, filename, gather_fns=None)**:
Saves a checkpoint to the specified file.
Args:
train_state: The train state to save.
filename: The name of the checkpoint file.
gather_fns: A dictionary of functions that can be used to gather
large tensors into smaller chunks before saving them.
* **save_all(train_state, gather_fns, metadata=None, dataset=None, milestone=False)**:
Saves a checkpoint for the current step, as well as metadata and dataset
information.
Args:
train_state: The train state to save.
gather_fns: A dictionary of functions that can be used to gather
large tensors into smaller chunks before saving them.
metadata: Metadata to save.
dataset: Dataset information to save.
milestone: Whether this is a milestone checkpoint.
* **load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None)**:
Loads a checkpoint from the specified file.
Args:
path: The path to the checkpoint file.
target: The object to load the checkpoint into.
shard_fns: A dictionary of functions that can be used to shard
tensors after loading them.
remove_dict_prefix: A tuple of keys to remove from the loaded
checkpoints.
* **load_flax_checkpoint(path, target=None, shard_fns=None)**:
Loads a standard flax checkpoint from the specified file.
Args:
path: The path to the checkpoint file.
target: The object to load the checkpoint into.
shard_fns: A dictionary of functions that can be used to shard
tensors after loading them.

* **load_trainstate_checkpoint(load_from, trainstate_target=None,
trainstate_shard_fns=None,
disallow_trainstate=False)**:
Load a train state checkpoint from the specified load_from string.
Args:
load_from: The load_from string, which can be one of the following:
* 'trainstate': Load the entire train state.
* 'trainstate_params': Load the params part of the train state.
* 'params': Load the params.
* 'flax_params': Load the params in the standard flax format (non-streaming).
trainstate_target: The target object to load the train state into.
trainstate_shard_fns: A dictionary of functions that can be used to shard
tensors after loading them.
disallow_trainstate: Whether to disallow loading the full train state.
Loading

0 comments on commit 9ff75c0

Please sign in to comment.