diff --git a/tests/nn/embeddings/qftp_embeddings.py b/tests/nn/embeddings/qftp_embeddings.py new file mode 100644 index 00000000..493cc187 --- /dev/null +++ b/tests/nn/embeddings/qftp_embeddings.py @@ -0,0 +1,93 @@ +import pytest +import torch +from zeta.nn.embeddings.qfsp_embeddings import QuantumSuperpositionEmbeddings + + +def test_qsembeddings_init(): + vocab_size = 10000 + dim = 512 + model = QuantumSuperpositionEmbeddings(vocab_size, dim) + assert model.embed_dim == dim + assert model.base_embeddings.num_embeddings == vocab_size + assert model.superposed_embeddings.num_embeddings == vocab_size + +def test_qsembeddings_forward_weighted_sum(): + vocab_size = 10000 + dim = 512 + model = QuantumSuperpositionEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + context_vector = torch.rand(1, 10) + embeddings = model(x, context_vector, 'weighted_sum') + assert embeddings.shape == (1, 10, dim) + +def test_qsembeddings_forward_dot_product(): + vocab_size = 10000 + dim = 512 + model = QuantumSuperpositionEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + context_vector = torch.rand(1, 10) + embeddings = model(x, context_vector, 'dot_product') + assert embeddings.shape == (1, 10, dim) + +def test_qsembeddings_forward_cosine_similarity(): + vocab_size = 10000 + dim = 512 + model = QuantumSuperpositionEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + context_vector = torch.rand(1, 10) + embeddings = model(x, context_vector, 'cosine_similarity') + assert embeddings.shape == (1, 10, dim) + +def test_qsembeddings_forward_gated(): + vocab_size = 10000 + dim = 512 + model = QuantumSuperpositionEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + context_vector = torch.rand(1, 10) + embeddings = model(x, context_vector, 'gated') + assert embeddings.shape == (1, 10, dim) + +def test_qsembeddings_forward_concat_linear(): + vocab_size = 10000 + dim = 512 + model = QuantumSuperpositionEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + context_vector = torch.rand(1, 10) + embeddings = model(x, context_vector, 'concat_linear') + assert embeddings.shape == (1, 10, dim) + +def test_qsembeddings_forward_invalid_mode(): + vocab_size = 10000 + dim = 512 + model = QuantumSuperpositionEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + context_vector = torch.rand(1, 10) + with pytest.raises(ValueError): + model(x, context_vector, 'invalid_mode') + +def test_qsembeddings_forward_large_input(): + vocab_size = 10000 + dim = 512 + model = QuantumSuperpositionEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1000, 1000)) + context_vector = torch.rand(1000, 1000) + embeddings = model(x, context_vector, 'weighted_sum') + assert embeddings.shape == (1000, 1000, dim) + +def test_qsembeddings_forward_large_dim(): + vocab_size = 10000 + dim = 10000 + model = QuantumSuperpositionEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + context_vector = torch.rand(1, 10) + embeddings = model(x, context_vector, 'weighted_sum') + assert embeddings.shape == (1, 10, dim) + +def test_qsembeddings_forward_large_vocab_size(): + vocab_size = 1000000 + dim = 512 + model = QuantumSuperpositionEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + context_vector = torch.rand(1, 10) + embeddings = model(x, context_vector, 'weighted_sum') + assert embeddings.shape == (1, 10, dim) \ No newline at end of file diff --git a/tests/nn/embeddings/test_QFTSPEmbeddings.py b/tests/nn/embeddings/test_QFTSPEmbeddings.py new file mode 100644 index 00000000..4e3f334c --- /dev/null +++ b/tests/nn/embeddings/test_QFTSPEmbeddings.py @@ -0,0 +1,86 @@ +import pytest +import torch +from zeta.nn.embeddings.qft_embeddings import QFTSPEmbeddings + + +def test_qftspembeddings_init(): + vocab_size = 10000 + dim = 512 + model = QFTSPEmbeddings(vocab_size, dim) + assert model.vocab_size == vocab_size + assert model.dim == dim + + +def test_qftspembeddings_forward(): + vocab_size = 10000 + dim = 512 + model = QFTSPEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + embeddings = model(x) + assert embeddings.shape == (1, 10, dim) + + +def test_qftspembeddings_forward_zero_dim(): + vocab_size = 10000 + dim = 0 + model = QFTSPEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + embeddings = model(x) + assert embeddings.shape == (1, 10, 0) + + +def test_qftspembeddings_forward_odd_dim(): + vocab_size = 10000 + dim = 513 + model = QFTSPEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + embeddings = model(x) + assert embeddings.shape == (1, 10, dim) + + +def test_qftspembeddings_forward_large_input(): + vocab_size = 10000 + dim = 512 + model = QFTSPEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1000, 1000)) + embeddings = model(x) + assert embeddings.shape == (1000, 1000, dim) + + +def test_qftspembeddings_forward_large_dim(): + vocab_size = 10000 + dim = 10000 + model = QFTSPEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + embeddings = model(x) + assert embeddings.shape == (1, 10, dim) + + +def test_qftspembeddings_forward_large_vocab_size(): + vocab_size = 1000000 + dim = 512 + model = QFTSPEmbeddings(vocab_size, dim) + x = torch.randint(0, vocab_size, (1, 10)) + embeddings = model(x) + assert embeddings.shape == (1, 10, dim) + + +def test_qftspembeddings_forward_negative_dim(): + vocab_size = 10000 + dim = -512 + with pytest.raises(ValueError): + model = QFTSPEmbeddings(vocab_size, dim) + + +def test_qftspembeddings_forward_negative_vocab_size(): + vocab_size = -10000 + dim = 512 + with pytest.raises(ValueError): + model = QFTSPEmbeddings(vocab_size, dim) + + +def test_qftspembeddings_forward_zero_vocab_size(): + vocab_size = 0 + dim = 512 + with pytest.raises(ValueError): + model = QFTSPEmbeddings(vocab_size, dim) diff --git a/zeta/nn/embeddings/__init__.py b/zeta/nn/embeddings/__init__.py index cba05081..cfc8766e 100644 --- a/zeta/nn/embeddings/__init__.py +++ b/zeta/nn/embeddings/__init__.py @@ -1,7 +1,4 @@ -# embeddings - from zeta.nn.embeddings.abc_pos_emb import AbsolutePositionalEmbedding -from zeta.nn.embeddings.base import BaseEmbedding from zeta.nn.embeddings.embedding import ( BaseEmbedding, Embedding, @@ -10,7 +7,6 @@ from zeta.nn.embeddings.multiway_network import ( MultiwayEmbedding, MultiwayNetwork, - # MultiwayWrapper, ) from zeta.nn.embeddings.nominal_embeddings import NominalEmbedding from zeta.nn.embeddings.positional import PositionalEmbedding @@ -26,9 +22,10 @@ apply_rotary_pos_emb, rotate_every_two, ) -from zeta.nn.embeddings.yarn import * from zeta.nn.embeddings.yarn import YarnEmbedding from zeta.nn.embeddings.sine_positional import SinePositionalEmbedding +from zeta.nn.embeddings.qft_embeddings import QFTSPEmbeddings +from zeta.nn.embeddings.qfsp_embeddings import QuantumSuperpositionEmbeddings __all__ = [ "AbsolutePositionalEmbedding", @@ -37,7 +34,6 @@ "TextEmbedding", "MultiwayEmbedding", "MultiwayNetwork", - # "MultiwayWrapper", "NominalEmbedding", "PositionalEmbedding", "PositionInterpolationEmbeddings", @@ -50,4 +46,6 @@ "rotate_every_two", "YarnEmbedding", "SinePositionalEmbedding", + "QFTSPEmbeddings", + "QuantumSuperpositionEmbeddings" ] diff --git a/zeta/nn/embeddings/qfsp_embeddings.py b/zeta/nn/embeddings/qfsp_embeddings.py new file mode 100644 index 00000000..2c6d50d2 --- /dev/null +++ b/zeta/nn/embeddings/qfsp_embeddings.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class QuantumSuperpositionEmbeddings(nn.Module): + """ + QuantumSuperpositionEmbeddings with multiple collapse mechanisms. + + This module allows for different ways of collapsing the superposition of embeddings, + based on the provided context and selected mechanism. + """ + + def __init__(self, vocab_size, embed_dim): + super(QuantumSuperpositionEmbeddings, self).__init__() + self.embed_dim = embed_dim + self.base_embeddings = nn.Embedding(vocab_size, embed_dim) + self.superposed_embeddings = nn.Embedding(vocab_size, embed_dim) + self.linear_transform = nn.Linear(2 * embed_dim, embed_dim) + + def forward(self, input_ids, context_vector, collapse_mode='weighted_sum'): + base_embeds = self.base_embeddings(input_ids) + superposed_embeds = self.superposed_embeddings(input_ids) + + if collapse_mode == 'weighted_sum': + collapsed_embeds = base_embeds + context_vector.unsqueeze(-1) * superposed_embeds + elif collapse_mode == 'dot_product': + scale = torch.sum(superposed_embeds * context_vector.unsqueeze(-1), dim=-1, keepdim=True) + collapsed_embeds = base_embeds + scale * superposed_embeds + elif collapse_mode == 'cosine_similarity': + scale = F.cosine_similarity(superposed_embeds, context_vector.unsqueeze(-1), dim=-1).unsqueeze(-1) + collapsed_embeds = base_embeds + scale * superposed_embeds + elif collapse_mode == 'gated': + gate = torch.sigmoid(context_vector) + collapsed_embeds = base_embeds + gate.unsqueeze(-1) * superposed_embeds + elif collapse_mode == 'concat_linear': + concatenated = torch.cat([base_embeds, superposed_embeds], dim=-1) + collapsed_embeds = self.linear_transform(concatenated) + else: + raise ValueError("Invalid collapse mode selected") + + return collapsed_embeds + +# # Example Usage +# vocab_size = 10000 +# embed_dim = 512 + +# model = QuantumSuperpositionEmbeddings(vocab_size, embed_dim) +# input_ids = torch.randint(0, vocab_size, (1, 10)) +# context_vector = torch.rand(1, 10) + +# # Test different collapse modes +# for mode in ['weighted_sum', 'dot_product', 'cosine_similarity', 'gated', 'concat_linear']: +# embeddings = model(input_ids, context_vector, collapse_mode=mode) +# print(f"Collapse mode: {mode}, Embeddings shape: {embeddings.shape}") diff --git a/zeta/nn/embeddings/qft_embeddings.py b/zeta/nn/embeddings/qft_embeddings.py new file mode 100644 index 00000000..e2ca3e86 --- /dev/null +++ b/zeta/nn/embeddings/qft_embeddings.py @@ -0,0 +1,58 @@ +import torch +from torch import nn +import numpy as np + + +class QFTSPEmbeddings(nn.Module): + """Quantum Fourier Transform-inspired Shift Phase Embeddings. + + + Attributes: + vocab_size (int): The size of the vocabulary. + dim (int): The dimensionality of the embeddings. + + Methods: + forward(x: torch.Tensor) -> torch.Tensor: Forward pass of the QFTSPEmbeddings module. + + Example: + >>> vocab_size = 10000 + >>> dim = 512 + >>> model = QFTSPEmbeddings(vocab_size, dim) + >>> x = torch.randint(0, vocab_size, (1, 10)) + >>> embeddings = model(x) + >>> print(embeddings) + """ + + def __init__( + self, vocab_size: int = None, dim: int = None, *args, **kwargs + ): + super().__init__() + self.vocab_size = vocab_size + self.dim = dim + + self.embeddings = nn.Embedding(vocab_size, dim, *args, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the QFTSPEmbeddings module. + + Args: + x (torch.Tensor): input tensor + + Returns: + torch.Tensor: phase shifted embeddings + """ + # real valued embeddings + embeds = self.embeddings(x) + + # Quantum-inspired operation: Phase shift + # Split embed_dim into two halves for real and imaginary parts + phase_shift = torch.exp(2j * np.pi * torch.rand(self.dim // 2)) + shifted_embeds = torch.cat( + [ + embeds[:, :, : self.dim // 2] * phase_shift.real, + embeds[:, :, self.dim // 2 :] * phase_shift.imag, + ], + dim=-1, + ) + + return shifted_embeds diff --git a/zeta/utils/cuda_memory_wrapper.py b/zeta/utils/cuda_memory_wrapper.py index e9efadf6..1cb837eb 100644 --- a/zeta/utils/cuda_memory_wrapper.py +++ b/zeta/utils/cuda_memory_wrapper.py @@ -1,39 +1,49 @@ -import torch -import functools -import logging - +import torch +import functools +import logging +# Logging initialization logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) - +# Main function def track_cuda_memory_usage(func): + """Track CUDA memory usage of a function. + + Args: + func (function): The function to be tracked. + + Returns: + function: The wrapped function. + + Example: + >>> @track_cuda_memory_usage + >>> def train(): + >>> pass + >>> train() + """ @functools.wraps(func) def wrapper(*args, **kwargs): if not torch.cuda.is_available(): logging.warning("CUDA is not available, skip tracking memory usage") return func(*args, **kwargs) - + torch.cuda.synchronize() before_memory = torch.cuda.memory_allocated() - + try: result = func(*args, **kwargs) except Exception as error: logging.error(f"Error occurs when running {func.__name__}: {error}") raise - + finally: torch.cuda.synchronize() after_memory = torch.cuda.memory_allocated() memory_diff = after_memory - before_memory - logging.info( - f"Memory usage of {func.__name__}: {memory_diff} bytes" - ) - + logging.info(f"Memory usage of {func.__name__}: {memory_diff} bytes") + return result - - -return wrapper + return wrapper \ No newline at end of file