Skip to content

Commit

Permalink
[DOCSTRINGS][zeta.nn.biases ++ zeta.nn.embeddings]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 23, 2023
1 parent d07d002 commit 05f20f5
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 11 deletions.
29 changes: 29 additions & 0 deletions zeta/nn/biases/alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ def pad_at_dim(t, pad, dim=-1, value=0.0):


class AlibiPositionalBias(BaseBias):
"""
AlibiPositionalBias class represents a positional bias module for neural networks.
Args:
heads (int): Number of heads in the neural network.
num_heads (int): Number of heads in the neural network.
Attributes:
slopes (Tensor): Tensor containing the slopes for the bias.
bias (Tensor): Tensor containing the bias values.
Methods:
get_bias(i, j, device): Returns the bias tensor for the given indices.
forward(i, j): Computes and returns the bias tensor for the given indices.
"""

def __init__(self, heads, num_heads, **kwargs):
super().__init__()
self.heads = heads
Expand Down Expand Up @@ -81,6 +98,18 @@ def forward(self, i, j):


class LearnedAlibiPositionalBias(AlibiPositionalBias):
"""
LearnedAlibiPositionalBias is a subclass of AlibiPositionalBias that introduces learned biases.
Args:
heads (int): Number of attention heads.
num_heads (int): Number of heads per layer.
Attributes:
learned_logslopes (nn.Parameter): Learned logarithmic slopes.
"""

def __init__(self, heads, num_heads):
super().__init__(heads, num_heads)
log_slopes = torch.log(self.slopes)
Expand Down
9 changes: 9 additions & 0 deletions zeta/nn/embeddings/abc_pos_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@


class AbsolutePositionalEmbedding(nn.Module):
"""
Absolute Positional Embedding module.
Args:
dim (int): The dimension of the embedding.
max_seq_len (int): The maximum sequence length.
l2norm_embed (bool, optional): Whether to apply L2 normalization to the embeddings. Defaults to False.
"""

def __init__(self, dim, max_seq_len, l2norm_embed=False):
super().__init__()
self.scale = dim**-0.5 if not l2norm_embed else 1.0
Expand Down
11 changes: 0 additions & 11 deletions zeta/nn/embeddings/bnb_embedding.py

This file was deleted.

12 changes: 12 additions & 0 deletions zeta/nn/embeddings/positional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ def forward(
positions=None,
**kwargs,
):
"""
Forward pass of the PositionalEmbedding module.
Args:
x (torch.Tensor): Input tensor.
positions (torch.Tensor, optional): Positions tensor. If None, positions are generated based on the input tensor size. Default is None.
**kwargs: Additional keyword arguments.
Returns:
torch.Tensor: Embedded tensor.
"""
if positions is None:
# being consistent with Fairseq, which starts from 2.
positions = (
Expand Down
28 changes: 28 additions & 0 deletions zeta/rl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@


class ActorCritic(nn.Module):
"""
A class representing an Actor-Critic model for Proximal Policy Optimization (PPO).
Args:
num_inputs (int): The number of input features.
num_outputs (int): The number of output actions.
hidden_size (int): The size of the hidden layer.
Attributes:
critic (nn.Sequential): The critic network.
actor (nn.Sequential): The actor network.
Methods:
forward(x): Performs a forward pass through the network.
"""

def __init__(self, num_inputs, num_outputs, hidden_size):
super(ActorCritic, self).__init__()
self.critic = nn.Sequential(
Expand All @@ -18,6 +35,17 @@ def __init__(self, num_inputs, num_outputs, hidden_size):
)

def forward(self, x):
"""
Performs a forward pass through the network.
Args:
x (torch.Tensor): The input tensor.
Returns:
dist (torch.distributions.Categorical): The probability distribution over actions.
value (torch.Tensor): The estimated value of the input state.
"""
value = self.critic(x)
probs = self.actor(x)
dist = torch.distributions.Categorical(probs)
Expand Down
28 changes: 28 additions & 0 deletions zeta/rl/vision_model_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@


class ResidualBlock(nn.Module):
"""
Residual Block module for a vision model.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
stride (int, optional): Stride value for the convolutional layers. Defaults to 1.
"""

def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
Expand Down Expand Up @@ -32,6 +41,25 @@ def forward(self, x):


class VisionRewardModel(nn.Module):
"""
VisionRewardModel is a neural network model that extracts image features and predicts rewards.
Args:
None
Attributes:
layer1 (ResidualBlock): The first residual block for image feature extraction.
layer2 (ResidualBlock): The second residual block for image feature extraction.
layer3 (ResidualBlock): The third residual block for image feature extraction.
layer4 (ResidualBlock): The fourth residual block for image feature extraction.
fc1 (nn.Linear): The fully connected layer for feature transformation.
fc2 (nn.Linear): The fully connected layer for reward prediction.
Methods:
forward(x): Performs forward pass through the network.
"""

def __init__(self):
super(VisionRewardModel, self).__init__()

Expand Down

0 comments on commit 05f20f5

Please sign in to comment.