Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 17, 2023
1 parent 2c89b26 commit cbb33a9
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 35 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "0.9.4"
version = "0.9.6"
description = "Transformers at zeta scales"
authors = ["Zeta Team <kye@apac.ai>"]
license = "MIT"
Expand Down
29 changes: 19 additions & 10 deletions tests/nn/embeddings/qftp_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,83 +11,92 @@ def test_qsembeddings_init():
assert model.base_embeddings.num_embeddings == vocab_size
assert model.superposed_embeddings.num_embeddings == vocab_size


def test_qsembeddings_forward_weighted_sum():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'weighted_sum')
embeddings = model(x, context_vector, "weighted_sum")
assert embeddings.shape == (1, 10, dim)


def test_qsembeddings_forward_dot_product():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'dot_product')
embeddings = model(x, context_vector, "dot_product")
assert embeddings.shape == (1, 10, dim)


def test_qsembeddings_forward_cosine_similarity():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'cosine_similarity')
embeddings = model(x, context_vector, "cosine_similarity")
assert embeddings.shape == (1, 10, dim)


def test_qsembeddings_forward_gated():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'gated')
embeddings = model(x, context_vector, "gated")
assert embeddings.shape == (1, 10, dim)


def test_qsembeddings_forward_concat_linear():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'concat_linear')
embeddings = model(x, context_vector, "concat_linear")
assert embeddings.shape == (1, 10, dim)


def test_qsembeddings_forward_invalid_mode():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
with pytest.raises(ValueError):
model(x, context_vector, 'invalid_mode')
model(x, context_vector, "invalid_mode")


def test_qsembeddings_forward_large_input():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1000, 1000))
context_vector = torch.rand(1000, 1000)
embeddings = model(x, context_vector, 'weighted_sum')
embeddings = model(x, context_vector, "weighted_sum")
assert embeddings.shape == (1000, 1000, dim)


def test_qsembeddings_forward_large_dim():
vocab_size = 10000
dim = 10000
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'weighted_sum')
embeddings = model(x, context_vector, "weighted_sum")
assert embeddings.shape == (1, 10, dim)


def test_qsembeddings_forward_large_vocab_size():
vocab_size = 1000000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'weighted_sum')
assert embeddings.shape == (1, 10, dim)
embeddings = model(x, context_vector, "weighted_sum")
assert embeddings.shape == (1, 10, dim)
2 changes: 1 addition & 1 deletion zeta/nn/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@
"YarnEmbedding",
"SinePositionalEmbedding",
"QFTSPEmbeddings",
"QuantumSuperpositionEmbeddings"
"QuantumSuperpositionEmbeddings",
]
32 changes: 22 additions & 10 deletions zeta/nn/embeddings/qfsp_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F


class QuantumSuperpositionEmbeddings(nn.Module):
"""
QuantumSuperpositionEmbeddings with multiple collapse mechanisms.
Expand All @@ -17,29 +18,40 @@ def __init__(self, vocab_size, embed_dim):
self.superposed_embeddings = nn.Embedding(vocab_size, embed_dim)
self.linear_transform = nn.Linear(2 * embed_dim, embed_dim)

def forward(self, input_ids, context_vector, collapse_mode='weighted_sum'):
def forward(self, input_ids, context_vector, collapse_mode="weighted_sum"):
base_embeds = self.base_embeddings(input_ids)
superposed_embeds = self.superposed_embeddings(input_ids)

if collapse_mode == 'weighted_sum':
collapsed_embeds = base_embeds + context_vector.unsqueeze(-1) * superposed_embeds
elif collapse_mode == 'dot_product':
scale = torch.sum(superposed_embeds * context_vector.unsqueeze(-1), dim=-1, keepdim=True)
if collapse_mode == "weighted_sum":
collapsed_embeds = (
base_embeds + context_vector.unsqueeze(-1) * superposed_embeds
)
elif collapse_mode == "dot_product":
scale = torch.sum(
superposed_embeds * context_vector.unsqueeze(-1),
dim=-1,
keepdim=True,
)
collapsed_embeds = base_embeds + scale * superposed_embeds
elif collapse_mode == 'cosine_similarity':
scale = F.cosine_similarity(superposed_embeds, context_vector.unsqueeze(-1), dim=-1).unsqueeze(-1)
elif collapse_mode == "cosine_similarity":
scale = F.cosine_similarity(
superposed_embeds, context_vector.unsqueeze(-1), dim=-1
).unsqueeze(-1)
collapsed_embeds = base_embeds + scale * superposed_embeds
elif collapse_mode == 'gated':
elif collapse_mode == "gated":
gate = torch.sigmoid(context_vector)
collapsed_embeds = base_embeds + gate.unsqueeze(-1) * superposed_embeds
elif collapse_mode == 'concat_linear':
collapsed_embeds = (
base_embeds + gate.unsqueeze(-1) * superposed_embeds
)
elif collapse_mode == "concat_linear":
concatenated = torch.cat([base_embeds, superposed_embeds], dim=-1)
collapsed_embeds = self.linear_transform(concatenated)
else:
raise ValueError("Invalid collapse mode selected")

return collapsed_embeds


# # Example Usage
# vocab_size = 10000
# embed_dim = 512
Expand Down
31 changes: 18 additions & 13 deletions zeta/utils/cuda_memory_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,54 @@
import torch
import functools
import logging
import torch
import functools
import logging

# Logging initialization
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)


# Main function
def track_cuda_memory_usage(func):
"""Track CUDA memory usage of a function.
Args:
func (function): The function to be tracked.
Returns:
function: The wrapped function.
Example:
>>> @track_cuda_memory_usage
>>> def train():
>>> pass
>>> train()
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
if not torch.cuda.is_available():
logging.warning("CUDA is not available, skip tracking memory usage")
return func(*args, **kwargs)

torch.cuda.synchronize()
before_memory = torch.cuda.memory_allocated()

try:
result = func(*args, **kwargs)
except Exception as error:
logging.error(f"Error occurs when running {func.__name__}: {error}")
raise

finally:
torch.cuda.synchronize()
after_memory = torch.cuda.memory_allocated()
memory_diff = after_memory - before_memory
logging.info(f"Memory usage of {func.__name__}: {memory_diff} bytes")

logging.info(
f"Memory usage of {func.__name__}: {memory_diff} bytes"
)

return result
return wrapper

return wrapper

0 comments on commit cbb33a9

Please sign in to comment.