Skip to content

Commit

Permalink
[Documentation generated for AverageModelMerger.
Browse files Browse the repository at this point in the history
Documentation generated for SLERPModelMerger.
Documentation generated for QuantizedLN.
Documentation generated in docs/zeta/nn/modules directory.
  • Loading branch information
Kye committed Dec 27, 2023
1 parent 08b18b3 commit f3f3de9
Show file tree
Hide file tree
Showing 6 changed files with 349 additions and 11 deletions.
17 changes: 7 additions & 10 deletions scripts/auto_tests_docs/auto_docs.py → auto_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
from swarms import OpenAIChat

##########
from zeta.nn.modules.triple_skip import TripleSkipBlock
from zeta.nn.modules.dynamic_routing_block import DynamicRoutingBlock
from zeta.nn.modules.gated_residual_block import GatedResidualBlock
from zeta.nn.modules.stochastic_depth import StochasticSkipBlocK

from zeta.nn.modules.quantized_layernorm import QuantizedLN
from zeta.nn.modules.slerp_model_merger import SLERPModelMerger
from zeta.nn.modules.avg_model_merger import AverageModelMerger

####################
load_dotenv()
Expand All @@ -23,7 +21,7 @@
model = OpenAIChat(
model_name="gpt-4",
openai_api_key=api_key,
max_tokens=2000,
max_tokens=3000,
)


Expand Down Expand Up @@ -61,10 +59,9 @@ def process_documentation(cls):

def main():
classes = [
TripleSkipBlock,
DynamicRoutingBlock,
GatedResidualBlock,
StochasticSkipBlocK,
QuantizedLN,
SLERPModelMerger,
AverageModelMerger,
]

threads = []
Expand Down
131 changes: 131 additions & 0 deletions docs/zeta/nn/modules/averagemodelmerger.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Zeta.nn.modules.AverageModelMerger Documentation

## Introduction

The AverageModelMerger class, found in the zeta.nn.modules library, is a simple yet powerful class to merge multiple models by averaging their weights. It offers a straightforward way to combine models trained in different stages, such as instruction and alignment tuning, leading to improved model performance in certain circumstances.

## Class Definition: AverageModelMerger

```python
class AverageModelMerger:
"""
A class to merge multiple models by averaging their weights.
Attributes:
models (List[nn.Module]): A list of PyTorch models to be merged.
Examples::- Example usage:
model1 = nn.Linear(in_features=10, out_features=10)
model2 = nn.Linear(in_features=10, out_features=10)
model3 = nn.Linear(in_features=10, out_features=10)
merge = AverageModelMerger([model1, model2, model3])
merged_model = merge.merge_models()
print(merged_model)
"""
```

### Class Parameters:

| Parameters | Data Type | Default Value | Description |
|------------|---------------|---------------|-------------|
| models | List[nn.Module] | N/A | List of PyTorch models to be merged

### Class Methods:

| Method Name | Description | Parameters | Returns |
|-------------------|-------------|------------|---------|
| `__init__(self, models: List[nn.Module])`| Initializes the AverageModelMerger with a list of models. | models (List[nn.Module]) | None |
| `merge_models(self)` | Merges the models by averaging their weights. | None | A new model with averaged weights. |
| `_copy_model_structure(model: nn.Module)` | Creates a new instance of a model with the same structure as the given model. | model (nn.Module) | A new model with the same structure. |

### Constructor `__init__(self, models: List[nn.Module])`

Initializes an instance of the AverageModelMerge class. It takes a list of PyTorch models as input which are to be merged later using the `merge_models` method.

- **models (List[nn.Module])**: Models to be merged.

### Method `merge_models(self) -> nn.Module`

This function merges the models by averaging the weights of the PyTorch models.

**Returns**

nn.Module: A new model with averaged weights.

### Method `_copy_model_structure(self, model: nn.Module) -> nn.Module`

This function creates a new instance of a model with exactly the same structure as the given model.

**Parameters**
- **model (nn.Module)**: The model whose structure is to be copied.

**Returns**

nn.Module: A new model with exactly the same structure.

## Examples of Usage:

### Example 1
```python
import torch.nn as nn
from typing import List
from zeta.nn.modules import AverageModelMerger

# Define models
model1 = nn.Linear(in_features=10, out_features=10)
model2 = nn.Linear(in_features=10, out_features=10)
model3 = nn.Linear(in_features=10, out_features=10)

# Initialize AverageModelMerger
merger = AverageModelMerger([model1, model2, model3])

# Merge models
merged_model = merger.merge_models()

# Print merged model
print(merged_model)
```

### Example 2
```python
import torch.nn as nn
from typing import List
from zeta.nn.modules import AverageModelMerger

# Define models
model1 = nn.Conv2d(3, 6, 5)
model2 = nn.Conv2d(3, 6, 5)
model3 = nn.Conv2d(3, 6, 5)

# Initialize AverageModelMerger
merger = AverageModelMerger([model1, model2, model3])

# Merge models
merged_model = merger.merge_models()

# Print merged model
print(merged_model)
```

### Example 3
```python
import torch.nn as nn
from typing import List
from zeta.nn.modules import AverageModelMerger

# Define models
model1 = nn.CrossEntropyLoss()
model2 = nn.CrossEntropyLoss()
model3 = nn.CrossEntropyLoss()

# Initialize AverageModelMerger
merger = AverageModelMerger([model1, model2, model3])

# Merge models
merged_model = merger.merge_models()

# Print merged model
print(merged_model)
```

All the examples above demonstrate the basic usage of this class. In cases where you have multiple trained models (e.g., resultant from a k-fold cross-validation or models trained on different datasets), you can use this class to merge or average their weights. The resultant model will carry averaged weights, giving a balanced representation of all the models.
141 changes: 141 additions & 0 deletions docs/zeta/nn/modules/quantizedln.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Module/Class Name: QuantizedLN

## Overview
`QuantizedLN` is a PyTorch module built on the lower-level `nn.Module` class. This module is designed for applying a form of normalization where the layer inputs are transformed to have zero mean and one standard deviation, and subsequently quantized. The main purpose of this module is to provide normalized inputs with reduced precision for performance and memory optimization purposes, seen typically in low-resource environments like mobile devices.

The 'LN' in the class name refers to Layer Normalization, a technique that normalizes the inputs across the features instead of the batch size. The 'Quantized' in the class name signifies that the normalized output is then quantized to a specified bit size for memory and speed optimizations.

```python
class QuantizedLN(nn.Module):
def __init__(
self,
normalized_shape,
bits: int = 8,
eps=1e-5,
element_wise_affine=True,
):
"""
Initializes a QuantizedLN module.
Args:
normalized_shape (int or tuple): The expected input shape.
bits (int, optional): Number of bits for quantization. Defaults to 8.
eps (float, optional): A value added to the denominator for numerical stability. Defaults to 1e-5.
element_wise_affine (bool, optional): Whether to include learnable affine parameters. Defaults to True.
"""
...

def forward(self, x: Tensor):
"""
Forward pass of the QuantizedLN module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after applying quantization and layer normalization.
"""
...
```

## Parameters
The `QuantizedLN` class takes the following arguments during initialization:

| Parameter Name | Type | Description | Default Value |
| --- | --- | --- | --- |
| normalized_shape | int or tuple | The expected input shape | Required |
| bits | int | Number of bits for quantization | 8 |
| eps | float | A small value added to the denominator for numerical stability | 1e-5 |
| element_wise_affine | bool | If True, includes learnable affine parameters | True |

## Methods
The `QuantizedLN` class has the following methods:

| Method Name | Args | Returns | Description |
| --- | --- | --- | --- |
| init | normalized_shape, bits, eps, element_wise_affine | None | Initializes the QuantizedLN module |
| forward | x | torch.Tensor | Performs the forward pass |

## Usage Examples

Below are three examples of how to use the `QuantizedLN` module.

### Example 1

```python
import torch
from torch import nn, Tensor
from torch.nn.parameter import Parameter
from zeta.nn.modules import QuantizedLN

# Define input tensor
x = torch.randn(128, 10)
# Create module instance
ln = QuantizedLN(10)
# Apply module to input
output = ln(x)
```

### Example 2

Define a custom network that uses have the `QuantizedLN` module:

```python
import torch.nn as nn

class CustomNetwork(nn.Module):
def __init__(self):
super(CustomNetwork, self).__init__()
self.layer1 = nn.Linear(128, 256)
self.ln = QuantizedLN(256)

def forward(self, x):
x = self.layer1(x)
x = self.ln(x)
return x

# Define input tensor
x = torch.randn(128, 10)

# Create network instance
network = CustomNetwork()

# Forward pass
output = network(x)
```

### Example 3

The `QuantizedLN` module in a multi-layer setup:

```python
import torch.nn as nn

class DeepNetwork(nn.Module):
def __init__(self):
super(DeepNetwork, self).__init__()
self.layer1 = nn.Linear(128, 256)
self.ln1 = QuantizedLN(256)
self.layer2 = nn.Linear(256, 512)
self.ln2 = QuantizedLN(512)

def forward(self, x):
x = self.layer1(x)
x = self.ln1(x)
x = self.layer2(x)
x = self.ln2(x)
return x

# Define input tensor
x = torch.randn(128, 10)

# Create network instance
network = DeepNetwork()

# Forward pass
output = network(x)
```

## Additional Notes:

Please make sure that the `absmax_quantize` function used in the `forward` method is properly defined in the scope of this class or is imported correctly from an external module. It is a quantization function that is not included by default in PyTorch's `nn` module. Failure to define or import this function will result in errors during execution.
65 changes: 65 additions & 0 deletions docs/zeta/nn/modules/slerpmodelmerger.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SLERPModelMerger

- **Description**:
SLERPModelMerger is a Python class that performs model merging using Spherical Linear Interpolation (SLERP). Interpolation is a process of finding a value between two points on a line or curve to create new geometries. Spherical Linear Interpolation (SLERP) is a method of interpolation where the model weights are visualized on a hypersphere, and the interpolated weight is obtained by moving along the geodesic (or the shortest path) on the hypersphere. This class is implemented under the PyTorch framework.

The class can blend or interpolate the weights of two trained models, allowing one to create an ensemble or composite model of the input models, essentially capturing the strengths of both. In ML terminology, this can be thought of as a "committee machine" where transformations applied to input data by multiple models are combined to produce a single output. This method is known to improve the robustness and performance of models, especially in scenarios where the strength of individual models varies across different sections of the input space.

- **Class Definition**:

Here is the class definition:

```python
class SLERPModelMerger(nn.Module):
@enforce_types
def __init__(self, model1: nn.Module, model2: nn.Module, t: float = 0.5):

def merge(self) -> nn.Module:

@staticmethod
@enforce_types
def _slerp(w1: Tensor, w2: Tensor, t: float) -> Tensor:

@staticmethod
@enforce_types
def _copy_model_structure(model: nn.Module) -> nn.Module:
```

- **Parameters:**
`model1` and `model2` are instances of PyTorch's neural network models (such as instances of `nn.Linear, nn.Conv2d` etc.) between which weights' interpolation is to be done. The parameter `t` is the interpolation parameter that ranges from 0 (model1) to 1 (model2), indicating the weightage given to the two models during interpolation. Hence, for t=0, the resulting model would be the same as model1, and for t=1, the resulting model would be the same as model2.

- **Methods:**

- `merge()` : This method merges the input models (`model1` and `model2`), according to the interpolation parameter `t`. The merging is done by interpolating the weights of the two models using Spherical Linear Interpolation (SLERP).

- `_slerp(w1: Tensor, w2: Tensor, t: float) -> Tensor:` : This method performs Spherical Linear Interpolation (SLERP) between two tensors.

- `_copy_model_structure(model: nn.Module) -> nn.Module:` : This method creates a new instance of a model with the same structure as the given model.

- **Usage:**

The following code shows how to use the SLERPModelMerger class to merge two PyTorch models (in this case two linear models):

```python
import torch.nn as nn
model1 = nn.Linear(10, 10)
model2 = nn.Linear(10, 10)

merger = SLERPModelMerger(model1, model2, 0.5)
merged_model = merger.merge()

# This will output the merged state_dict
print(merged_model.state_dict())
```

The prints statement will output the state_dict of the merged model. The state_dict is a Python dictionary that maps each layer to its corresponding parameters (tensors).

The weightage given to the two models for interpolation is specified by the interpolation parameter `t`. As t ranges from 0 to 1, we can see the merged model evolve from model1 to model2. Thus, by changing `t` we can generate a spectrum of models from model1 to model2.

This gives us a strategy to generate an ensemble of models by interpolating between two carefully chosen base models. This ensemble could then be used for model selection or for creating a more robust composite model.

- **References:**

- Ken Shoemake. Animating rotation with quaternion curves. In ACM SIGGRAPH Computer Graphics, volume 19, pp. 245–254. ACM, 1985.

Remarks: Remember, while PyTorch models accept parameters as single arguments to their constructors, this is not the case with all models. Some models might accept parameters as lists, sets, or other non-single-parameter-type objects. As such, additional pre-processing or configuration might be needed if using those models with SLERPModelMerger. Try these different configurations and methods to find the one that best suits your requirements.
3 changes: 3 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ nav:
- gatedresidualblock: "zeta/nn/modules/gatedresidualblock.md"
- tripleskipblock: "zeta/nn/modules/tripleskipblock.md"
- DynamicRoutingBlock: "zeta/nn/modules/dynamicroutingblock.md"
- AverageModelMerger: "zeta/nn/modules/averagemodelmerger.md"
- SLERPModelMerger: "zeta/nn/modules/slerpmodelmerger.md"
- QuantizedLN: "zeta/nn/modules/quantizedln.md"
- zeta.nn.attention:
- FlashAttention: "zeta/nn/attention/flash_attention.md"
- MultiQueryAttention: "zeta/nn/attention/multiquery.md"
Expand Down
3 changes: 2 additions & 1 deletion zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@
from zeta.nn.modules.dynamic_routing_block import DynamicRoutingBlock
from zeta.nn.modules.gated_residual_block import GatedResidualBlock
from zeta.nn.modules.stochastic_depth import StochasticSkipBlocK
from zeta.nn.modules.quantized_layernorm import QuantizedLN


from zeta.nn.modules.quantized_layernorm import QuantizedLN
from zeta.nn.modules.slerp_model_merger import SLERPModelMerger
from zeta.nn.modules.avg_model_merger import AverageModelMerger

Expand Down

0 comments on commit f3f3de9

Please sign in to comment.