From 9ff75c0a50737d54edb76418ceb6013d9acb4ea4 Mon Sep 17 00:00:00 2001 From: erfanzar Date: Sun, 10 Sep 2023 19:39:10 +0330 Subject: [PATCH] adding `/docs` --- FJUtils.egg-info/PKG-INFO | 250 +++++++++++++++++++------------------- docs/CheckPointing.md | 55 +++++++++ docs/Optimizers.md | 161 ++++++++++++++++++++++++ docs/README.md | 0 4 files changed, 339 insertions(+), 127 deletions(-) create mode 100644 docs/CheckPointing.md create mode 100644 docs/Optimizers.md create mode 100644 docs/README.md diff --git a/FJUtils.egg-info/PKG-INFO b/FJUtils.egg-info/PKG-INFO index 4c91e1b..5bfe881 100644 --- a/FJUtils.egg-info/PKG-INFO +++ b/FJUtils.egg-info/PKG-INFO @@ -1,127 +1,123 @@ -Metadata-Version: 2.1 -Name: FJUtils -Version: 0.0.15 -Summary: UNKNOWN -Home-page: https://github.com/erfanzar/ -Author: Erfan Zare Chavoshi -Author-email: erfanzare82@yahoo.com -License: Apache License 2.0 -Platform: UNKNOWN -Classifier: Programming Language :: Python :: 3 -Classifier: License :: OSI Approved :: MIT License -Classifier: Operating System :: OS Independent -Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python :: 3.7 -Classifier: Programming Language :: Python :: 3.8 -Classifier: Programming Language :: Python :: 3.9 -Classifier: Programming Language :: Python :: 3.10 -Classifier: Programming Language :: Python :: 3.11 -Requires-Python: >=3.7, <3.11 -Description-Content-Type: text/markdown -License-File: LICENSE - -# FJUtils - -a package for custom Jax Flax Functions and Utils -Welcome to FJUtils - A collection of useful functions and utilities for Flax and JAX! -Some Parts of code if from EasyLM <3 - -## Overview - -FJUtils is a collection of functions and utilities that can help with various tasks when using Flax and JAX. It includes -checkpoint savers, partitioning tools, and other helpful functions. -The goal of FJUtils is to make your life easier when working with Flax and JAX. Whether you are training a new model, -fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, FJUtils has something to -offer. - -## Features - -Here are some of the features included in FJUtils: - -Checkpoint saver: This tool provides an easy way to save and restore checkpoints during training. You can specify how -often to save checkpoints, where to store them, and more. - -Partitioning tools: FJUtils includes several tools for partitioning data across multiple devices or nodes. These tools -can help you optimize the performance of your models on clusters or distributed systems. - -Other utilities: FJUtils includes a variety of other helpful functions and utilities and more. - -## Getting Started - -To get started with FJUtils, simply install the package using pip: - -```shell -pip install fjutils -``` - -Once installed, you can import the package and start using its functions and utilities. For example, here's how you can -use the checkpoint saver for loading models like : - -```python -from fjutils import StreamingCheckpointer - -ckpt = StreamingCheckpointer.load_trainstate_checkpoint('params::') - -``` - -or simply getting an optimizer for example adafactor with cosine scheduler : - -```python -from jax import numpy as jnp -from fjutils.optimizers import get_adafactor_with_cosine_scheduler - -optimizer, scheduler = get_adafactor_with_cosine_scheduler( - steps=5000, - learning_rate=5e-5, - weight_decay=1e-1, - min_dim_size_to_factor=128, - decay_rate=0.8, - decay_offset=0, - multiply_by_parameter_scale=True, - clipping_threshold=1.0, - momentum=None, - dtype_momentum=jnp.float32, - weight_decay_rate=None, - eps=1e-30, - factored=True, - weight_decay_mask=None, -) - -``` - -or getting adamw with linear scheduler: - -```python -from fjutils.optimizers import get_adamw_with_linear_scheduler - -optimizer, scheduler = get_adamw_with_linear_scheduler( - steps=5000, - learning_rate_start=5e-5, - learning_rate_end=1e-5, - b1=0.9, - b2=0.999, - eps=1e-8, - eps_root=0.0, - weight_decay=1e-1, - mu_dtype=None, -) - -``` - -## Documentation - -- TODO - -## Contributing - -FJUtils is an open-source project, and contributions are always welcome! If you have a feature request, bug report, or -just want to help out with development, please check out our GitHub repository and feel free to submit a pull request or -open an issue. - -Thank you for using FJUtils, and happy training! - -## Credits - -- some of the jax utilities are from EasyLM by [Young-Geng](https://github.com/young-geng) - - +Metadata-Version: 2.1 +Name: FJUtils +Version: 0.0.16 +Home-page: https://github.com/erfanzar/ +Author: Erfan Zare Chavoshi +Author-email: erfanzare82@yahoo.com +License: Apache License 2.0 +Classifier: Programming Language :: Python :: 3 +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Requires-Python: >=3.7 +Description-Content-Type: text/markdown +License-File: LICENSE + +# FJUtils + +a package for custom Jax Flax Functions and Utils +Welcome to FJUtils - A collection of useful functions and utilities for Flax and JAX! +Some Parts of code if from EasyLM <3 + +## Overview + +FJUtils is a collection of functions and utilities that can help with various tasks when using Flax and JAX. It includes +checkpoint savers, partitioning tools, and other helpful functions. +The goal of FJUtils is to make your life easier when working with Flax and JAX. Whether you are training a new model, +fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, FJUtils has something to +offer. + +## Features + +Here are some of the features included in FJUtils: + +Checkpoint saver: This tool provides an easy way to save and restore checkpoints during training. You can specify how +often to save checkpoints, where to store them, and more. + +Partitioning tools: FJUtils includes several tools for partitioning data across multiple devices or nodes. These tools +can help you optimize the performance of your models on clusters or distributed systems. + +Other utilities: FJUtils includes a variety of other helpful functions and utilities and more. + +## Getting Started + +To get started with FJUtils, simply install the package using pip: + +```shell +pip install fjutils +``` + +Once installed, you can import the package and start using its functions and utilities. For example, here's how you can +use the checkpoint saver for loading models like : + +```python +from fjutils import StreamingCheckpointer + +ckpt = StreamingCheckpointer.load_trainstate_checkpoint('params::') + +``` + +or simply getting an optimizer for example adafactor with cosine scheduler : + +```python +from jax import numpy as jnp +from fjutils.optimizers import get_adafactor_with_cosine_scheduler + +optimizer, scheduler = get_adafactor_with_cosine_scheduler( + steps=5000, + learning_rate=5e-5, + weight_decay=1e-1, + min_dim_size_to_factor=128, + decay_rate=0.8, + decay_offset=0, + multiply_by_parameter_scale=True, + clipping_threshold=1.0, + momentum=None, + dtype_momentum=jnp.float32, + weight_decay_rate=None, + eps=1e-30, + factored=True, + weight_decay_mask=None, +) + +``` + +or getting adamw with linear scheduler: + +```python +from fjutils.optimizers import get_adamw_with_linear_scheduler + +optimizer, scheduler = get_adamw_with_linear_scheduler( + steps=5000, + learning_rate_start=5e-5, + learning_rate_end=1e-5, + b1=0.9, + b2=0.999, + eps=1e-8, + eps_root=0.0, + weight_decay=1e-1, + mu_dtype=None, +) + +``` + +## Documentation + +- TODO + +## Contributing + +FJUtils is an open-source project, and contributions are always welcome! If you have a feature request, bug report, or +just want to help out with development, please check out our GitHub repository and feel free to submit a pull request or +open an issue. + +Thank you for using FJUtils, and happy training! + +## Credits + +- some of the jax utilities are from EasyLM by [Young-Geng](https://github.com/young-geng) diff --git a/docs/CheckPointing.md b/docs/CheckPointing.md new file mode 100644 index 0000000..e04e4fe --- /dev/null +++ b/docs/CheckPointing.md @@ -0,0 +1,55 @@ +* **StreamingCheckpointer** class: + * **__init__(config, checkpoint_dir, enable=True)**: + Initializes a StreamingCheckpointer object. + Args: + config: A dictionary of configuration options. + checkpoint_dir: The directory where checkpoints will be saved. + enable: Whether to enable the streaming checkpointing functionality. + * **save_checkpoint(train_state, filename, gather_fns=None)**: + Saves a checkpoint to the specified file. + Args: + train_state: The train state to save. + filename: The name of the checkpoint file. + gather_fns: A dictionary of functions that can be used to gather + large tensors into smaller chunks before saving them. + * **save_all(train_state, gather_fns, metadata=None, dataset=None, milestone=False)**: + Saves a checkpoint for the current step, as well as metadata and dataset + information. + Args: + train_state: The train state to save. + gather_fns: A dictionary of functions that can be used to gather + large tensors into smaller chunks before saving them. + metadata: Metadata to save. + dataset: Dataset information to save. + milestone: Whether this is a milestone checkpoint. + * **load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None)**: + Loads a checkpoint from the specified file. + Args: + path: The path to the checkpoint file. + target: The object to load the checkpoint into. + shard_fns: A dictionary of functions that can be used to shard + tensors after loading them. + remove_dict_prefix: A tuple of keys to remove from the loaded + checkpoints. + * **load_flax_checkpoint(path, target=None, shard_fns=None)**: + Loads a standard flax checkpoint from the specified file. + Args: + path: The path to the checkpoint file. + target: The object to load the checkpoint into. + shard_fns: A dictionary of functions that can be used to shard + tensors after loading them. + +* **load_trainstate_checkpoint(load_from, trainstate_target=None, + trainstate_shard_fns=None, + disallow_trainstate=False)**: + Load a train state checkpoint from the specified load_from string. + Args: + load_from: The load_from string, which can be one of the following: + * 'trainstate': Load the entire train state. + * 'trainstate_params': Load the params part of the train state. + * 'params': Load the params. + * 'flax_params': Load the params in the standard flax format (non-streaming). + trainstate_target: The target object to load the train state into. + trainstate_shard_fns: A dictionary of functions that can be used to shard + tensors after loading them. + disallow_trainstate: Whether to disallow loading the full train state. diff --git a/docs/Optimizers.md b/docs/Optimizers.md new file mode 100644 index 0000000..712ad1a --- /dev/null +++ b/docs/Optimizers.md @@ -0,0 +1,161 @@ +* **optax_add_scheduled_weight_decay** function: + * **Arguments:** + * **schedule_fn:** A function that takes the current step number as input and returns the weight decay value. + * **mask:** An optional mask that can be used to apply weight decay to a subset of parameters. + * **Returns:** + An Optax GradientTransformation object that adds the scheduled weight decay to the updates. + +* **get_adamw_with_cosine_scheduler** function: + * **Arguments:** + * **steps:** The total number of training steps. + * **learning_rate:** The initial learning rate. + * **b1:** The first Adam beta parameter. + * **b2:** The second Adam beta parameter. + * **eps:** The Adam epsilon parameter. + * **eps_root:** The Adam epsilon root parameter. + * **weight_decay:** The weight decay coefficient. + * **gradient_accumulation_steps:** The number of gradient accumulation steps. + * **mu_dtype:** The dtype of the Adam momentum terms. + * **Returns:** + A tuple of the Adam optimizer and the cosine learning rate scheduler. + +* **get_adamw_with_linear_scheduler** function: + * **Arguments:** + * **steps:** The total number of training steps. + * **learning_rate_start:** The initial learning rate. + * **learning_rate_end:** The final learning rate. + * **b1:** The first Adam beta parameter. + * **b2:** The second Adam beta parameter. + * **eps:** The Adam epsilon parameter. + * **eps_root:** The Adam epsilon root parameter. + * **weight_decay:** The weight decay coefficient. + * **gradient_accumulation_steps:** The number of gradient accumulation steps. + * **mu_dtype:** The dtype of the Adam momentum terms. + * **Returns:** + A tuple of the Adam optimizer and the linear learning rate scheduler. + + +* **get_adafactor_with_linear_scheduler** function: + * **Arguments:** + * **steps:** The total number of training steps. + * **learning_rate_start:** The initial learning rate. + * **learning_rate_end:** The final learning rate. + * **weight_decay:** The weight decay coefficient. + * **min_dim_size_to_factor:** The minimum size of a parameter tensor for it to be factored by Adafactor. + * **decay_rate:** The decay rate parameter for Adafactor. + * **decay_offset:** The decay offset parameter for Adafactor. + * **multiply_by_parameter_scale:** Whether to multiply the learning rate by the parameter scale. + * **clipping_threshold:** The gradient clipping threshold. + * **momentum:** The momentum parameter for Adafactor. + * **dtype_momentum:** The dtype of the momentum term for Adafactor. + * **weight_decay_rate:** The weight decay rate for Adafactor. + * **eps:** The epsilon parameter for Adafactor. + * **factored:** Whether to use the factored implementation of Adafactor. + * **gradient_accumulation_steps:** The number of gradient accumulation steps. + * **weight_decay_mask:** A mask tensor that specifies which parameters to apply weight decay to. + * **Returns:** + A tuple of the Adafactor optimizer and the linear learning rate scheduler. + +* **get_adafactor_with_cosine_scheduler** function: + * **Arguments:** + * **steps:** The total number of training steps. + * **learning_rate:** The initial learning rate. + * **weight_decay:** The weight decay coefficient. + * **min_dim_size_to_factor:** The minimum size of a parameter tensor for it to be factored by Adafactor. + * **decay_rate:** The decay rate parameter for Adafactor. + * **decay_offset:** The decay offset parameter for Adafactor. + * **multiply_by_parameter_scale:** Whether to multiply the learning rate by the parameter scale. + * **clipping_threshold:** The gradient clipping threshold. + * **momentum:** The momentum parameter for Adafactor. + * **dtype_momentum:** The dtype of the momentum term for Adafactor. + * **weight_decay_rate:** The weight decay rate for Adafactor. + * **eps:** The epsilon parameter for Adafactor. + * **factored:** Whether to use the factored implementation of Adafactor. + * **gradient_accumulation_steps:** The number of gradient accumulation steps. + * **weight_decay_mask:** A mask tensor that specifies which parameters to apply weight decay to. + * **Returns:** + A tuple of the Adafactor optimizer and the cosine learning rate scheduler. + +* **get_lion_with_cosine_scheduler** function: + * **Arguments:** + * **steps:** The total number of training steps. + * **learning_rate:** The initial learning rate. + * **alpha:** The minimum value of the multiplier used to adjust the learning rate. + * **exponent:** The exponent of the cosine decay schedule. + * **b1:** The first Lion beta parameter. + * **b2:** The second Lion beta parameter. + * **gradient_accumulation_steps:** The number of gradient accumulation steps. + * **mu_dtype:** The dtype of the Lion momentum terms. + * **Returns:** + A tuple of the Lion optimizer and the cosine learning rate scheduler. + + + +* **get_lion_with_linear_scheduler** function: + * **Arguments:** + * **steps:** The total number of training steps. + * **learning_rate_start:** The initial learning rate. + * **learning_rate_end:** The final learning rate. + * **b1:** The first Lion beta parameter. + * **b2:** The second Lion beta parameter. + * **gradient_accumulation_steps:** The number of gradient accumulation steps. + * **mu_dtype:** The dtype of the Lion momentum terms. + * **Returns:** + A tuple of the Lion optimizer and the linear learning rate scheduler. + +* **get_adamw_with_warm_up_cosine_scheduler** function: + * **Arguments:** + * **steps:** The total number of training steps. + * **learning_rate:** The initial learning rate. + * **b1:** The first Adam beta parameter. + * **b2:** The second Adam beta parameter. + * **eps:** The Adam epsilon parameter. + * **eps_root:** The Adam epsilon root parameter. + * **weight_decay:** The weight decay coefficient. + * **exponent:** The exponent of the cosine decay schedule. + * **gradient_accumulation_steps:** The number of gradient accumulation steps. + * **mu_dtype:** The dtype of the Adam momentum terms. + * **Returns:** + A tuple of the Adam optimizer and the cosine learning rate scheduler. + +* **get_adafactor_with_warm_up_cosine_scheduler** function: + * **Arguments:** + * **steps:** The total number of training steps. + * **learning_rate:** The initial learning rate. + * **weight_decay:** The weight decay coefficient. + * **min_dim_size_to_factor:** The minimum size of a parameter tensor for it to be factored by Adafactor. + * **decay_rate:** The decay rate parameter for Adafactor. + * **decay_offset:** The decay offset parameter for Adafactor. + * **multiply_by_parameter_scale:** Whether to multiply the learning rate by the parameter scale. + * **clipping_threshold:** The gradient clipping threshold. + * **momentum:** The momentum parameter for Adafactor. + * **dtype_momentum:** The dtype of the momentum term for Adafactor. + * **weight_decay_rate:** The weight decay rate for Adafactor. + * **eps:** The epsilon parameter for Adafactor. + * **factored:** Whether to use the factored implementation of Adafactor. + * **exponent:** The exponent of the cosine decay schedule. + * **weight_decay_mask:** A mask tensor that specifies which parameters to apply weight decay to. + * **gradient_accumulation_steps:** The number of gradient accumulation steps. + * **Returns:** + A tuple of the Adafactor optimizer and the cosine learning rate scheduler. + +* **get_lion_with_warm_up_cosine_scheduler** function: + * **Arguments:** + * **steps:** The total number of training steps. + * **learning_rate:** The initial learning rate. + * **exponent:** The exponent of the cosine decay schedule. + * **b1:** The first Lion beta parameter. + * **b2:** The second Lion beta parameter. + * **gradient_accumulation_steps:** The number of gradient accumulation steps. + * **mu_dtype:** The dtype of the Lion momentum terms. + * **Returns:** + A tuple of the Lion optimizer and the cosine learning rate scheduler. + +The references for these functions are: + +* Lion: A Linear-Complexity Adaptive Learning Rate Method: https://arxiv.org/abs/2204.02267 +* Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 +* Cosine Annealing with Restarts for Stochastic Optimization: https://arxiv.org/abs/1608.03983 +* Adafactor: Adaptive Learning Rates for Neural Networks: https://arxiv.org/abs/1804.04235 + +I hope this documentation is helpful. Let me know if you have any other questions. \ No newline at end of file diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..e69de29