-
-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Documentation generated for AverageModelMerger.
Documentation generated for SLERPModelMerger. Documentation generated for QuantizedLN. Documentation generated in docs/zeta/nn/modules directory.
- Loading branch information
Kye
committed
Dec 27, 2023
1 parent
08b18b3
commit f3f3de9
Showing
6 changed files
with
349 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Zeta.nn.modules.AverageModelMerger Documentation | ||
|
||
## Introduction | ||
|
||
The AverageModelMerger class, found in the zeta.nn.modules library, is a simple yet powerful class to merge multiple models by averaging their weights. It offers a straightforward way to combine models trained in different stages, such as instruction and alignment tuning, leading to improved model performance in certain circumstances. | ||
|
||
## Class Definition: AverageModelMerger | ||
|
||
```python | ||
class AverageModelMerger: | ||
""" | ||
A class to merge multiple models by averaging their weights. | ||
Attributes: | ||
models (List[nn.Module]): A list of PyTorch models to be merged. | ||
Examples::- Example usage: | ||
model1 = nn.Linear(in_features=10, out_features=10) | ||
model2 = nn.Linear(in_features=10, out_features=10) | ||
model3 = nn.Linear(in_features=10, out_features=10) | ||
merge = AverageModelMerger([model1, model2, model3]) | ||
merged_model = merge.merge_models() | ||
print(merged_model) | ||
""" | ||
``` | ||
|
||
### Class Parameters: | ||
|
||
| Parameters | Data Type | Default Value | Description | | ||
|------------|---------------|---------------|-------------| | ||
| models | List[nn.Module] | N/A | List of PyTorch models to be merged | ||
|
||
### Class Methods: | ||
|
||
| Method Name | Description | Parameters | Returns | | ||
|-------------------|-------------|------------|---------| | ||
| `__init__(self, models: List[nn.Module])`| Initializes the AverageModelMerger with a list of models. | models (List[nn.Module]) | None | | ||
| `merge_models(self)` | Merges the models by averaging their weights. | None | A new model with averaged weights. | | ||
| `_copy_model_structure(model: nn.Module)` | Creates a new instance of a model with the same structure as the given model. | model (nn.Module) | A new model with the same structure. | | ||
|
||
### Constructor `__init__(self, models: List[nn.Module])` | ||
|
||
Initializes an instance of the AverageModelMerge class. It takes a list of PyTorch models as input which are to be merged later using the `merge_models` method. | ||
|
||
- **models (List[nn.Module])**: Models to be merged. | ||
|
||
### Method `merge_models(self) -> nn.Module` | ||
|
||
This function merges the models by averaging the weights of the PyTorch models. | ||
|
||
**Returns** | ||
|
||
nn.Module: A new model with averaged weights. | ||
|
||
### Method `_copy_model_structure(self, model: nn.Module) -> nn.Module` | ||
|
||
This function creates a new instance of a model with exactly the same structure as the given model. | ||
|
||
**Parameters** | ||
- **model (nn.Module)**: The model whose structure is to be copied. | ||
|
||
**Returns** | ||
|
||
nn.Module: A new model with exactly the same structure. | ||
|
||
## Examples of Usage: | ||
|
||
### Example 1 | ||
```python | ||
import torch.nn as nn | ||
from typing import List | ||
from zeta.nn.modules import AverageModelMerger | ||
|
||
# Define models | ||
model1 = nn.Linear(in_features=10, out_features=10) | ||
model2 = nn.Linear(in_features=10, out_features=10) | ||
model3 = nn.Linear(in_features=10, out_features=10) | ||
|
||
# Initialize AverageModelMerger | ||
merger = AverageModelMerger([model1, model2, model3]) | ||
|
||
# Merge models | ||
merged_model = merger.merge_models() | ||
|
||
# Print merged model | ||
print(merged_model) | ||
``` | ||
|
||
### Example 2 | ||
```python | ||
import torch.nn as nn | ||
from typing import List | ||
from zeta.nn.modules import AverageModelMerger | ||
|
||
# Define models | ||
model1 = nn.Conv2d(3, 6, 5) | ||
model2 = nn.Conv2d(3, 6, 5) | ||
model3 = nn.Conv2d(3, 6, 5) | ||
|
||
# Initialize AverageModelMerger | ||
merger = AverageModelMerger([model1, model2, model3]) | ||
|
||
# Merge models | ||
merged_model = merger.merge_models() | ||
|
||
# Print merged model | ||
print(merged_model) | ||
``` | ||
|
||
### Example 3 | ||
```python | ||
import torch.nn as nn | ||
from typing import List | ||
from zeta.nn.modules import AverageModelMerger | ||
|
||
# Define models | ||
model1 = nn.CrossEntropyLoss() | ||
model2 = nn.CrossEntropyLoss() | ||
model3 = nn.CrossEntropyLoss() | ||
|
||
# Initialize AverageModelMerger | ||
merger = AverageModelMerger([model1, model2, model3]) | ||
|
||
# Merge models | ||
merged_model = merger.merge_models() | ||
|
||
# Print merged model | ||
print(merged_model) | ||
``` | ||
|
||
All the examples above demonstrate the basic usage of this class. In cases where you have multiple trained models (e.g., resultant from a k-fold cross-validation or models trained on different datasets), you can use this class to merge or average their weights. The resultant model will carry averaged weights, giving a balanced representation of all the models. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Module/Class Name: QuantizedLN | ||
|
||
## Overview | ||
`QuantizedLN` is a PyTorch module built on the lower-level `nn.Module` class. This module is designed for applying a form of normalization where the layer inputs are transformed to have zero mean and one standard deviation, and subsequently quantized. The main purpose of this module is to provide normalized inputs with reduced precision for performance and memory optimization purposes, seen typically in low-resource environments like mobile devices. | ||
|
||
The 'LN' in the class name refers to Layer Normalization, a technique that normalizes the inputs across the features instead of the batch size. The 'Quantized' in the class name signifies that the normalized output is then quantized to a specified bit size for memory and speed optimizations. | ||
|
||
```python | ||
class QuantizedLN(nn.Module): | ||
def __init__( | ||
self, | ||
normalized_shape, | ||
bits: int = 8, | ||
eps=1e-5, | ||
element_wise_affine=True, | ||
): | ||
""" | ||
Initializes a QuantizedLN module. | ||
Args: | ||
normalized_shape (int or tuple): The expected input shape. | ||
bits (int, optional): Number of bits for quantization. Defaults to 8. | ||
eps (float, optional): A value added to the denominator for numerical stability. Defaults to 1e-5. | ||
element_wise_affine (bool, optional): Whether to include learnable affine parameters. Defaults to True. | ||
""" | ||
... | ||
|
||
def forward(self, x: Tensor): | ||
""" | ||
Forward pass of the QuantizedLN module. | ||
Args: | ||
x (torch.Tensor): Input tensor. | ||
Returns: | ||
torch.Tensor: Output tensor after applying quantization and layer normalization. | ||
""" | ||
... | ||
``` | ||
|
||
## Parameters | ||
The `QuantizedLN` class takes the following arguments during initialization: | ||
|
||
| Parameter Name | Type | Description | Default Value | | ||
| --- | --- | --- | --- | | ||
| normalized_shape | int or tuple | The expected input shape | Required | | ||
| bits | int | Number of bits for quantization | 8 | | ||
| eps | float | A small value added to the denominator for numerical stability | 1e-5 | | ||
| element_wise_affine | bool | If True, includes learnable affine parameters | True | | ||
|
||
## Methods | ||
The `QuantizedLN` class has the following methods: | ||
|
||
| Method Name | Args | Returns | Description | | ||
| --- | --- | --- | --- | | ||
| init | normalized_shape, bits, eps, element_wise_affine | None | Initializes the QuantizedLN module | | ||
| forward | x | torch.Tensor | Performs the forward pass | | ||
|
||
## Usage Examples | ||
|
||
Below are three examples of how to use the `QuantizedLN` module. | ||
|
||
### Example 1 | ||
|
||
```python | ||
import torch | ||
from torch import nn, Tensor | ||
from torch.nn.parameter import Parameter | ||
from zeta.nn.modules import QuantizedLN | ||
|
||
# Define input tensor | ||
x = torch.randn(128, 10) | ||
# Create module instance | ||
ln = QuantizedLN(10) | ||
# Apply module to input | ||
output = ln(x) | ||
``` | ||
|
||
### Example 2 | ||
|
||
Define a custom network that uses have the `QuantizedLN` module: | ||
|
||
```python | ||
import torch.nn as nn | ||
|
||
class CustomNetwork(nn.Module): | ||
def __init__(self): | ||
super(CustomNetwork, self).__init__() | ||
self.layer1 = nn.Linear(128, 256) | ||
self.ln = QuantizedLN(256) | ||
|
||
def forward(self, x): | ||
x = self.layer1(x) | ||
x = self.ln(x) | ||
return x | ||
|
||
# Define input tensor | ||
x = torch.randn(128, 10) | ||
|
||
# Create network instance | ||
network = CustomNetwork() | ||
|
||
# Forward pass | ||
output = network(x) | ||
``` | ||
|
||
### Example 3 | ||
|
||
The `QuantizedLN` module in a multi-layer setup: | ||
|
||
```python | ||
import torch.nn as nn | ||
|
||
class DeepNetwork(nn.Module): | ||
def __init__(self): | ||
super(DeepNetwork, self).__init__() | ||
self.layer1 = nn.Linear(128, 256) | ||
self.ln1 = QuantizedLN(256) | ||
self.layer2 = nn.Linear(256, 512) | ||
self.ln2 = QuantizedLN(512) | ||
|
||
def forward(self, x): | ||
x = self.layer1(x) | ||
x = self.ln1(x) | ||
x = self.layer2(x) | ||
x = self.ln2(x) | ||
return x | ||
|
||
# Define input tensor | ||
x = torch.randn(128, 10) | ||
|
||
# Create network instance | ||
network = DeepNetwork() | ||
|
||
# Forward pass | ||
output = network(x) | ||
``` | ||
|
||
## Additional Notes: | ||
|
||
Please make sure that the `absmax_quantize` function used in the `forward` method is properly defined in the scope of this class or is imported correctly from an external module. It is a quantization function that is not included by default in PyTorch's `nn` module. Failure to define or import this function will result in errors during execution. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# SLERPModelMerger | ||
|
||
- **Description**: | ||
SLERPModelMerger is a Python class that performs model merging using Spherical Linear Interpolation (SLERP). Interpolation is a process of finding a value between two points on a line or curve to create new geometries. Spherical Linear Interpolation (SLERP) is a method of interpolation where the model weights are visualized on a hypersphere, and the interpolated weight is obtained by moving along the geodesic (or the shortest path) on the hypersphere. This class is implemented under the PyTorch framework. | ||
|
||
The class can blend or interpolate the weights of two trained models, allowing one to create an ensemble or composite model of the input models, essentially capturing the strengths of both. In ML terminology, this can be thought of as a "committee machine" where transformations applied to input data by multiple models are combined to produce a single output. This method is known to improve the robustness and performance of models, especially in scenarios where the strength of individual models varies across different sections of the input space. | ||
|
||
- **Class Definition**: | ||
|
||
Here is the class definition: | ||
|
||
```python | ||
class SLERPModelMerger(nn.Module): | ||
@enforce_types | ||
def __init__(self, model1: nn.Module, model2: nn.Module, t: float = 0.5): | ||
|
||
def merge(self) -> nn.Module: | ||
|
||
@staticmethod | ||
@enforce_types | ||
def _slerp(w1: Tensor, w2: Tensor, t: float) -> Tensor: | ||
|
||
@staticmethod | ||
@enforce_types | ||
def _copy_model_structure(model: nn.Module) -> nn.Module: | ||
``` | ||
|
||
- **Parameters:** | ||
`model1` and `model2` are instances of PyTorch's neural network models (such as instances of `nn.Linear, nn.Conv2d` etc.) between which weights' interpolation is to be done. The parameter `t` is the interpolation parameter that ranges from 0 (model1) to 1 (model2), indicating the weightage given to the two models during interpolation. Hence, for t=0, the resulting model would be the same as model1, and for t=1, the resulting model would be the same as model2. | ||
|
||
- **Methods:** | ||
|
||
- `merge()` : This method merges the input models (`model1` and `model2`), according to the interpolation parameter `t`. The merging is done by interpolating the weights of the two models using Spherical Linear Interpolation (SLERP). | ||
|
||
- `_slerp(w1: Tensor, w2: Tensor, t: float) -> Tensor:` : This method performs Spherical Linear Interpolation (SLERP) between two tensors. | ||
|
||
- `_copy_model_structure(model: nn.Module) -> nn.Module:` : This method creates a new instance of a model with the same structure as the given model. | ||
|
||
- **Usage:** | ||
|
||
The following code shows how to use the SLERPModelMerger class to merge two PyTorch models (in this case two linear models): | ||
|
||
```python | ||
import torch.nn as nn | ||
model1 = nn.Linear(10, 10) | ||
model2 = nn.Linear(10, 10) | ||
|
||
merger = SLERPModelMerger(model1, model2, 0.5) | ||
merged_model = merger.merge() | ||
|
||
# This will output the merged state_dict | ||
print(merged_model.state_dict()) | ||
``` | ||
|
||
The prints statement will output the state_dict of the merged model. The state_dict is a Python dictionary that maps each layer to its corresponding parameters (tensors). | ||
|
||
The weightage given to the two models for interpolation is specified by the interpolation parameter `t`. As t ranges from 0 to 1, we can see the merged model evolve from model1 to model2. Thus, by changing `t` we can generate a spectrum of models from model1 to model2. | ||
|
||
This gives us a strategy to generate an ensemble of models by interpolating between two carefully chosen base models. This ensemble could then be used for model selection or for creating a more robust composite model. | ||
|
||
- **References:** | ||
|
||
- Ken Shoemake. Animating rotation with quaternion curves. In ACM SIGGRAPH Computer Graphics, volume 19, pp. 245–254. ACM, 1985. | ||
|
||
Remarks: Remember, while PyTorch models accept parameters as single arguments to their constructors, this is not the case with all models. Some models might accept parameters as lists, sets, or other non-single-parameter-type objects. As such, additional pre-processing or configuration might be needed if using those models with SLERPModelMerger. Try these different configurations and methods to find the one that best suits your requirements. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters