diff --git a/README.md b/README.md
index 5c388e63..d18f3ae5 100644
--- a/README.md
+++ b/README.md
@@ -336,6 +336,40 @@ niva(
```
+
+### ZetaCloud
+Train or finetune any model on any cluster in 1 click with zetacloud, just pass in your file and the GPU type and quantity you want! To gain access first `pip install zetascale` then run `zeta -h` in the terminal.
+
+- Flexible Pricing with pooling from many clouds
+- Easy Deployment with 1 click
+- Various options for cloud providers!
+
+```bash
+Zetacloud CLI
+
+options:
+ -h, --help show this help message and exit
+ -t TASK_NAME, --task_name TASK_NAME
+ Task name
+ -c CLUSTER_NAME, --cluster_name CLUSTER_NAME
+ Cluster name
+ -cl CLOUD, --cloud CLOUD
+ Cloud provider
+ -g GPUS, --gpus GPUS GPUs
+ -f FILENAME, --filename FILENAME
+ Filename
+ -s, --stop Stop flag
+ -d, --down Down flag
+ -sr, --status_report Status report flag
+
+```
+
+- A simple run example code would be like:
+
+```bash
+zeta -f train.py -g A100:8
+```
+
# Documentation
[Click here for the documentation, it's at zeta.apac.ai](https://zeta.apac.ai)
diff --git a/docs/corporate/zeta_cloud.md b/docs/corporate/zeta_cloud.md
index 5f20b967..61cce3e1 100644
--- a/docs/corporate/zeta_cloud.md
+++ b/docs/corporate/zeta_cloud.md
@@ -60,3 +60,106 @@ The estimated timeline for shipping Zeta Cloud is as follows:
| API Access for Third-Party Integrations | Providing API access for integration with other tools and services. | Developers, businesses needing integrations. | Monthly/Annual subscription or pay-per-use. |
+
+
+# GTM - Go To Market
+
+### **Contents**
+
+1. Positioning Statement
+2. Early Adopter Segments
+3. Branding
+4. Channel Strategy
+5. Initial Marketing Methods
+6. Testing Plan
+7. LTV/CAC
+
+---
+
+### **1. Positioning Statement**
+
+*For AI engineers and data scientists who struggle with the complexities of model training and deployment, Zeta Cloud is a new cloud-based AI service that simplifies these processes. Unlike traditional cloud services, we offer an automated, user-friendly platform with a strong focus on accessibility and efficiency.*
+
+---
+
+### **2. Early Adopter Segments**
+
+**Segment Characteristics:**
+- Demographics: AI engineers and data scientists in mid-sized tech companies and startups.
+- Unmet Needs: Simplification of AI model deployment, efficient resource management, cost-effective scaling.
+- Behaviors: Active users of cloud computing services, frequent participants in tech forums and communities.
+- Psychographics: Value innovation, efficiency, and user-friendly interfaces.
+- Multi-party Decision Making: End users (engineers and scientists), economic buyers (company executives), and key influencers (tech thought leaders and industry experts).
+
+**Implications for Targeted Marketing:**
+- Focused engagement in tech forums and communities.
+- Tailored content marketing addressing specific needs and pain points.
+- Leveraging influencers and thought leaders to reach decision-makers.
+
+---
+
+### **3. Branding**
+
+**Strengths of Product Name:**
+- 'Zeta Cloud' conveys a sense of technological advancement and cloud-based efficiency.
+
+**Brand Association Words:**
+- Innovative, Efficient, User-Friendly, Accessible, Empowering, Reliable.
+
+**Aspirational Brand Similarities:**
+- Brands like AWS, Google Cloud, and Azure for their technological prowess and market presence.
+
+---
+
+### **4. Channel Strategy**
+
+**Channels:**
+- Own Website: Primary channel for direct sales and customer engagement.
+- Sales Force: Blend of inside sales for smaller accounts and field sales for larger, enterprise-level deals.
+- Channel Partners: Collaborations with tech marketplaces and value-added resellers.
+
+**Partner Responsibilities and Margins:**
+- Education and initial engagement by Zeta Cloud, with partners focusing on closing sales and after-sales service.
+- Attractive margins to incentivize partner engagement and commitment.
+
+---
+
+### **5. Initial Marketing Methods**
+
+**Hypothesized Effective Methods:**
+1. **Content Marketing:** Strength - establishes thought leadership; Weakness - time-intensive.
+2. **Social Media and Community Engagement:** Strength - builds brand awareness; Weakness - requires consistent, high-quality engagement.
+3. **Paid Digital Advertising (e.g., Google Ads, LinkedIn):** Strength - targets specific segments; Weakness - can be costly.
+
+**Performance Metrics:**
+- Engagement rates, conversion rates, customer acquisition costs.
+
+**Secondary Marketing Methods:**
+- Email marketing, PR activities, and webinars; secondary due to longer lead times and higher resource requirements.
+
+---
+
+### **6. Testing Plan**
+
+**Completed Tests:**
+- Initial A/B testing on website messaging and layout.
+
+**Upcoming Tests:**
+- Content marketing effectiveness: Measuring engagement and conversion rates from different content types.
+- Social media ad campaigns: Assessing customer acquisition costs and conversion rates.
+- Budget for tests: Approximately $20,000 over three months.
+
+---
+
+### **7. LTV/CAC**
+
+**LTV Targets:**
+- Average annual revenue per customer: $5,000.
+- Variable contribution margin: 70%.
+- Retention rate: 85% annually.
+
+**CAC Projections:**
+- Mix of free and paid methods: 40% free methods (referrals), 60% paid methods.
+- Viral coefficient: 0.5.
+- CAC for paid methods: $500 - $1,000, varying by channel.
+
diff --git a/docs/zeta/cloud/main.md b/docs/zeta/cloud/main.md
new file mode 100644
index 00000000..8aaeade3
--- /dev/null
+++ b/docs/zeta/cloud/main.md
@@ -0,0 +1,126 @@
+
+# ZetaCloud Documentation
+
+## Overview
+
+ZetaCloud is a versatile command-line tool that simplifies the process of training or fine-tuning machine learning models on remote GPU clusters. With just a few commands, you can effortlessly manage your tasks and harness the computational power of various GPUs. This comprehensive documentation will guide you through every aspect of the ZetaCloud CLI, from installation to advanced usage.
+
+## Table of Contents
+
+1. [Installation](#installation)
+2. [ZetaCloud CLI](#zetacloud-cli)
+ - [Options](#options)
+3. [Basic Usage](#basic-usage)
+ - [Example 1: Starting a Task](#example-1-starting-a-task)
+ - [Example 2: Stopping a Task](#example-2-stopping-a-task)
+ - [Example 3: Checking Task Status](#example-3-checking-task-status)
+4. [Advanced Usage](#advanced-usage)
+ - [Example 4: Cluster Selection](#example-4-cluster-selection)
+ - [Example 5: Choosing the Cloud Provider](#example-5-choosing-the-cloud-provider)
+5. [Additional Information](#additional-information)
+6. [References](#references)
+
+---
+
+## 1. Installation
+
+Getting started with ZetaCloud is quick and straightforward. Follow these steps to set up ZetaCloud on your machine:
+
+1. Open your terminal or command prompt.
+
+2. Install the `zetascale` package using `pip`:
+
+ ```bash
+ pip install zetascale
+ ```
+
+3. After a successful installation, you can access the ZetaCloud CLI by running the following command:
+
+ ```bash
+ zeta -h
+ ```
+
+ This command will display a list of available options and basic usage information for ZetaCloud.
+
+## 2. ZetaCloud CLI
+
+The ZetaCloud Command-Line Interface (CLI) provides a set of powerful options that enable you to manage tasks on GPU clusters effortlessly. Below are the available options:
+
+### Options
+
+- `-h, --help`: Display the help message and exit.
+- `-t TASK_NAME, --task_name TASK_NAME`: Specify the name of your task.
+- `-c CLUSTER_NAME, --cluster_name CLUSTER_NAME`: Specify the name of the cluster you want to use.
+- `-cl CLOUD, --cloud CLOUD`: Choose the cloud provider (e.g., AWS, Google Cloud, Azure).
+- `-g GPUS, --gpus GPUS`: Specify the number and type of GPUs required for your task.
+- `-f FILENAME, --filename FILENAME`: Provide the filename of your Python script or code.
+- `-s, --stop`: Use this flag to stop a running task.
+- `-d, --down`: Use this flag to terminate a cluster.
+- `-sr, --status_report`: Check the status of your task.
+
+## 3. Basic Usage
+
+ZetaCloud's basic usage covers essential tasks such as starting, stopping, and checking the status of your tasks. Let's explore these tasks with examples.
+
+### Example 1: Starting a Task
+
+To start a task, you need to specify the Python script you want to run and the GPU configuration. Here's an example command:
+
+```bash
+zeta -f train.py -g A100:8
+```
+
+In this example:
+- `-f train.py` indicates that you want to run the Python script named `train.py`.
+- `-g A100:8` specifies that you require 8 NVIDIA A100 GPUs for your task.
+
+### Example 2: Stopping a Task
+
+If you need to stop a running task, you can use the following command:
+
+```bash
+zeta -s
+```
+
+This command will stop the currently running task.
+
+### Example 3: Checking Task Status
+
+To check the status of your task, use the following command:
+
+```bash
+zeta -sr
+```
+
+This command will provide you with a detailed status report for your active task.
+
+## 4. Advanced Usage
+
+ZetaCloud also offers advanced options that allow you to fine-tune your tasks according to your specific requirements.
+
+### Example 4: Cluster Selection
+
+You can select a specific cluster for your task by providing the cluster name with the `-c` option:
+
+```bash
+zeta -f train.py -g A100:8 -c my_cluster
+```
+
+This command will run your task on the cluster named `my_cluster`.
+
+### Example 5: Choosing the Cloud Provider
+
+ZetaCloud supports multiple cloud providers. You can specify your preferred cloud provider using the `-cl` option:
+
+```bash
+zeta -f train.py -g A100:8 -cl AWS
+```
+
+This command will execute your task on a cloud provider's infrastructure, such as AWS.
+
+## 5. Additional Information
+
+- ZetaCloud simplifies the process of utilizing GPU clusters, allowing you to focus on your machine learning tasks rather than infrastructure management.
+
+- You can easily adapt ZetaCloud to various cloud providers, making it a versatile tool for your machine learning needs.
+
diff --git a/docs/zeta/nn/modules/fused_dropout_layernorm.md b/docs/zeta/nn/modules/fused_dropout_layernorm.md
new file mode 100644
index 00000000..eab36b9c
--- /dev/null
+++ b/docs/zeta/nn/modules/fused_dropout_layernorm.md
@@ -0,0 +1,137 @@
+# FusedDropoutLayerNorm Documentation
+
+## Overview
+
+The `FusedDropoutLayerNorm` module in PyTorch is designed to combine two commonly used operations in neural networks: dropout and layer normalization. This fusion aims to enhance the efficiency of the model by reducing the overhead associated with sequential operations. The module is particularly useful in scenarios where both dropout and layer normalization are critical for the model's performance.
+
+## Class Definition
+
+### `FusedDropoutLayerNorm`
+
+```python
+class FusedDropoutLayerNorm(nn.Module):
+ """
+ This class fuses Dropout and LayerNorm into a single module for efficiency.
+
+ Args:
+ dim (int): Input dimension of the layer.
+ dropout (float, optional): Probability of an element to be zeroed. Defaults to 0.1.
+ eps (float, optional): A value added to the denominator for numerical stability. Defaults to 1e-5.
+ elementwise_affine (bool, optional): A flag to enable learning of affine parameters. Defaults to True.
+ """
+```
+
+## Constructor Parameters
+
+| Parameter | Type | Description | Default Value |
+|---------------------|---------|----------------------------------------------------------|---------------|
+| `dim` | int | The input dimension of the layer. | - |
+| `dropout` | float | Dropout probability. | 0.1 |
+| `eps` | float | Epsilon for numerical stability in LayerNorm. | 1e-5 |
+| `elementwise_affine`| bool | Enables learning of affine parameters in LayerNorm. | True |
+
+## Methods
+
+### `forward`
+
+```python
+def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of FusedDropoutLayerNorm.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying dropout and layer normalization.
+ """
+```
+
+## Examples
+
+### Basic Usage
+
+```python
+import torch
+from torch import nn
+from zeta.nn import FusedDropoutLayerNorm
+
+# Initialize the module
+model = FusedDropoutLayerNorm(dim=512)
+
+# Create a sample input tensor
+x = torch.randn(1, 512)
+
+# Forward pass
+output = model(x)
+
+# Check output shape
+print(output.shape) # Expected: torch.Size([1, 512])
+```
+
+### Integration in a Neural Network
+
+```python
+import torch
+import torch.nn as nn
+from zeta.nn import FusedDropoutLayerNorm
+
+class SampleModel(nn.Module):
+ def __init__(self):
+ super(SampleModel, self).__init__()
+ self.linear = nn.Linear(512, 512)
+ self.fused_dropout_layernorm = FusedDropoutLayerNorm(512)
+
+ def forward(self, x):
+ x = self.linear(x)
+ x = self.fused_dropout_layernorm(x)
+ return x
+
+# Example
+model = SampleModel()
+input_tensor = torch.randn(10, 512)
+output = model(input_tensor)
+print(output.shape) # Expected: torch.Size([10, 512])
+```
+
+### Custom Configuration
+
+```python
+import torch
+from zeta.nn import FusedDropoutLayerNorm
+
+# Custom configuration
+dropout_rate = 0.2
+epsilon = 1e-6
+elementwise_affine = False
+
+# Initialize the module with custom configuration
+model = FusedDropoutLayerNorm(512, dropout=dropout_rate, eps=epsilon, elementwise_affine=elementwise_affine)
+
+# Sample input
+x = torch.randn(1, 512)
+
+# Forward pass
+output = model(x)
+print(output.shape) # Expected: torch.Size([1, 512])
+```
+
+## Architecture and Working
+
+The `FusedDropoutLayerNorm` module is architecturally a combination of two PyTorch layers: `nn.Dropout` and `nn.LayerNorm`. The fusion of these layers into a single module ensures that the operations are performed sequentially and efficiently, thereby reducing the computational overhead.
+
+- **Dropout**: This operation randomly zeroes some of the elements of the input tensor with probability `dropout` during training. It helps prevent overfitting.
+- **Layer Normalization**: This operation normalizes the input across the features. It stabilizes the learning process and accelerates the training of deep neural networks.
+
+By integrating these two operations, `FusedDropoutLayerNorm` ensures a streamlined process where the dropout is applied first, followed by layer normalization. This design choice is made for computational efficiency and is particularly beneficial in transformer models and other deep learning architectures where both operations are frequently used.
+
+## Purpose and Importance
+
+The primary purpose of `FusedDropoutLayerNorm` is to provide a more efficient way to apply both dropout and layer normalization in a model. This efficiency is particularly crucial in
+
+ large-scale models where computational resources and runtime are significant concerns. The module is designed to be versatile and can be easily integrated into various neural network architectures, especially those involving transformer models.
+
+## Conclusion
+
+The `FusedDropoutLayerNorm` module in PyTorch is a practical and efficient solution for models that require both dropout and layer normalization. Its fused architecture not only enhances computational efficiency but also simplifies the model design process. The module is flexible, allowing for easy customization and integration into diverse neural network architectures.
+
diff --git a/docs/zeta/nn/modules/fused_gelu_dense.md b/docs/zeta/nn/modules/fused_gelu_dense.md
new file mode 100644
index 00000000..77868b86
--- /dev/null
+++ b/docs/zeta/nn/modules/fused_gelu_dense.md
@@ -0,0 +1,140 @@
+# `FusedDenseGELUDense`
+
+## Overview
+
+The `FusedDenseGELUDense` module is a versatile neural network layer designed for efficient computation of dense layers with GELU (Gaussian Error Linear Unit) activations. This documentation will provide an in-depth understanding of the module's architecture, purpose, parameters, and usage examples.
+
+## Table of Contents
+
+1. [Introduction](#introduction)
+2. [Architecture](#architecture)
+3. [Purpose](#purpose)
+4. [Class Definition](#class-definition)
+ - [Parameters](#parameters)
+ - [Internal Layers](#internal-layers)
+5. [Functionality and Usage](#functionality-and-usage)
+ - [Forward Pass](#forward-pass)
+6. [Examples](#examples)
+ - [Basic Usage](#basic-usage)
+ - [Custom Configuration](#custom-configuration)
+ - [Quantization with bitsandbytes](#quantization-with-bitsandbytes)
+7. [Additional Information](#additional-information)
+8. [References](#references)
+
+---
+
+## 1. Introduction
+
+The `FusedDenseGELUDense` module combines dense layers with GELU activations in a single neural network layer. This fusion improves computational efficiency and is particularly useful in various deep learning applications.
+
+## 2. Architecture
+
+The `FusedDenseGELUDense` layer consists of two dense sub-layers, each followed by a GELU activation function. It takes an input tensor and passes it through these sub-layers to produce the final output.
+
+## 3. Purpose
+
+The primary purpose of the `FusedDenseGELUDense` layer is to efficiently compute dense transformations with GELU activations. It is designed for use in neural networks, providing a convenient way to incorporate these operations into deep learning models.
+
+## 4. Class Definition
+
+### Parameters
+
+- `dim` (int): Input dimension.
+- `dim_out` (int): Output dimension.
+- `bias` (bool, optional): Whether to include bias terms. Defaults to True.
+- `has_fp16_weights` (bool, optional): Whether to use fp16 weights. Defaults to False.
+- `threshold` (float, optional): Threshold for quantization. Defaults to 6.0.
+
+### Internal Layers
+
+The `FusedDenseGELUDense` layer consists of the following internal layers:
+
+1. `dense1`: The first dense layer.
+2. `act`: The GELU activation function.
+3. `dense2`: The second dense layer.
+
+## 5. Functionality and Usage
+
+### Forward Pass
+
+The `forward` method of the `FusedDenseGELUDense` layer performs the following operations:
+
+1. Applies the first dense layer (`dense1`) to the input tensor.
+2. Applies the GELU activation function (`act`) to the result.
+3. Applies the second dense layer (`dense2`) to the GELU-activated output.
+
+## 6. Examples
+
+### Basic Usage
+
+Here's a basic example of using the `FusedDenseGELUDense` layer:
+
+```python
+import torch
+from zeta.nn import FusedDenseGELUDense
+
+# Create an instance of FusedDenseGELUDense
+model = FusedDenseGELUDense(dim=512, dim_out=1024)
+
+# Generate random input tensor
+x = torch.randn(1, 512)
+
+# Forward pass
+out = model(x)
+
+# Check the output shape
+print(out.shape) # torch.Size([1, 512])
+```
+
+### Custom Configuration
+
+You can customize the layer by specifying different parameters:
+
+```python
+# Create a custom FusedDenseGELUDense layer
+custom_model = FusedDenseGELUDense(
+ dim=256, dim_out=512, bias=False, has_fp16_weights=True, threshold=4.0
+)
+
+# Generate random input tensor
+x = torch.randn(1, 256)
+
+# Forward pass with the custom configuration
+out = custom_model(x)
+```
+
+### Quantization with bitsandbytes
+
+You can enable quantization using the `bitsandbytes` library by providing a quantized implementation of the dense layers:
+
+```python
+# Install bitsandbytes if not already installed
+# pip install bitsandbytes
+
+import torch
+from zeta.nn import FusedDenseGELUDense
+
+# Create an instance of FusedDenseGELUDense with quantization
+quantized_model = FusedDenseGELUDense(
+ dim=512, dim_out=1024, has_fp16_weights=True, threshold=4.0
+)
+
+# Generate random input tensor
+x = torch.randn(1, 512)
+
+# Forward pass with quantization
+out = quantized_model(x)
+```
+
+## 7. Additional Information
+
+- The `FusedDenseGELUDense` layer efficiently combines dense and GELU activation operations.
+- Custom configurations for bias, weight precision, and threshold are supported.
+- Quantization can be enabled using the `bitsandbytes` library for further efficiency.
+
+## 8. References
+
+For more information on GELU activations and dense layers in PyTorch, refer to the official PyTorch documentation:
+
+- [GELU Activation Function](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html)
+- [Dense Layer](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)
diff --git a/mkdocs.yml b/mkdocs.yml
index 30720331..780107f8 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -68,6 +68,7 @@ nav:
- Home:
- Overview: "index.md"
- Contributing: "contributing.md"
+ - ZetaCloud: "zeta/cloud/main.md"
- Zeta:
- Overview: "zeta/index.md"
- zeta.nn:
@@ -109,6 +110,8 @@ nav:
- MultiModalAdapterDenseNetwork: "zeta/nn/modules/mm_adapter.md"
- CustomMLP: "zeta/nn/modules/custom_mlp.md"
- PolymorphicNeuronLayer: "zeta/nn/modules/polymorphic_activation.md"
+ - FusedDenseGELUDense: "zeta/nn/modules/fused_gelu_dense.md"
+ - FusedDropoutLayerNorm: "zeta/nn/modules/fused_dropout_layernorm.md"
- zeta.nn.attention:
- FlashAttention: "zeta/nn/attention/flash_attention.md"
- MultiQueryAttention: "zeta/nn/attention/multiquery.md"
diff --git a/pyproject.toml b/pyproject.toml
index b70ed317..83fb9e25 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
-version = "0.9.9"
+version = "1.1.7"
description = "Transformers at zeta scales"
authors = ["Zeta Team "]
license = "MIT"
@@ -40,6 +40,9 @@ beartype = "0.16.4"
tiktoken = "0.5.2"
tqdm = "4.66.1"
rich = "13.7.0"
+argparse = "^1.4.0"
+skypilot = "0.4.1"
+
[build-system]
requires = ["poetry-core>=1.0.0"]
@@ -73,6 +76,7 @@ preview = true
-
+[tool.poetry.scripts]
+zeta = 'zeta.cli.main:main'
diff --git a/requirements.txt b/requirements.txt
index 7e8f4724..87e024db 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -28,4 +28,5 @@ torchaudio==2.1.2
mkdocs
mkdocs-material
mkdocs-glightbox
-glightbox
\ No newline at end of file
+skypilot==0.4.1
+argparse
\ No newline at end of file
diff --git a/tests/cloud/main.py b/tests/cloud/main.py
new file mode 100644
index 00000000..46a81395
--- /dev/null
+++ b/tests/cloud/main.py
@@ -0,0 +1,101 @@
+import pytest
+from unittest.mock import MagicMock, patch
+from zeta.cloud.main import zetacloud
+
+
+@patch("zeta.cloud.main.skyapi")
+@patch("zeta.cloud.main.logger")
+def test_zetacloud_basic(mock_logger, mock_skyapi):
+ # Arrange
+ mock_task = MagicMock()
+ mock_skyapi.create_task.return_value = mock_task
+
+ # Act
+ zetacloud(task_name="test_task")
+
+ # Assert
+ mock_skyapi.create_task.assert_called_once_with(
+ name="test_task",
+ setup="pip install requirements.txt",
+ run="python train.py",
+ workdir=".",
+ )
+ mock_logger.info.assert_called_with(
+ "Task: {} has been created".format(mock_task)
+ )
+ mock_task.set_resources.assert_called_once()
+ mock_skyapi.launch.assert_called_once_with(mock_task, "[ZetaTrainingRun]")
+
+
+# ... replicate this test with different arguments for thoroughness
+
+
+@patch("zeta.cloud.main.skyapi")
+@patch("zeta.cloud.main.logger")
+def test_zetacloud_with_stop(mock_logger, mock_skyapi):
+ # Arrange
+ mock_task = MagicMock()
+ mock_skyapi.create_task.return_value = mock_task
+
+ # Act
+ zetacloud(task_name="test_task", stop=True)
+
+ # Assert
+ mock_skyapi.stop.assert_called_once_with("[ZetaTrainingRun]")
+ mock_logger.info.assert_called_with(
+ "Cluster: [ZetaTrainingRun] has been stopped"
+ )
+
+
+@patch("zeta.cloud.main.skyapi")
+@patch("zeta.cloud.main.logger")
+def test_zetacloud_with_down(mock_logger, mock_skyapi):
+ # Arrange
+ mock_task = MagicMock()
+ mock_skyapi.create_task.return_value = mock_task
+
+ # Act
+ zetacloud(task_name="test_task", down=True)
+
+ # Assert
+ mock_skyapi.down.assert_called_once_with("[ZetaTrainingRun]")
+ mock_logger.info.assert_called_with(
+ "Cluster: [ZetaTrainingRun] has been deleted"
+ )
+
+
+@patch("zeta.cloud.main.skyapi")
+@patch("zeta.cloud.main.logger")
+def test_zetacloud_with_status_report(mock_logger, mock_skyapi):
+ # Arrange
+ mock_task = MagicMock()
+ mock_skyapi.create_task.return_value = mock_task
+
+ # Act
+ zetacloud(task_name="test_task", status_report=True)
+
+ # Assert
+ mock_skyapi.status.assert_called_once_with(
+ cluster_names=["[ZetaTrainingRun]"]
+ )
+ mock_logger.info.assert_called_with(
+ "Cluster: [ZetaTrainingRun] has been reported on"
+ )
+
+
+@patch("zeta.cloud.main.skyapi")
+@patch("zeta.cloud.main.logger")
+def test_zetacloud_with_exception(mock_logger, mock_skyapi):
+ # Arrange
+ mock_skyapi.create_task.side_effect = Exception("Test exception")
+
+ # Act
+ with pytest.raises(Exception):
+ zetacloud(task_name="test_task")
+
+ # Assert
+ mock_logger.error.assert_called_once()
+
+
+# ... replicate similar tests with minor changes for thoroughness
+# Examples: test different cloud providers, test other parameter combinations, etc.
diff --git a/tests/nn/modules/test_fused_dropout_layernom.py b/tests/nn/modules/test_fused_dropout_layernom.py
new file mode 100644
index 00000000..e38567d8
--- /dev/null
+++ b/tests/nn/modules/test_fused_dropout_layernom.py
@@ -0,0 +1,70 @@
+import torch
+from torch import nn
+from zeta.nn.modules.fused_dropout_layernom import FusedDropoutLayerNorm
+
+
+def test_class_init():
+ model = FusedDropoutLayerNorm(512)
+
+ assert isinstance(model.dropout, nn.Dropout)
+ assert isinstance(model.layer_norm, nn.LayerNorm)
+
+
+def test_class_init_with_args():
+ model = FusedDropoutLayerNorm(
+ 512, dropout=0.2, eps=1e-6, elementwise_affine=False
+ )
+
+ assert isinstance(model.dropout, nn.Dropout)
+ assert isinstance(model.layer_norm, nn.LayerNorm)
+ assert model.dropout.p == 0.2
+ assert model.layer_norm.eps == 1e-6
+ assert model.layer_norm.elementwise_affine is False
+
+
+def test_forward():
+ model = FusedDropoutLayerNorm(512)
+ x = torch.randn(1, 512)
+ out = model(x)
+
+ assert out.shape == torch.Size([1, 512])
+
+
+def test_forward_with_different_input():
+ model = FusedDropoutLayerNorm(512)
+ x = torch.randn(2, 512)
+ out = model(x)
+
+ assert out.shape == torch.Size([2, 512])
+
+
+def test_forward_with_different_dim():
+ model = FusedDropoutLayerNorm(256)
+ x = torch.randn(1, 256)
+ out = model(x)
+
+ assert out.shape == torch.Size([1, 256])
+
+
+def test_forward_with_different_dropout():
+ model = FusedDropoutLayerNorm(512, dropout=0.2)
+ x = torch.randn(1, 512)
+ out = model(x)
+
+ assert out.shape == torch.Size([1, 512])
+
+
+def test_forward_with_different_eps():
+ model = FusedDropoutLayerNorm(512, eps=1e-6)
+ x = torch.randn(1, 512)
+ out = model(x)
+
+ assert out.shape == torch.Size([1, 512])
+
+
+def test_forward_with_no_elementwise_affine():
+ model = FusedDropoutLayerNorm(512, elementwise_affine=False)
+ x = torch.randn(1, 512)
+ out = model(x)
+
+ assert out.shape == torch.Size([1, 512])
diff --git a/tests/nn/modules/test_fused_gelu_dense.py b/tests/nn/modules/test_fused_gelu_dense.py
new file mode 100644
index 00000000..f0390bf7
--- /dev/null
+++ b/tests/nn/modules/test_fused_gelu_dense.py
@@ -0,0 +1,81 @@
+import pytest
+import torch
+from zeta.nn.modules.fused_gelu_dense import FusedDenseGELUDense
+
+
+def test_class_init():
+ model = FusedDenseGELUDense(512, 1024)
+
+ assert model.dim == 512
+ assert model.dim_out == 1024
+ assert model.bias == True
+ assert model.has_fp16_weights == False
+ assert model.threshold == 6.0
+
+
+def test_class_init_with_args():
+ model = FusedDenseGELUDense(
+ 512, 1024, bias=False, has_fp16_weights=True, threshold=5.0
+ )
+
+ assert model.dim == 512
+ assert model.dim_out == 1024
+ assert model.bias == False
+ assert model.has_fp16_weights == True
+ assert model.threshold == 5.0
+
+
+def test_forward():
+ model = FusedDenseGELUDense(512, 1024)
+ x = torch.randn(1, 512)
+ out = model(x)
+
+ assert out.shape == torch.Size([1, 512])
+
+
+def test_forward_with_different_input():
+ model = FusedDenseGELUDense(512, 1024)
+ x = torch.randn(2, 512)
+ out = model(x)
+
+ assert out.shape == torch.Size([2, 512])
+
+
+def test_forward_with_different_dim():
+ model = FusedDenseGELUDense(256, 512)
+ x = torch.randn(1, 256)
+ out = model(x)
+
+ assert out.shape == torch.Size([1, 256])
+
+
+def test_forward_with_different_dim_out():
+ model = FusedDenseGELUDense(512, 2048)
+ x = torch.randn(1, 512)
+ out = model(x)
+
+ assert out.shape == torch.Size([1, 512])
+
+
+def test_forward_with_no_bias():
+ model = FusedDenseGELUDense(512, 1024, bias=False)
+ x = torch.randn(1, 512)
+ out = model(x)
+
+ assert out.shape == torch.Size([1, 512])
+
+
+def test_forward_with_fp16_weights():
+ model = FusedDenseGELUDense(512, 1024, has_fp16_weights=True)
+ x = torch.randn(1, 512)
+ out = model(x)
+
+ assert out.shape == torch.Size([1, 512])
+
+
+def test_forward_with_different_threshold():
+ model = FusedDenseGELUDense(512, 1024, threshold=5.0)
+ x = torch.randn(1, 512)
+ out = model(x)
+
+ assert out.shape == torch.Size([1, 512])
diff --git a/tests/nn/modules/test_image_projector.py b/tests/nn/modules/test_image_projector.py
index 58f3e2a2..92d696d9 100644
--- a/tests/nn/modules/test_image_projector.py
+++ b/tests/nn/modules/test_image_projector.py
@@ -249,17 +249,13 @@ def test_patch_projector_output_shape_consistency(sample_input_tensor):
# Test case for edge case: invalid max_patch_size
def test_patch_projector_invalid_max_patch_size():
with pytest.raises(ValueError):
- ImagePatchCreatorProjector(
- max_patch_size=0, embedding_dim=768
- )
+ ImagePatchCreatorProjector(max_patch_size=0, embedding_dim=768)
# Test case for edge case: invalid embedding_dim
def test_patch_projector_invalid_embedding_dim():
with pytest.raises(ValueError):
- ImagePatchCreatorProjector(
- max_patch_size=16, embedding_dim=0
- )
+ ImagePatchCreatorProjector(max_patch_size=16, embedding_dim=0)
# Test case for edge case: invalid input tensor shape
diff --git a/tests/ops/test_einops_poly.py b/tests/ops/test_einops_poly.py
index a1ad7c44..85f0f14e 100644
--- a/tests/ops/test_einops_poly.py
+++ b/tests/ops/test_einops_poly.py
@@ -148,9 +148,7 @@ def test_reduce_many_with_sum_reduction():
# Additional tests for rearrange_with_anon_dims function
def test_rearrange_with_anon_dims_invalid_dim_list():
with pytest.raises(ValueError):
- rearrange_with_anon_dims(
- input_data, pattern="...a b c", a=(1,)
- )
+ rearrange_with_anon_dims(input_data, pattern="...a b c", a=(1,))
def test_rearrange_with_anon_dims_invalid_pattern():
diff --git a/tests/optim/lion8b.py b/tests/optim/lion8b.py
new file mode 100644
index 00000000..75fa2b8b
--- /dev/null
+++ b/tests/optim/lion8b.py
@@ -0,0 +1,131 @@
+import pytest
+import torch
+from zeta.optim.lion8b import DecoupledLionW_8bit
+
+
+def test_optimizer_init():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ optimizer = DecoupledLionW_8bit(params)
+
+ assert len(optimizer.param_groups) == 1
+ assert optimizer.param_groups[0]["lr"] == 1e-3
+ assert optimizer.param_groups[0]["betas"] == (0.9, 0.99)
+ assert optimizer.param_groups[0]["weight_decay"] == 0
+
+
+def test_optimizer_init_invalid_lr():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ with pytest.raises(ValueError):
+ DecoupledLionW_8bit(params, lr=-1)
+
+
+def test_optimizer_init_invalid_betas():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ with pytest.raises(ValueError):
+ DecoupledLionW_8bit(params, betas=(-1, 0.99))
+ with pytest.raises(ValueError):
+ DecoupledLionW_8bit(params, betas=(0.9, -1))
+
+
+def test_optimizer_init_invalid_weight_decay():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ with pytest.raises(ValueError):
+ DecoupledLionW_8bit(params, weight_decay=-1)
+
+
+def test_step_without_closure():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ optimizer = DecoupledLionW_8bit(params)
+ loss = optimizer.step()
+
+ assert loss is None
+
+
+def test_step_with_closure():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ optimizer = DecoupledLionW_8bit(params)
+ closure = lambda: torch.sum(params[0] ** 2 + params[1] ** 2)
+ loss = optimizer.step(closure)
+
+ assert loss is not None
+ assert loss == closure()
+
+
+def test_step_param_no_grad():
+ params = [torch.randn(3, 3, requires_grad=False) for _ in range(2)]
+ optimizer = DecoupledLionW_8bit(params)
+ optimizer.step_param(params[0], optimizer.param_groups[0])
+
+ assert params[0].grad is None
+
+
+def test_step_param_with_grad():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ optimizer = DecoupledLionW_8bit(params)
+ closure = lambda: torch.sum(params[0] ** 2 + params[1] ** 2)
+ closure().backward()
+ optimizer.step_param(params[0], optimizer.param_groups[0])
+
+ assert params[0].grad is not None
+
+
+def test_step_param_not_cuda():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ optimizer = DecoupledLionW_8bit(params, quantize=True)
+ closure = lambda: torch.sum(params[0] ** 2 + params[1] ** 2)
+ closure().backward()
+
+ with pytest.raises(NotImplementedError):
+ optimizer.step_param(params[0], optimizer.param_groups[0])
+
+
+def test_optimizer_init_invalid_weight_decay():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ with pytest.raises(ValueError):
+ DecoupledLionW_8bit(params, weight_decay=-1)
+
+
+def test_step_without_closure():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ optimizer = DecoupledLionW_8bit(params)
+ loss = optimizer.step()
+
+ assert loss is None
+
+
+def test_step_with_closure():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ optimizer = DecoupledLionW_8bit(params)
+ closure = lambda: torch.sum(params[0] ** 2 + params[1] ** 2)
+ loss = optimizer.step(closure)
+
+ assert loss is not None
+ assert loss == closure()
+
+
+def test_step_param_no_grad():
+ params = [torch.randn(3, 3, requires_grad=False) for _ in range(2)]
+ optimizer = DecoupledLionW_8bit(params)
+ optimizer.step_param(params[0], optimizer.param_groups[0])
+
+ assert params[0].grad is None
+
+
+def test_step_param_with_grad():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ optimizer = DecoupledLionW_8bit(params)
+ closure = lambda: torch.sum(params[0] ** 2 + params[1] ** 2)
+ closure().backward()
+ optimizer.step_param(params[0], optimizer.param_groups[0])
+
+ assert params[0].grad is not None
+
+
+def test_step_param_not_cuda():
+ params = [torch.randn(3, 3, requires_grad=True) for _ in range(2)]
+ optimizer = DecoupledLionW_8bit(params, quantize=True)
+ closure = lambda: torch.sum(params[0] ** 2 + params[1] ** 2)
+ closure().backward()
+
+ with pytest.raises(NotImplementedError):
+ optimizer.step_param(params[0], optimizer.param_groups[0])
diff --git a/tests/test_test_example.py b/tests/test_test_example.py
index ad15eee2..b707a6d9 100644
--- a/tests/test_test_example.py
+++ b/tests/test_test_example.py
@@ -1,4 +1,4 @@
-from zeta import MultiheadAttention
+
import time
import unittest
diff --git a/zeta/__init__.py b/zeta/__init__.py
index 31ae3141..e0099777 100644
--- a/zeta/__init__.py
+++ b/zeta/__init__.py
@@ -11,3 +11,4 @@
from zeta.optim import * # noqa: F403, E402
from zeta.ops import * # noqa: F403, E402
from zeta.quant import * # noqa: F403, E402
+from zeta.cloud import * # noqa: F403, E402
diff --git a/zeta/cli/__init__.py b/zeta/cli/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/zeta/cli/main.py b/zeta/cli/main.py
new file mode 100644
index 00000000..98b5e2dc
--- /dev/null
+++ b/zeta/cli/main.py
@@ -0,0 +1,66 @@
+import argparse
+from zeta.cloud.main import zetacloud
+
+
+def main():
+ """Main function for the CLI
+
+ Args:
+ task_name (str, optional): _description_. Defaults to None.
+ cluster_name (str, optional): _description_. Defaults to "[ZetaTrainingRun]".
+ cloud (Any, optional): _description_. Defaults to AWS().
+ gpus (str, optional): _description_. Defaults to None.
+
+ Examples:
+ $ zetacloud -t "test" -c "[ZetaTrainingRun]" -cl AWS -g "1 V100"
+
+
+ """
+ parser = argparse.ArgumentParser(description="Zetacloud CLI")
+ parser.add_argument("-t", "--task_name", type=str, help="Task name")
+ parser.add_argument(
+ "-c",
+ "--cluster_name",
+ type=str,
+ default="[ZetaTrainingRun]",
+ help="Cluster name",
+ )
+ parser.add_argument(
+ "-cl", "--cloud", type=str, default="AWS", help="Cloud provider"
+ )
+ parser.add_argument("-g", "--gpus", type=str, help="GPUs")
+ parser.add_argument(
+ "-f", "--filename", type=str, default="train.py", help="Filename"
+ )
+ parser.add_argument("-s", "--stop", action="store_true", help="Stop flag")
+ parser.add_argument("-d", "--down", action="store_true", help="Down flag")
+ parser.add_argument(
+ "-sr", "--status_report", action="store_true", help="Status report flag"
+ )
+
+ # Generate API key
+ # parser.add_argument(
+ # "-k", "--generate_api_key", action="store_true", help="Generate key flag"
+ # )
+
+ # Sign In
+ # parser.add_argument(
+ # "-si", "--sign_in", action="store_true", help="Sign in flag"
+ # )
+
+ args = parser.parse_args()
+
+ zetacloud(
+ task_name=args.task_name,
+ cluster_name=args.cluster_name,
+ cloud=args.cloud,
+ gpus=args.gpus,
+ filename=args.filename,
+ stop=args.stop,
+ down=args.down,
+ status_report=args.status_report,
+ )
+
+
+# if __name__ == "__main__":
+# main()
diff --git a/zeta/cloud/__init__.py b/zeta/cloud/__init__.py
new file mode 100644
index 00000000..05c279eb
--- /dev/null
+++ b/zeta/cloud/__init__.py
@@ -0,0 +1,4 @@
+from zeta.cloud.sky_api import SkyInterface
+from zeta.cloud.main import zetacloud
+
+__all__ = ["zetacloud", "SkyInterface"]
diff --git a/zeta/cloud/main.py b/zeta/cloud/main.py
new file mode 100644
index 00000000..3d46183d
--- /dev/null
+++ b/zeta/cloud/main.py
@@ -0,0 +1,73 @@
+import logging
+from typing import Any
+
+from sky import AWS, Resources
+
+from zeta.cloud.sky_api import SkyInterface
+
+skyapi = SkyInterface(stream_logs_enabled=True)
+
+
+# Logger
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+def zetacloud(
+ task_name: str = None,
+ cluster_name: str = "ZetaTrainingRun",
+ setup: str = "pip install -r requirements.txt",
+ cloud: Any = AWS(),
+ gpus: str = "V100:4",
+ filename: str = "train.py",
+ stop: bool = False,
+ down: bool = False,
+ status_report: bool = False,
+ *args,
+ **kwargs,
+):
+ """zetacloud
+
+ Args:
+ task_name (str, optional): _description_. Defaults to None.
+ cluster_name (str, optional): _description_. Defaults to "[ZetaTrainingRun]".
+ cloud (Any, optional): _description_. Defaults to AWS().
+ gpus (str, optional): _description_. Defaults to None.
+ """
+ try:
+ task = skyapi.create_task(
+ name=task_name,
+ setup=setup,
+ run=f"python {filename}",
+ workdir=".",
+ )
+ logger.info(f"Task: {task} has been created")
+
+ # Set the resources
+ task.set_resources(Resources(accelerators=gpus))
+ # logger.info(f"Resources: {task.resources} have been set")
+
+ # Execute the task on the cluster
+ execution = skyapi.launch(task, cluster_name)
+ print(execution)
+ logger.info(
+ f"Task: {task} has been launched on cluster: {cluster_name}"
+ )
+
+ if stop:
+ skyapi.stop(cluster_name)
+ logger.info(f"Cluster: {cluster_name} has been stopped")
+
+ if down:
+ skyapi.down(cluster_name)
+ logger.info(f"Cluster: {cluster_name} has been deleted")
+
+ if status_report:
+ skyapi.status(cluster_names=[cluster_name])
+ logger.info(f"Cluster: {cluster_name} has been reported on")
+
+ except Exception as error:
+ print(
+ f"There has been an error: {error} the root cause is:"
+ f" {error.__cause__}"
+ )
diff --git a/zeta/cloud/sky_api.py b/zeta/cloud/sky_api.py
new file mode 100644
index 00000000..6fd1f776
--- /dev/null
+++ b/zeta/cloud/sky_api.py
@@ -0,0 +1,202 @@
+from typing import List
+
+import sky
+from sky import Task
+
+
+class SkyInterface:
+ """
+
+ SkyInterface is a wrapper around the sky Python API. It provides a
+ simplified interface for launching, executing, stopping, starting, and
+ tearing down clusters.
+
+ Attributes:
+ clusters (dict): A dictionary of clusters that have been launched.
+ The keys are the names of the clusters and the values are the handles
+ to the clusters.
+
+ Methods:
+ launch: Launch a cluster
+ execute: Execute a task on a cluster
+ stop: Stop a cluster
+ start: Start a cluster
+ down: Tear down a cluster
+ status: Get the status of a cluster
+ autostop: Set the autostop of a cluster
+
+ Example:
+ >>> sky_interface = SkyInterface()
+ >>> job_id = sky_interface.launch("task", "cluster_name")
+ >>> sky_interface.execute("task", "cluster_name")
+ >>> sky_interface.stop("cluster_name")
+ >>> sky_interface.start("cluster_name")
+ >>> sky_interface.down("cluster_name")
+ >>> sky_interface.status()
+ >>> sky_interface.autostop("cluster_name")
+
+
+ """
+
+ def __init__(
+ self,
+ task_name: str = None,
+ cluster_name: str = None,
+ gpus: str = "T4:1",
+ stream_logs_enabled: bool = False,
+ *args,
+ **kwargs,
+ ):
+ self.task_name = task_name
+ self.cluster_name = cluster_name
+ self.gpus = gpus
+ self.stream_logs_enabled = stream_logs_enabled
+ self.clusters = {}
+
+ def launch(self, task: Task = None, cluster_name: str = None, **kwargs):
+ """Launch a task on a cluster
+
+ Args:
+ task (str): code to execute on the cluster
+ cluster_name (_type_, optional): _description_. Defaults to None.
+
+ Returns:
+ _type_: _description_
+ """
+ cluster = None
+ try:
+ cluster = sky.launch(
+ task=task,
+ cluster_name=cluster_name,
+ stream_logs=self.stream_logs_enabled,
+ **kwargs,
+ )
+ print(f"Launched job {cluster} on cluster {cluster_name}")
+ return cluster
+ except Exception as error:
+ # Deep error logging
+ print(
+ f"Error launching job {cluster} on cluster {cluster_name} with"
+ f" error {error}"
+ )
+ raise error
+
+ def execute(self, task: Task = None, cluster_name: str = None, **kwargs):
+ """Execute a task on a cluster
+
+ Args:
+ task (_type_): _description_
+ cluster_name (_type_): _description_
+
+ Raises:
+ ValueError: _description_
+
+ Returns:
+ _type_: _description_
+ """
+ if cluster_name not in self.clusters:
+ raise ValueError("Cluster {} does not exist".format(cluster_name))
+ try:
+ return sky.exec(
+ task=task,
+ cluster_name=cluster_name,
+ stream_logs=self.stream_logs_enabled,
+ **kwargs,
+ )
+ except Exception as e:
+ print("Error executing on cluster:", e)
+
+ def stop(self, cluster_name: str = None, **kwargs):
+ """Stop a cluster
+
+ Args:
+ cluster_name (str): name of the cluster to stop
+ """
+ try:
+ sky.stop(cluster_name, **kwargs)
+ except (ValueError, RuntimeError) as e:
+ print("Error stopping cluster:", e)
+
+ def start(self, cluster_name: str = None, **kwargs):
+ """start a cluster
+
+ Args:
+ cluster_name (str): name of the cluster to start
+ """
+ try:
+ sky.start(cluster_name, **kwargs)
+ except Exception as e:
+ print("Error starting cluster:", e)
+
+ def down(self, cluster_name: str = None, **kwargs):
+ """Down a cluster
+
+ Args:
+ cluster_name (str): name of the cluster to tear down
+ """
+ try:
+ sky.down(cluster_name, **kwargs)
+ if cluster_name in self.clusters:
+ del self.clusters[cluster_name]
+ except (ValueError, RuntimeError) as e:
+ print("Error tearing down cluster:", e)
+
+ def status(self, cluster_names: List[str] = None, **kwargs):
+ """Save a cluster
+
+ Returns:
+ r: the status of the cluster
+ """
+ try:
+ return sky.status(cluster_names, **kwargs)
+ except Exception as e:
+ print("Error getting status:", e)
+
+ def autostop(self, cluster_name: str = None, **kwargs):
+ """Autostop a cluster
+
+ Args:
+ cluster_name (str): name of the cluster to autostop
+ """
+ try:
+ sky.autostop(cluster_name, **kwargs)
+ except Exception as e:
+ print("Error setting autostop:", e)
+
+ def create_task(
+ self,
+ name: str = None,
+ setup: str = None,
+ run: str = None,
+ workdir: str = None,
+ task: str = None,
+ *args,
+ **kwargs,
+ ):
+ """_summary_
+
+ Args:
+ name (str, optional): _description_. Defaults to None.
+ setup (str, optional): _description_. Defaults to None.
+ run (str, optional): _description_. Defaults to None.
+ workdir (str, optional): _description_. Defaults to None.
+ task (str, optional): _description_. Defaults to None.
+
+ Returns:
+ _type_: _description_
+
+ # A Task that will sync up local workdir '.', containing
+ # requirements.txt and train.py.
+ sky.Task(setup='pip install requirements.txt',
+ run='python train.py',
+ workdir='.')
+
+ # An empty Task for provisioning a cluster.
+ task = sky.Task(num_nodes=n).set_resources(...)
+
+ # Chaining setters.
+ sky.Task().set_resources(...).set_file_mounts(...)
+ """
+ return Task(
+ name=name, setup=setup, run=run, workdir=workdir, *args, **kwargs
+ )
diff --git a/zeta/models/__init__.py b/zeta/models/__init__.py
index 9dab6ca3..5d17fc25 100644
--- a/zeta/models/__init__.py
+++ b/zeta/models/__init__.py
@@ -22,4 +22,4 @@
"LLama2",
"Andromeda",
"NaViT",
-]
\ No newline at end of file
+]
diff --git a/zeta/nn/modules/fused_dropout_layernom.py b/zeta/nn/modules/fused_dropout_layernom.py
new file mode 100644
index 00000000..8850d47b
--- /dev/null
+++ b/zeta/nn/modules/fused_dropout_layernom.py
@@ -0,0 +1,51 @@
+import torch
+from torch import nn
+
+
+class FusedDropoutLayerNorm(nn.Module):
+ """FusedDropoutLayerNorm
+
+ Args:
+ dim (int): Input dimension
+ dropout (float, optional): Dropout. Defaults to 0.1.
+ eps (float, optional): Epsilon. Defaults to 1e-5.
+ elementwise_affine (bool, optional): Elementwise affine. Defaults to True.
+
+ Examples:
+ >>> x = torch.randn(1, 512)
+ >>> model = FusedDropoutLayerNorm(512)
+ >>> out = model(x)
+ >>> out.shape
+ torch.Size([1, 512])
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dropout: float = 0.1,
+ eps: float = 1e-5,
+ elementwise_affine: bool = True,
+ *args,
+ **kwargs,
+ ):
+ super(FusedDropoutLayerNorm, self).__init__()
+
+ # Dropout initialization
+ self.dropout = nn.Dropout(dropout)
+
+ # LayerNorm initialization
+ self.layer_norm = nn.LayerNorm(
+ dim, eps=eps, elementwise_affine=elementwise_affine, *args, **kwargs
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass
+
+ Args:
+ x (torch.Tensor): tensor
+
+ Returns:
+
+ """
+ x = self.dropout(x)
+ return self.layer_norm(x)
diff --git a/zeta/nn/modules/fused_gelu_dense.py b/zeta/nn/modules/fused_gelu_dense.py
new file mode 100644
index 00000000..885ac458
--- /dev/null
+++ b/zeta/nn/modules/fused_gelu_dense.py
@@ -0,0 +1,87 @@
+import torch
+from torch import nn
+
+
+class FusedDenseGELUDense(nn.Module):
+ """FuseFusedDenseGELUDense
+
+ Args
+ dim (int): Input dimension
+ dim_out (int): Output dimension
+ bias (bool, optional): Bias. Defaults to True.
+ has_fp16_weights (bool, optional): Use fp16 weights. Defaults to False.
+ threshold (float, optional): Threshold for quantization. Defaults to 6.0.
+
+ Examples:
+ >>> x = torch.randn(1, 512)
+ >>> model = FusedDenseGELUDense(512, 1024)
+ >>> out = model(x)
+ >>> out.shape
+ torch.Size([1, 512])
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ bias: bool = True,
+ has_fp16_weights: bool = False,
+ threshold: float = 6.0,
+ *args,
+ **kwargs,
+ ):
+ super(FusedDenseGELUDense, self).__init__()
+ self.dim = dim
+ self.dim_out = dim_out
+ self.bias = bias
+ self.has_fp16_weights = has_fp16_weights
+ self.threshold = threshold
+
+ try:
+ import bitsandbytes as bnb
+
+ # Using bitsandbytes for quantization
+ self.dense1 = bnb.nn.Linear8bitLt(
+ dim,
+ dim_out,
+ bias=bias,
+ has_fp16_weights=has_fp16_weights,
+ threshold=threshold,
+ *args,
+ **kwargs,
+ )
+
+ # Reverse
+ self.dense2 = bnb.nn.Linear8bitLt(
+ dim_out,
+ dim,
+ bias=bias,
+ has_fp16_weights=has_fp16_weights,
+ threshold=threshold,
+ *args,
+ **kwargs,
+ )
+
+ except ModuleNotFoundError:
+ # Using torch.nn.Linear
+ self.dense1 = nn.Linear(dim, dim_out, bias=bias * args, **kwargs)
+
+ # Dense 2
+ self.dense2 = nn.Linear(dim_out, dim, bias=bias * args, **kwargs)
+
+ # Activation
+ self.act = nn.GELU()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass
+
+ Args:
+ x (torch.Tensor): x input
+
+ Returns:
+ torch.Tensor: _description_
+ """
+ x = self.dense1(x)
+ x = self.act(x)
+ x = self.dense2(x)
+ return x
diff --git a/zeta/nn/modules/simple_mamba.py b/zeta/nn/modules/simple_mamba.py
new file mode 100644
index 00000000..67a1a959
--- /dev/null
+++ b/zeta/nn/modules/simple_mamba.py
@@ -0,0 +1,52 @@
+import torch
+from torch import nn
+from zeta.nn.modules.rms_norm import RMSNorm
+from zeta.nn.modules.residual import Residual
+
+
+class Mamba(nn.Module):
+ def __init__(
+ self,
+ vocab_size: int,
+ dim: int,
+ depth: int,
+ bias: bool = False,
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+ self.embedding = nn.Embedding(vocab_size, dim)
+ self.layers = nn.ModuleList(
+ [
+ Residual(self.rmsnorm, nn.Linear(dim, dim, bias=bias))
+ for _ in range(depth)
+ ]
+ )
+ self.rmsnorm = RMSNorm(dim)
+ self.linear = nn.Linear(dim, vocab_size, bias=bias)
+ self.linear.weight = self.embedding.weight
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.embedding(x)
+
+ for layer in self.layers:
+ x = layer(x)
+
+ x = self.rmsnorm(x)
+ logits = self.linear(x)
+
+ return logits
+
+
+# class MambaBlock(nn.Module):
+# def __init__(
+# self,
+# dim,
+# inner_dim,
+# bias: bool = False,
+# conv_bias=None,
+# dim_conv=None,
+# *args,
+# **kwargs,
+# ):
+# super().__init__()
diff --git a/zeta/optim/__init__.py b/zeta/optim/__init__.py
index 5b6cea92..b7e81e34 100644
--- a/zeta/optim/__init__.py
+++ b/zeta/optim/__init__.py
@@ -12,6 +12,7 @@
from zeta.optim.stable_adam import StableAdamWUnfused
from zeta.optim.gradient_ascent import GradientAscent
from zeta.optim.gradient_equillibrum import GradientEquilibrum
+from zeta.optim.lion8b import DecoupledLionW8Bit
__all__ = [
"BatchedOptimizer",
@@ -26,4 +27,5 @@
"StableAdamWUnfused",
"GradientAscent",
"GradientEquilibrum",
+ "DecoupledLionW8Bit",
]
diff --git a/zeta/optim/batched_optimizer.py b/zeta/optim/batched_optimizer.py
index 36cc0b5e..8b0300a8 100644
--- a/zeta/optim/batched_optimizer.py
+++ b/zeta/optim/batched_optimizer.py
@@ -206,7 +206,6 @@ def step(self, closure=None):
with torch.enable_grad():
loss = closure()
-
for group, group_params_names in zip(
self.param_groups, self.parameters_names
):
diff --git a/zeta/optim/lion8b.py b/zeta/optim/lion8b.py
new file mode 100644
index 00000000..31e147a1
--- /dev/null
+++ b/zeta/optim/lion8b.py
@@ -0,0 +1,490 @@
+from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
+
+import torch
+
+
+class DecoupledLionW8Bit(torch.optim.Optimizer):
+ """LION optimizer with ~8 bits of state per parameter.
+
+ This optimizer is a drop-in replacement for our regular LION optimizer
+ with decoupled weight decay, but uses less memory, writes smaller
+ checkpoints, and offers almost-numerically-identical convergence.
+
+ Its state saved per parameter is just an int8, though there are auxiliary
+ scaling factors that bring the total memory per parameter to ~8.5 bits.
+ The exact quantization scheme is considered an implementation detail
+ and may change.
+
+ When training on CPUs, however, no quantization will actually take place.
+
+ See the LION paper (https://arxiv.org/abs/2302.06675) for details about
+ the algorithm itself.
+
+ Args:
+ params: iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr: learning rate
+ betas: two coefficients between 0 and 1 used to combine the current
+ gradients and the momentum. The first coefficient is the weight
+ of the gradient when computing the update. The second is the
+ weight of the gradient when computing the new momentum.
+ weight decay: Weights are multiplied by 1 - `weight_decay` after
+ each optimizer step. Note that we use decoupled weight decay,
+ meaning that this decay does not contribute to the momentum.
+ compress_state_dict: if True, this optimizer's `state_dict` will
+ include quantized optimizer states. Otherwise, the optimizer
+ states are converted to bfloat16 Tensors matching the shapes of
+ their corresponding parameters. The former uses ~8.5 bits per
+ parameter while the latter uses 16 bits per parameter. However,
+ the former is less thoroughly tested and will not work with
+ FSDP or other weight sharding approaches.
+ quantize: If False, optimizer states will not actually be quantized.
+ This option is available so that one can easily debug whether
+ the quantization is causing any convergence issues. Because
+ quantization is only supported for CUDA parameters, attempting to
+ update a non-CUDA tensor will raise an error.
+ error_correction: If True, float16 and bfloat16 parameters will be
+ given an extra state variable, "errors." This tensor will be
+ of the same shape as the parameter but of dtype uint8. This
+ auxiliary variable is used to better approximate float32 updates
+ by retaining information across optimizer steps.
+
+ Raises:
+ NotImplementedError - If any of `quantize`, `compress_state_dict`,
+ or `error_correction` are `True` and either a) there is no CUDA
+ device, or b) step() is executed on a non-CUDA parameter.
+ """
+
+ def __init__(
+ self,
+ params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]],
+ lr: float = 1e-3,
+ betas: Tuple[float, float] = (0.9, 0.99),
+ weight_decay: float = 0,
+ quantize: bool = True,
+ compress_state_dict: bool = False,
+ error_correction: bool = False,
+ _fused: bool = True, # XXX this flag is mostly for testing...
+ ):
+ if lr < 0.0:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= betas[0] <= 1.0:
+ raise ValueError(
+ "Invalid beta parameter at index 0: {}".format(betas[0])
+ )
+ if not 0.0 <= betas[1] <= 1.0:
+ raise ValueError(
+ "Invalid beta parameter at index 1: {}".format(betas[1])
+ )
+ if not 0.0 <= weight_decay:
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
+
+ if not torch.cuda.is_available():
+ needs_cuda = " requires a CUDA device."
+ if quantize:
+ raise NotImplementedError("Quantization" + needs_cuda)
+ if error_correction:
+ raise NotImplementedError("Error correction" + needs_cuda)
+ if compress_state_dict:
+ raise NotImplementedError("Quantized state dict" + needs_cuda)
+
+ _fused = _fused and quantize
+ self._quantize = quantize
+ self._error_correction = error_correction
+ self._compress_state_dict = compress_state_dict
+
+ defaults = {
+ "lr": lr,
+ "initial_lr": lr,
+ "betas": betas,
+ "weight_decay": weight_decay,
+ "fused": _fused,
+ }
+ super().__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self, closure: Optional[Callable] = None):
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ self.step_param(p, group)
+
+ return loss
+
+ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None:
+ if not p.requires_grad or p.grad is None:
+ return
+ if self._quantize and not p.is_cuda:
+ raise NotImplementedError(
+ f"Can't use quantization with param on {p.device} "
+ + f"({p.shape}, {p.dtype}). If you need "
+ + "to use DecoupledLionW_8bit without a CUDA device, try "
+ + "creating this optimizer with quantize=False."
+ )
+ state = self.state[p] # type:ignore using tensor as key
+ if "exp_avg" not in state:
+ mom = torch.zeros_like(p)
+ state["exp_avg"] = _MaybeQuantizedTensor(
+ mom, try_quantize=self._quantize
+ )
+ need_errs = (p.dtype != torch.float32) and self._error_correction
+ if state.get("errors") is None and need_errs:
+ numel = p.numel()
+ numel += numel % 2 # ensure even number of bytes
+ errors = torch.zeros(numel, dtype=torch.uint8, device=p.device)
+ # as of torch 2.1, FSDP can't shard ints for no reason
+ state["errors"] = errors.view(torch.bfloat16)
+ decay_factor = hparams["weight_decay"]
+ decay_factor *= hparams["lr"] / hparams["initial_lr"]
+ errors: Optional[torch.Tensor] = None
+ if "errors" in state:
+ errors = state["errors"]
+ assert errors is not None # pyright
+ errors = errors.view(dtype=torch.uint8)
+ errors = errors[: p.numel()].view(
+ p.shape
+ ) # strip padding + reshape
+ _lion8b_step(
+ momentums=state["exp_avg"],
+ weights=p,
+ grads=p.grad,
+ beta1=hparams["betas"][0],
+ beta2=hparams["betas"][1],
+ lr=hparams["lr"],
+ weight_decay=decay_factor,
+ fused=hparams["fused"],
+ errors=errors,
+ )
+
+ def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
+ # we override this function to quantize optimizer states when
+ # loading a state dict
+ opt_state, _ = state.values() # other val is param_groups
+ for param_id in opt_state:
+ param_state = opt_state[param_id]
+ new_state = {}
+ if any(k.startswith("exp_avg") for k in param_state):
+ # the keys can either be just "exp_avg" or
+ # "exp_avg::quantized" and "exp_avg::scales", depending on
+ # whether we saved it as quantized or not. The former case
+ # gives us interop with regular LION.
+ qtensor = _MaybeQuantizedTensor(
+ None, try_quantize=self._quantize
+ )
+ qtensor.load_state_dict(param_state, name="exp_avg")
+ new_state["exp_avg"] = qtensor
+ if "errors" in param_state:
+ # we need to cast back to the correct dtype since optimizer
+ # load_state_dict casts to param dtype for fp params; see
+ # https://github.com/pytorch/pytorch/blob/a25eee1d77d93079614fab3ea4ac66e64fb2343b/torch/optim/optimizer.py#L626C7-L626C7 # noqa
+ errs = (
+ param_state["errors"]
+ .to(dtype=torch.uint8)
+ .view(torch.bfloat16)
+ )
+ new_state["errors"] = errs
+ opt_state[param_id] = new_state
+ super().__setstate__(state)
+
+ def state_dict(self):
+ # If the user hasn't opted into storing compressed state dicts
+ # we have to make sure our states are regular torch.Tensors. This
+ # is mostly needed to make FSDP happy in the case that we want to
+ # resume training with a number of devices where
+ # (param numel / device count) % quantization group size != 0
+ # for any param.
+ d = super().state_dict()
+ opt_state, _ = d.values() # other val is param_groups
+ for param_id in opt_state:
+ # make a copy so that we don't mutate our self.state; opt_state
+ # isn't the same as self.state, but its consituent dicts are
+ # the same as those in self.state
+ param_state = {k: v for k, v in opt_state[param_id].items()}
+ if "exp_avg" in param_state: # true if we've taken any steps
+ qtensor = param_state.pop("exp_avg")
+ assert isinstance(qtensor, _MaybeQuantizedTensor) # pyright
+ param_state.update(
+ qtensor.state_dict(
+ name="exp_avg",
+ allow_quantized=self._compress_state_dict,
+ )
+ )
+ if "errors" in param_state:
+ # fsdp apparently needs the states to be the same shape
+ # as the params
+ param_state["errors"] = (
+ param_state["errors"]
+ .view(torch.uint8)
+ .to(dtype=torch.bfloat16)
+ )
+ opt_state[param_id] = param_state
+ return d
+
+
+class _MaybeQuantizedTensor:
+ """Helper class so 8b LION doesn't have to know quantization details.
+
+ Important points about this class:
+ * It handles CPU tensors not being quantized
+ * It knows how to save + load state dicts, handling both the quantized
+ and not quantized cases
+ * It implements some parts of the torch.Tensor interface that we need,
+ but is not intended to be a full torch.Tensor replacement
+ """
+
+ def __init__(self, data: Optional[torch.Tensor], try_quantize: bool = True):
+ super().__init__()
+ self.data: Optional[torch.Tensor] = None
+ self.quantized: Optional[torch.Tensor] = None
+ self.scales: Optional[torch.Tensor] = None
+ self._try_quantize = try_quantize and torch.cuda.is_available()
+
+ # conditionally import CUDA kernels
+ self._f_encode = None
+ self._f_decode = None
+ if self._try_quantize:
+ from turbo import dequantize8b, quantize8b
+
+ self._f_encode = quantize8b
+ self._f_decode = dequantize8b
+
+ if data is not None:
+ self.set_data(data)
+
+ def state_dict(
+ self, name: str, allow_quantized: bool = False
+ ) -> Dict[str, torch.Tensor]:
+ if self.is_quantized() and allow_quantized:
+ assert self.quantized is not None # pyright
+ assert self.scales is not None # pyright
+ return {
+ f"{name}::quantized": self.quantized,
+ f"{name}::scales": self.scales,
+ }
+ return {name: self.materialize().to(dtype=torch.bfloat16)}
+
+ def load_state_dict(self, d: Dict[str, torch.Tensor], name: str) -> None:
+ # we allow other keys in the state dict for convenience, so you can
+ # just pass this the whole opt state for a parameters
+ d = {k: v for k, v in d.items() if k.startswith(name)}
+ if name in d:
+ if len(d) != 1:
+ raise ValueError(
+ f"If state dict specifies {name}, it must not "
+ + f"specify other keys. Got {list(d.keys())}"
+ )
+ self.set_data(d[name])
+ return
+
+ self.quantized = d[f"{name}::quantized"].to(dtype=torch.int8)
+ self.scales = d[f"{name}::scales"].to(dtype=torch.float16)
+
+ def set_data(self, data: torch.Tensor) -> None:
+ if self._try_quantize:
+ if not data.is_cuda:
+ raise NotImplementedError(
+ f"Attempting to quantize a non-CUDA {data.dtype} tensor "
+ + f"on device {data.device} with shape {data.shape}."
+ )
+ self.data = None
+ assert self._f_encode is not None # pyright
+ self.quantized, self.scales = self._f_encode(data)
+ else:
+ self.data = data.to(dtype=torch.float32)
+ self.quantized = None
+ self.scales = None
+
+ def is_quantized(self) -> bool:
+ return self.data is None
+
+ def materialize(self) -> torch.Tensor:
+ if not self.is_quantized():
+ assert self.data is not None # pyright
+ return self.data
+ assert self._f_decode is not None # pyright
+ assert self.quantized is not None # pyright
+ assert self.scales is not None # pyright
+ return self._f_decode(self.quantized, self.scales)
+
+ @property # property to mirror Tensor interface
+ def is_cuda(self) -> bool:
+ if self.is_quantized():
+ assert self.quantized is not None # pyright
+ return self.quantized.is_cuda
+ assert self.data is not None # pyright
+ return self.data.is_cuda
+
+ @property # property to mirror Tensor interface
+ def shape(self) -> Tuple[int]:
+ if self.is_quantized():
+ assert self.quantized is not None # pyright
+ return self.quantized.shape
+ assert self.data is not None # pyright
+ return self.data.shape
+
+ def numel(self) -> int:
+ if self.is_quantized():
+ assert self.quantized is not None # pyright
+ return self.quantized.numel()
+ assert self.data is not None # pyright
+ return self.data.numel()
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__} quantized={self.is_quantized()} "
+ + f"shape={self.shape}"
+ )
+
+
+def lion_step_unfused(
+ grads: torch.Tensor,
+ weights: torch.Tensor,
+ momentums: torch.Tensor,
+ lr: float,
+ beta1: float,
+ beta2: float,
+ weight_decay: float = 0,
+) -> torch.Tensor:
+ # f32 cast to match fused impl + for compatibility with f32 grads or weights
+ momentums = momentums.to(dtype=torch.float32)
+ grads = grads.to(dtype=torch.float32)
+
+ update = momentums.lerp(grads, 1 - beta1).sign_()
+ if weight_decay > 0:
+ weights.mul_(1.0 - weight_decay)
+
+ weights.add_(update, alpha=-lr)
+ momentums.lerp_(grads, 1.0 - beta2)
+ return momentums # f32 upcast means not necessarily modified in place
+
+
+def lion8b_step_fused(
+ grads: torch.Tensor,
+ weights: torch.Tensor,
+ momentums: torch.Tensor,
+ scales: torch.Tensor,
+ lr: float,
+ beta1: float,
+ beta2: float,
+ weight_decay: float,
+ errors: Optional[torch.Tensor] = None,
+) -> None:
+ # just to save space in lists of allowed dtypes
+ f16, bf16, f32 = torch.float16, torch.bfloat16, torch.float32
+
+ use_errors = (errors is not None) and (weights.dtype in (f16, bf16))
+ orig_shape = weights.shape
+
+ # ------------------------------------------------ wall of error checking
+ quantize_group_size = 32
+ num_groups = (
+ weights.numel() + quantize_group_size - 1
+ ) // quantize_group_size
+ if num_groups != scales.numel():
+ raise ValueError(
+ f"Expected {num_groups} quantization scales but "
+ + f" received {scales.numel()}"
+ )
+
+ for name, tensor, allowed_dtypes in [
+ ("grad", grads, (f16, bf16, f32)),
+ ("param", weights, (f16, bf16, f32)),
+ ("momentum", momentums, [torch.int8]),
+ ("scales", scales, [f16]),
+ ("errors", errors, [torch.uint8]),
+ ]:
+ if name == "errors" and not use_errors:
+ continue
+ if not tensor.is_cuda:
+ raise ValueError(
+ f"{name} must be on a CUDA device, not {tensor.device}"
+ )
+ if not tensor.is_contiguous():
+ raise ValueError(f"{name} is not contiguous!")
+ strides_unequal = tensor.stride() != weights.stride()
+ if name not in ("scales", "errors") and strides_unequal:
+ raise ValueError(
+ f"{name} stride {tensor.stride()} != "
+ + f"param stride {weights.stride()}"
+ )
+ if tensor.dtype not in allowed_dtypes:
+ raise ValueError(
+ f"{name} must have dtype {allowed_dtypes}, not "
+ + f"{tensor.dtype}"
+ )
+ if (name != "scales") and (orig_shape != tensor.shape):
+ raise ValueError(
+ f"Param shape {orig_shape} != " + f"{name} shape {tensor.shape}"
+ )
+
+ if grads.dtype in (torch.float16, torch.bfloat16):
+ allowed_dtypes = (grads.dtype, torch.float32)
+ if weights.dtype not in allowed_dtypes:
+ raise ValueError(
+ f"Weights must be f32 or match grad dtype {grads.dtype}"
+ )
+
+ # ------------------------------------------------ actual function call
+ from turbo import lion8b_step_cuda
+
+ return lion8b_step_cuda(
+ grads=grads,
+ weights=weights,
+ momentums=momentums,
+ scales=scales,
+ lr=lr,
+ beta1=beta1,
+ beta2=beta2,
+ weight_decay=weight_decay,
+ errors=errors,
+ )
+
+
+def _lion8b_step(
+ grads: torch.Tensor,
+ weights: torch.Tensor,
+ momentums: _MaybeQuantizedTensor,
+ lr: float,
+ beta1: float,
+ beta2: float,
+ weight_decay: float = 0,
+ errors: Optional[torch.Tensor] = None,
+ fused: bool = True,
+) -> None:
+ if fused and not momentums.is_quantized():
+ raise NotImplementedError(
+ "Fused LION step only implemented with quantization."
+ )
+
+ if momentums.is_quantized() and fused:
+ assert momentums.quantized is not None # pyright
+ assert momentums.scales is not None # pyright
+ return lion8b_step_fused(
+ grads=grads,
+ weights=weights,
+ momentums=momentums.quantized,
+ scales=momentums.scales,
+ lr=lr,
+ beta1=beta1,
+ beta2=beta2,
+ weight_decay=weight_decay,
+ errors=errors,
+ )
+
+ momentums_float = momentums.materialize()
+ new_momentums = lion_step_unfused(
+ grads=grads,
+ weights=weights,
+ momentums=momentums_float,
+ lr=lr,
+ beta1=beta1,
+ beta2=beta2,
+ weight_decay=weight_decay,
+ )
+ momentums.set_data(new_momentums)