-
-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat][QuantumSuperpositionEmbeddings]
- Loading branch information
Kye
committed
Dec 17, 2023
1 parent
40f0f00
commit 2c89b26
Showing
6 changed files
with
322 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import pytest | ||
import torch | ||
from zeta.nn.embeddings.qfsp_embeddings import QuantumSuperpositionEmbeddings | ||
|
||
|
||
def test_qsembeddings_init(): | ||
vocab_size = 10000 | ||
dim = 512 | ||
model = QuantumSuperpositionEmbeddings(vocab_size, dim) | ||
assert model.embed_dim == dim | ||
assert model.base_embeddings.num_embeddings == vocab_size | ||
assert model.superposed_embeddings.num_embeddings == vocab_size | ||
|
||
def test_qsembeddings_forward_weighted_sum(): | ||
vocab_size = 10000 | ||
dim = 512 | ||
model = QuantumSuperpositionEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
context_vector = torch.rand(1, 10) | ||
embeddings = model(x, context_vector, 'weighted_sum') | ||
assert embeddings.shape == (1, 10, dim) | ||
|
||
def test_qsembeddings_forward_dot_product(): | ||
vocab_size = 10000 | ||
dim = 512 | ||
model = QuantumSuperpositionEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
context_vector = torch.rand(1, 10) | ||
embeddings = model(x, context_vector, 'dot_product') | ||
assert embeddings.shape == (1, 10, dim) | ||
|
||
def test_qsembeddings_forward_cosine_similarity(): | ||
vocab_size = 10000 | ||
dim = 512 | ||
model = QuantumSuperpositionEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
context_vector = torch.rand(1, 10) | ||
embeddings = model(x, context_vector, 'cosine_similarity') | ||
assert embeddings.shape == (1, 10, dim) | ||
|
||
def test_qsembeddings_forward_gated(): | ||
vocab_size = 10000 | ||
dim = 512 | ||
model = QuantumSuperpositionEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
context_vector = torch.rand(1, 10) | ||
embeddings = model(x, context_vector, 'gated') | ||
assert embeddings.shape == (1, 10, dim) | ||
|
||
def test_qsembeddings_forward_concat_linear(): | ||
vocab_size = 10000 | ||
dim = 512 | ||
model = QuantumSuperpositionEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
context_vector = torch.rand(1, 10) | ||
embeddings = model(x, context_vector, 'concat_linear') | ||
assert embeddings.shape == (1, 10, dim) | ||
|
||
def test_qsembeddings_forward_invalid_mode(): | ||
vocab_size = 10000 | ||
dim = 512 | ||
model = QuantumSuperpositionEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
context_vector = torch.rand(1, 10) | ||
with pytest.raises(ValueError): | ||
model(x, context_vector, 'invalid_mode') | ||
|
||
def test_qsembeddings_forward_large_input(): | ||
vocab_size = 10000 | ||
dim = 512 | ||
model = QuantumSuperpositionEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1000, 1000)) | ||
context_vector = torch.rand(1000, 1000) | ||
embeddings = model(x, context_vector, 'weighted_sum') | ||
assert embeddings.shape == (1000, 1000, dim) | ||
|
||
def test_qsembeddings_forward_large_dim(): | ||
vocab_size = 10000 | ||
dim = 10000 | ||
model = QuantumSuperpositionEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
context_vector = torch.rand(1, 10) | ||
embeddings = model(x, context_vector, 'weighted_sum') | ||
assert embeddings.shape == (1, 10, dim) | ||
|
||
def test_qsembeddings_forward_large_vocab_size(): | ||
vocab_size = 1000000 | ||
dim = 512 | ||
model = QuantumSuperpositionEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
context_vector = torch.rand(1, 10) | ||
embeddings = model(x, context_vector, 'weighted_sum') | ||
assert embeddings.shape == (1, 10, dim) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import pytest | ||
import torch | ||
from zeta.nn.embeddings.qft_embeddings import QFTSPEmbeddings | ||
|
||
|
||
def test_qftspembeddings_init(): | ||
vocab_size = 10000 | ||
dim = 512 | ||
model = QFTSPEmbeddings(vocab_size, dim) | ||
assert model.vocab_size == vocab_size | ||
assert model.dim == dim | ||
|
||
|
||
def test_qftspembeddings_forward(): | ||
vocab_size = 10000 | ||
dim = 512 | ||
model = QFTSPEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
embeddings = model(x) | ||
assert embeddings.shape == (1, 10, dim) | ||
|
||
|
||
def test_qftspembeddings_forward_zero_dim(): | ||
vocab_size = 10000 | ||
dim = 0 | ||
model = QFTSPEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
embeddings = model(x) | ||
assert embeddings.shape == (1, 10, 0) | ||
|
||
|
||
def test_qftspembeddings_forward_odd_dim(): | ||
vocab_size = 10000 | ||
dim = 513 | ||
model = QFTSPEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
embeddings = model(x) | ||
assert embeddings.shape == (1, 10, dim) | ||
|
||
|
||
def test_qftspembeddings_forward_large_input(): | ||
vocab_size = 10000 | ||
dim = 512 | ||
model = QFTSPEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1000, 1000)) | ||
embeddings = model(x) | ||
assert embeddings.shape == (1000, 1000, dim) | ||
|
||
|
||
def test_qftspembeddings_forward_large_dim(): | ||
vocab_size = 10000 | ||
dim = 10000 | ||
model = QFTSPEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
embeddings = model(x) | ||
assert embeddings.shape == (1, 10, dim) | ||
|
||
|
||
def test_qftspembeddings_forward_large_vocab_size(): | ||
vocab_size = 1000000 | ||
dim = 512 | ||
model = QFTSPEmbeddings(vocab_size, dim) | ||
x = torch.randint(0, vocab_size, (1, 10)) | ||
embeddings = model(x) | ||
assert embeddings.shape == (1, 10, dim) | ||
|
||
|
||
def test_qftspembeddings_forward_negative_dim(): | ||
vocab_size = 10000 | ||
dim = -512 | ||
with pytest.raises(ValueError): | ||
model = QFTSPEmbeddings(vocab_size, dim) | ||
|
||
|
||
def test_qftspembeddings_forward_negative_vocab_size(): | ||
vocab_size = -10000 | ||
dim = 512 | ||
with pytest.raises(ValueError): | ||
model = QFTSPEmbeddings(vocab_size, dim) | ||
|
||
|
||
def test_qftspembeddings_forward_zero_vocab_size(): | ||
vocab_size = 0 | ||
dim = 512 | ||
with pytest.raises(ValueError): | ||
model = QFTSPEmbeddings(vocab_size, dim) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class QuantumSuperpositionEmbeddings(nn.Module): | ||
""" | ||
QuantumSuperpositionEmbeddings with multiple collapse mechanisms. | ||
This module allows for different ways of collapsing the superposition of embeddings, | ||
based on the provided context and selected mechanism. | ||
""" | ||
|
||
def __init__(self, vocab_size, embed_dim): | ||
super(QuantumSuperpositionEmbeddings, self).__init__() | ||
self.embed_dim = embed_dim | ||
self.base_embeddings = nn.Embedding(vocab_size, embed_dim) | ||
self.superposed_embeddings = nn.Embedding(vocab_size, embed_dim) | ||
self.linear_transform = nn.Linear(2 * embed_dim, embed_dim) | ||
|
||
def forward(self, input_ids, context_vector, collapse_mode='weighted_sum'): | ||
base_embeds = self.base_embeddings(input_ids) | ||
superposed_embeds = self.superposed_embeddings(input_ids) | ||
|
||
if collapse_mode == 'weighted_sum': | ||
collapsed_embeds = base_embeds + context_vector.unsqueeze(-1) * superposed_embeds | ||
elif collapse_mode == 'dot_product': | ||
scale = torch.sum(superposed_embeds * context_vector.unsqueeze(-1), dim=-1, keepdim=True) | ||
collapsed_embeds = base_embeds + scale * superposed_embeds | ||
elif collapse_mode == 'cosine_similarity': | ||
scale = F.cosine_similarity(superposed_embeds, context_vector.unsqueeze(-1), dim=-1).unsqueeze(-1) | ||
collapsed_embeds = base_embeds + scale * superposed_embeds | ||
elif collapse_mode == 'gated': | ||
gate = torch.sigmoid(context_vector) | ||
collapsed_embeds = base_embeds + gate.unsqueeze(-1) * superposed_embeds | ||
elif collapse_mode == 'concat_linear': | ||
concatenated = torch.cat([base_embeds, superposed_embeds], dim=-1) | ||
collapsed_embeds = self.linear_transform(concatenated) | ||
else: | ||
raise ValueError("Invalid collapse mode selected") | ||
|
||
return collapsed_embeds | ||
|
||
# # Example Usage | ||
# vocab_size = 10000 | ||
# embed_dim = 512 | ||
|
||
# model = QuantumSuperpositionEmbeddings(vocab_size, embed_dim) | ||
# input_ids = torch.randint(0, vocab_size, (1, 10)) | ||
# context_vector = torch.rand(1, 10) | ||
|
||
# # Test different collapse modes | ||
# for mode in ['weighted_sum', 'dot_product', 'cosine_similarity', 'gated', 'concat_linear']: | ||
# embeddings = model(input_ids, context_vector, collapse_mode=mode) | ||
# print(f"Collapse mode: {mode}, Embeddings shape: {embeddings.shape}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
from torch import nn | ||
import numpy as np | ||
|
||
|
||
class QFTSPEmbeddings(nn.Module): | ||
"""Quantum Fourier Transform-inspired Shift Phase Embeddings. | ||
Attributes: | ||
vocab_size (int): The size of the vocabulary. | ||
dim (int): The dimensionality of the embeddings. | ||
Methods: | ||
forward(x: torch.Tensor) -> torch.Tensor: Forward pass of the QFTSPEmbeddings module. | ||
Example: | ||
>>> vocab_size = 10000 | ||
>>> dim = 512 | ||
>>> model = QFTSPEmbeddings(vocab_size, dim) | ||
>>> x = torch.randint(0, vocab_size, (1, 10)) | ||
>>> embeddings = model(x) | ||
>>> print(embeddings) | ||
""" | ||
|
||
def __init__( | ||
self, vocab_size: int = None, dim: int = None, *args, **kwargs | ||
): | ||
super().__init__() | ||
self.vocab_size = vocab_size | ||
self.dim = dim | ||
|
||
self.embeddings = nn.Embedding(vocab_size, dim, *args, **kwargs) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
"""Forward pass of the QFTSPEmbeddings module. | ||
Args: | ||
x (torch.Tensor): input tensor | ||
Returns: | ||
torch.Tensor: phase shifted embeddings | ||
""" | ||
# real valued embeddings | ||
embeds = self.embeddings(x) | ||
|
||
# Quantum-inspired operation: Phase shift | ||
# Split embed_dim into two halves for real and imaginary parts | ||
phase_shift = torch.exp(2j * np.pi * torch.rand(self.dim // 2)) | ||
shifted_embeds = torch.cat( | ||
[ | ||
embeds[:, :, : self.dim // 2] * phase_shift.real, | ||
embeds[:, :, self.dim // 2 :] * phase_shift.imag, | ||
], | ||
dim=-1, | ||
) | ||
|
||
return shifted_embeds |
Oops, something went wrong.