Skip to content

Commit

Permalink
[FEAT][FLOW]
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed Dec 15, 2024
1 parent 8105d94 commit cd78a41
Show file tree
Hide file tree
Showing 6 changed files with 1,241 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/models/agi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def main():
dummy_audio = torch.randn(1, 16000) # Batch size 1, 1-second audio at 16kHz

# Forward pass
output = model(dummy_image, dummy_audio)
model(dummy_image, dummy_audio)
logger.info("Model output obtained")


Expand Down
2 changes: 1 addition & 1 deletion examples/models/evo_transformer_mutate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def create_next_generation(self) -> None:
"""
parents = self.select_parents()
next_generation = parents.copy()
num_children = self.population_size - len(parents)
self.population_size - len(parents)
while len(next_generation) < self.population_size:
parent_indices = torch.randperm(len(parents))[:2]
parent1 = parents[parent_indices[0]]
Expand Down
5 changes: 3 additions & 2 deletions muon.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

import math
import torch
import torch.nn as nn
from zeta import Muon # Assuming muon.py contains our implementation


# Simple transformer layer
class SimpleTransformer(nn.Module):
def __init__(self, d_model=256):
Expand All @@ -24,6 +24,7 @@ def forward(self, x):
out = torch.matmul(attn, v)
return self.output(out)


# Create model
model = SimpleTransformer()

Expand All @@ -32,7 +33,7 @@ def forward(self, x):
other_params = []

for name, param in model.named_parameters():
if any(x in name for x in ['query', 'key', 'value']):
if any(x in name for x in ["query", "key", "value"]):
muon_params.append(param)
else:
other_params.append(param)
Expand Down
8 changes: 8 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@
from zeta.nn.modules.crome_adapter import CROMEAdapter
from zeta.nn.modules.cog_vlm_two_adapter import CogVLMTwoAdapter
from zeta.nn.modules.sigmoid_attn import SigmoidAttention
from zeta.nn.modules.flow_matching import Flow, MixtureFlow, MixtureFlowConfig
from zeta.nn.modules.flow_transformer import FlowTransformerConfig, FlowMLP, FlowTransformer

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -455,4 +457,10 @@
"CROMEAdapter",
"CogVLMTwoAdapter",
"SigmoidAttention",
"Flow",
"MixtureFlow",
"MixtureFlowConfig",
"FlowTransformerConfig",
"FlowMLP",
"FlowTransformer",
]
Loading

0 comments on commit cd78a41

Please sign in to comment.