diff --git a/pyproject.toml b/pyproject.toml index 68ff3d05..f65cd5c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zetascale" -version = "0.9.4" +version = "0.9.6" description = "Transformers at zeta scales" authors = ["Zeta Team "] license = "MIT" diff --git a/tests/nn/embeddings/qftp_embeddings.py b/tests/nn/embeddings/qftp_embeddings.py index 493cc187..f2327199 100644 --- a/tests/nn/embeddings/qftp_embeddings.py +++ b/tests/nn/embeddings/qftp_embeddings.py @@ -11,51 +11,57 @@ def test_qsembeddings_init(): assert model.base_embeddings.num_embeddings == vocab_size assert model.superposed_embeddings.num_embeddings == vocab_size + def test_qsembeddings_forward_weighted_sum(): vocab_size = 10000 dim = 512 model = QuantumSuperpositionEmbeddings(vocab_size, dim) x = torch.randint(0, vocab_size, (1, 10)) context_vector = torch.rand(1, 10) - embeddings = model(x, context_vector, 'weighted_sum') + embeddings = model(x, context_vector, "weighted_sum") assert embeddings.shape == (1, 10, dim) + def test_qsembeddings_forward_dot_product(): vocab_size = 10000 dim = 512 model = QuantumSuperpositionEmbeddings(vocab_size, dim) x = torch.randint(0, vocab_size, (1, 10)) context_vector = torch.rand(1, 10) - embeddings = model(x, context_vector, 'dot_product') + embeddings = model(x, context_vector, "dot_product") assert embeddings.shape == (1, 10, dim) + def test_qsembeddings_forward_cosine_similarity(): vocab_size = 10000 dim = 512 model = QuantumSuperpositionEmbeddings(vocab_size, dim) x = torch.randint(0, vocab_size, (1, 10)) context_vector = torch.rand(1, 10) - embeddings = model(x, context_vector, 'cosine_similarity') + embeddings = model(x, context_vector, "cosine_similarity") assert embeddings.shape == (1, 10, dim) + def test_qsembeddings_forward_gated(): vocab_size = 10000 dim = 512 model = QuantumSuperpositionEmbeddings(vocab_size, dim) x = torch.randint(0, vocab_size, (1, 10)) context_vector = torch.rand(1, 10) - embeddings = model(x, context_vector, 'gated') + embeddings = model(x, context_vector, "gated") assert embeddings.shape == (1, 10, dim) + def test_qsembeddings_forward_concat_linear(): vocab_size = 10000 dim = 512 model = QuantumSuperpositionEmbeddings(vocab_size, dim) x = torch.randint(0, vocab_size, (1, 10)) context_vector = torch.rand(1, 10) - embeddings = model(x, context_vector, 'concat_linear') + embeddings = model(x, context_vector, "concat_linear") assert embeddings.shape == (1, 10, dim) + def test_qsembeddings_forward_invalid_mode(): vocab_size = 10000 dim = 512 @@ -63,7 +69,8 @@ def test_qsembeddings_forward_invalid_mode(): x = torch.randint(0, vocab_size, (1, 10)) context_vector = torch.rand(1, 10) with pytest.raises(ValueError): - model(x, context_vector, 'invalid_mode') + model(x, context_vector, "invalid_mode") + def test_qsembeddings_forward_large_input(): vocab_size = 10000 @@ -71,23 +78,25 @@ def test_qsembeddings_forward_large_input(): model = QuantumSuperpositionEmbeddings(vocab_size, dim) x = torch.randint(0, vocab_size, (1000, 1000)) context_vector = torch.rand(1000, 1000) - embeddings = model(x, context_vector, 'weighted_sum') + embeddings = model(x, context_vector, "weighted_sum") assert embeddings.shape == (1000, 1000, dim) + def test_qsembeddings_forward_large_dim(): vocab_size = 10000 dim = 10000 model = QuantumSuperpositionEmbeddings(vocab_size, dim) x = torch.randint(0, vocab_size, (1, 10)) context_vector = torch.rand(1, 10) - embeddings = model(x, context_vector, 'weighted_sum') + embeddings = model(x, context_vector, "weighted_sum") assert embeddings.shape == (1, 10, dim) + def test_qsembeddings_forward_large_vocab_size(): vocab_size = 1000000 dim = 512 model = QuantumSuperpositionEmbeddings(vocab_size, dim) x = torch.randint(0, vocab_size, (1, 10)) context_vector = torch.rand(1, 10) - embeddings = model(x, context_vector, 'weighted_sum') - assert embeddings.shape == (1, 10, dim) \ No newline at end of file + embeddings = model(x, context_vector, "weighted_sum") + assert embeddings.shape == (1, 10, dim) diff --git a/zeta/nn/embeddings/__init__.py b/zeta/nn/embeddings/__init__.py index cfc8766e..18c6a063 100644 --- a/zeta/nn/embeddings/__init__.py +++ b/zeta/nn/embeddings/__init__.py @@ -47,5 +47,5 @@ "YarnEmbedding", "SinePositionalEmbedding", "QFTSPEmbeddings", - "QuantumSuperpositionEmbeddings" + "QuantumSuperpositionEmbeddings", ] diff --git a/zeta/nn/embeddings/qfsp_embeddings.py b/zeta/nn/embeddings/qfsp_embeddings.py index 2c6d50d2..d7bde425 100644 --- a/zeta/nn/embeddings/qfsp_embeddings.py +++ b/zeta/nn/embeddings/qfsp_embeddings.py @@ -2,6 +2,7 @@ import torch.nn as nn import torch.nn.functional as F + class QuantumSuperpositionEmbeddings(nn.Module): """ QuantumSuperpositionEmbeddings with multiple collapse mechanisms. @@ -17,22 +18,32 @@ def __init__(self, vocab_size, embed_dim): self.superposed_embeddings = nn.Embedding(vocab_size, embed_dim) self.linear_transform = nn.Linear(2 * embed_dim, embed_dim) - def forward(self, input_ids, context_vector, collapse_mode='weighted_sum'): + def forward(self, input_ids, context_vector, collapse_mode="weighted_sum"): base_embeds = self.base_embeddings(input_ids) superposed_embeds = self.superposed_embeddings(input_ids) - if collapse_mode == 'weighted_sum': - collapsed_embeds = base_embeds + context_vector.unsqueeze(-1) * superposed_embeds - elif collapse_mode == 'dot_product': - scale = torch.sum(superposed_embeds * context_vector.unsqueeze(-1), dim=-1, keepdim=True) + if collapse_mode == "weighted_sum": + collapsed_embeds = ( + base_embeds + context_vector.unsqueeze(-1) * superposed_embeds + ) + elif collapse_mode == "dot_product": + scale = torch.sum( + superposed_embeds * context_vector.unsqueeze(-1), + dim=-1, + keepdim=True, + ) collapsed_embeds = base_embeds + scale * superposed_embeds - elif collapse_mode == 'cosine_similarity': - scale = F.cosine_similarity(superposed_embeds, context_vector.unsqueeze(-1), dim=-1).unsqueeze(-1) + elif collapse_mode == "cosine_similarity": + scale = F.cosine_similarity( + superposed_embeds, context_vector.unsqueeze(-1), dim=-1 + ).unsqueeze(-1) collapsed_embeds = base_embeds + scale * superposed_embeds - elif collapse_mode == 'gated': + elif collapse_mode == "gated": gate = torch.sigmoid(context_vector) - collapsed_embeds = base_embeds + gate.unsqueeze(-1) * superposed_embeds - elif collapse_mode == 'concat_linear': + collapsed_embeds = ( + base_embeds + gate.unsqueeze(-1) * superposed_embeds + ) + elif collapse_mode == "concat_linear": concatenated = torch.cat([base_embeds, superposed_embeds], dim=-1) collapsed_embeds = self.linear_transform(concatenated) else: @@ -40,6 +51,7 @@ def forward(self, input_ids, context_vector, collapse_mode='weighted_sum'): return collapsed_embeds + # # Example Usage # vocab_size = 10000 # embed_dim = 512 diff --git a/zeta/utils/cuda_memory_wrapper.py b/zeta/utils/cuda_memory_wrapper.py index 1cb837eb..02ad005d 100644 --- a/zeta/utils/cuda_memory_wrapper.py +++ b/zeta/utils/cuda_memory_wrapper.py @@ -1,49 +1,54 @@ -import torch -import functools -import logging +import torch +import functools +import logging # Logging initialization logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) + # Main function def track_cuda_memory_usage(func): """Track CUDA memory usage of a function. Args: func (function): The function to be tracked. - + Returns: function: The wrapped function. - + Example: >>> @track_cuda_memory_usage >>> def train(): >>> pass >>> train() """ + @functools.wraps(func) def wrapper(*args, **kwargs): if not torch.cuda.is_available(): logging.warning("CUDA is not available, skip tracking memory usage") return func(*args, **kwargs) - + torch.cuda.synchronize() before_memory = torch.cuda.memory_allocated() - + try: result = func(*args, **kwargs) except Exception as error: logging.error(f"Error occurs when running {func.__name__}: {error}") raise - + finally: torch.cuda.synchronize() after_memory = torch.cuda.memory_allocated() memory_diff = after_memory - before_memory - logging.info(f"Memory usage of {func.__name__}: {memory_diff} bytes") - + logging.info( + f"Memory usage of {func.__name__}: {memory_diff} bytes" + ) + return result - return wrapper \ No newline at end of file + + return wrapper