Skip to content

Commit

Permalink
[feat][QuantumSuperpositionEmbeddings]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 17, 2023
1 parent 40f0f00 commit 2c89b26
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 23 deletions.
93 changes: 93 additions & 0 deletions tests/nn/embeddings/qftp_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest
import torch
from zeta.nn.embeddings.qfsp_embeddings import QuantumSuperpositionEmbeddings


def test_qsembeddings_init():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
assert model.embed_dim == dim
assert model.base_embeddings.num_embeddings == vocab_size
assert model.superposed_embeddings.num_embeddings == vocab_size

def test_qsembeddings_forward_weighted_sum():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'weighted_sum')
assert embeddings.shape == (1, 10, dim)

def test_qsembeddings_forward_dot_product():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'dot_product')
assert embeddings.shape == (1, 10, dim)

def test_qsembeddings_forward_cosine_similarity():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'cosine_similarity')
assert embeddings.shape == (1, 10, dim)

def test_qsembeddings_forward_gated():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'gated')
assert embeddings.shape == (1, 10, dim)

def test_qsembeddings_forward_concat_linear():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'concat_linear')
assert embeddings.shape == (1, 10, dim)

def test_qsembeddings_forward_invalid_mode():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
with pytest.raises(ValueError):
model(x, context_vector, 'invalid_mode')

def test_qsembeddings_forward_large_input():
vocab_size = 10000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1000, 1000))
context_vector = torch.rand(1000, 1000)
embeddings = model(x, context_vector, 'weighted_sum')
assert embeddings.shape == (1000, 1000, dim)

def test_qsembeddings_forward_large_dim():
vocab_size = 10000
dim = 10000
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'weighted_sum')
assert embeddings.shape == (1, 10, dim)

def test_qsembeddings_forward_large_vocab_size():
vocab_size = 1000000
dim = 512
model = QuantumSuperpositionEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
context_vector = torch.rand(1, 10)
embeddings = model(x, context_vector, 'weighted_sum')
assert embeddings.shape == (1, 10, dim)
86 changes: 86 additions & 0 deletions tests/nn/embeddings/test_QFTSPEmbeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
import torch
from zeta.nn.embeddings.qft_embeddings import QFTSPEmbeddings


def test_qftspembeddings_init():
vocab_size = 10000
dim = 512
model = QFTSPEmbeddings(vocab_size, dim)
assert model.vocab_size == vocab_size
assert model.dim == dim


def test_qftspembeddings_forward():
vocab_size = 10000
dim = 512
model = QFTSPEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
embeddings = model(x)
assert embeddings.shape == (1, 10, dim)


def test_qftspembeddings_forward_zero_dim():
vocab_size = 10000
dim = 0
model = QFTSPEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
embeddings = model(x)
assert embeddings.shape == (1, 10, 0)


def test_qftspembeddings_forward_odd_dim():
vocab_size = 10000
dim = 513
model = QFTSPEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
embeddings = model(x)
assert embeddings.shape == (1, 10, dim)


def test_qftspembeddings_forward_large_input():
vocab_size = 10000
dim = 512
model = QFTSPEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1000, 1000))
embeddings = model(x)
assert embeddings.shape == (1000, 1000, dim)


def test_qftspembeddings_forward_large_dim():
vocab_size = 10000
dim = 10000
model = QFTSPEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
embeddings = model(x)
assert embeddings.shape == (1, 10, dim)


def test_qftspembeddings_forward_large_vocab_size():
vocab_size = 1000000
dim = 512
model = QFTSPEmbeddings(vocab_size, dim)
x = torch.randint(0, vocab_size, (1, 10))
embeddings = model(x)
assert embeddings.shape == (1, 10, dim)


def test_qftspembeddings_forward_negative_dim():
vocab_size = 10000
dim = -512
with pytest.raises(ValueError):
model = QFTSPEmbeddings(vocab_size, dim)


def test_qftspembeddings_forward_negative_vocab_size():
vocab_size = -10000
dim = 512
with pytest.raises(ValueError):
model = QFTSPEmbeddings(vocab_size, dim)


def test_qftspembeddings_forward_zero_vocab_size():
vocab_size = 0
dim = 512
with pytest.raises(ValueError):
model = QFTSPEmbeddings(vocab_size, dim)
10 changes: 4 additions & 6 deletions zeta/nn/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
# embeddings

from zeta.nn.embeddings.abc_pos_emb import AbsolutePositionalEmbedding
from zeta.nn.embeddings.base import BaseEmbedding
from zeta.nn.embeddings.embedding import (
BaseEmbedding,
Embedding,
Expand All @@ -10,7 +7,6 @@
from zeta.nn.embeddings.multiway_network import (
MultiwayEmbedding,
MultiwayNetwork,
# MultiwayWrapper,
)
from zeta.nn.embeddings.nominal_embeddings import NominalEmbedding
from zeta.nn.embeddings.positional import PositionalEmbedding
Expand All @@ -26,9 +22,10 @@
apply_rotary_pos_emb,
rotate_every_two,
)
from zeta.nn.embeddings.yarn import *
from zeta.nn.embeddings.yarn import YarnEmbedding
from zeta.nn.embeddings.sine_positional import SinePositionalEmbedding
from zeta.nn.embeddings.qft_embeddings import QFTSPEmbeddings
from zeta.nn.embeddings.qfsp_embeddings import QuantumSuperpositionEmbeddings

__all__ = [
"AbsolutePositionalEmbedding",
Expand All @@ -37,7 +34,6 @@
"TextEmbedding",
"MultiwayEmbedding",
"MultiwayNetwork",
# "MultiwayWrapper",
"NominalEmbedding",
"PositionalEmbedding",
"PositionInterpolationEmbeddings",
Expand All @@ -50,4 +46,6 @@
"rotate_every_two",
"YarnEmbedding",
"SinePositionalEmbedding",
"QFTSPEmbeddings",
"QuantumSuperpositionEmbeddings"
]
54 changes: 54 additions & 0 deletions zeta/nn/embeddings/qfsp_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class QuantumSuperpositionEmbeddings(nn.Module):
"""
QuantumSuperpositionEmbeddings with multiple collapse mechanisms.
This module allows for different ways of collapsing the superposition of embeddings,
based on the provided context and selected mechanism.
"""

def __init__(self, vocab_size, embed_dim):
super(QuantumSuperpositionEmbeddings, self).__init__()
self.embed_dim = embed_dim
self.base_embeddings = nn.Embedding(vocab_size, embed_dim)
self.superposed_embeddings = nn.Embedding(vocab_size, embed_dim)
self.linear_transform = nn.Linear(2 * embed_dim, embed_dim)

def forward(self, input_ids, context_vector, collapse_mode='weighted_sum'):
base_embeds = self.base_embeddings(input_ids)
superposed_embeds = self.superposed_embeddings(input_ids)

if collapse_mode == 'weighted_sum':
collapsed_embeds = base_embeds + context_vector.unsqueeze(-1) * superposed_embeds
elif collapse_mode == 'dot_product':
scale = torch.sum(superposed_embeds * context_vector.unsqueeze(-1), dim=-1, keepdim=True)
collapsed_embeds = base_embeds + scale * superposed_embeds
elif collapse_mode == 'cosine_similarity':
scale = F.cosine_similarity(superposed_embeds, context_vector.unsqueeze(-1), dim=-1).unsqueeze(-1)
collapsed_embeds = base_embeds + scale * superposed_embeds
elif collapse_mode == 'gated':
gate = torch.sigmoid(context_vector)
collapsed_embeds = base_embeds + gate.unsqueeze(-1) * superposed_embeds
elif collapse_mode == 'concat_linear':
concatenated = torch.cat([base_embeds, superposed_embeds], dim=-1)
collapsed_embeds = self.linear_transform(concatenated)
else:
raise ValueError("Invalid collapse mode selected")

return collapsed_embeds

# # Example Usage
# vocab_size = 10000
# embed_dim = 512

# model = QuantumSuperpositionEmbeddings(vocab_size, embed_dim)
# input_ids = torch.randint(0, vocab_size, (1, 10))
# context_vector = torch.rand(1, 10)

# # Test different collapse modes
# for mode in ['weighted_sum', 'dot_product', 'cosine_similarity', 'gated', 'concat_linear']:
# embeddings = model(input_ids, context_vector, collapse_mode=mode)
# print(f"Collapse mode: {mode}, Embeddings shape: {embeddings.shape}")
58 changes: 58 additions & 0 deletions zeta/nn/embeddings/qft_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
from torch import nn
import numpy as np


class QFTSPEmbeddings(nn.Module):
"""Quantum Fourier Transform-inspired Shift Phase Embeddings.
Attributes:
vocab_size (int): The size of the vocabulary.
dim (int): The dimensionality of the embeddings.
Methods:
forward(x: torch.Tensor) -> torch.Tensor: Forward pass of the QFTSPEmbeddings module.
Example:
>>> vocab_size = 10000
>>> dim = 512
>>> model = QFTSPEmbeddings(vocab_size, dim)
>>> x = torch.randint(0, vocab_size, (1, 10))
>>> embeddings = model(x)
>>> print(embeddings)
"""

def __init__(
self, vocab_size: int = None, dim: int = None, *args, **kwargs
):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim

self.embeddings = nn.Embedding(vocab_size, dim, *args, **kwargs)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the QFTSPEmbeddings module.
Args:
x (torch.Tensor): input tensor
Returns:
torch.Tensor: phase shifted embeddings
"""
# real valued embeddings
embeds = self.embeddings(x)

# Quantum-inspired operation: Phase shift
# Split embed_dim into two halves for real and imaginary parts
phase_shift = torch.exp(2j * np.pi * torch.rand(self.dim // 2))
shifted_embeds = torch.cat(
[
embeds[:, :, : self.dim // 2] * phase_shift.real,
embeds[:, :, self.dim // 2 :] * phase_shift.imag,
],
dim=-1,
)

return shifted_embeds
Loading

0 comments on commit 2c89b26

Please sign in to comment.