Skip to content

Commit

Permalink
Add attention_bias argument in transformer block and transformer la…
Browse files Browse the repository at this point in the history
…yer modules, addressing change in MCore (#11289)

* fix api

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* fix ci

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* add docstring

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* fix docstring2

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* fix line too long

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

---------

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>
Co-authored-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>
  • Loading branch information
yaoyu-33 and yaoyu-33 authored Nov 18, 2024
1 parent 59b8b48 commit 168c3e5
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def forward(
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
inference_params=None,
packed_seq_params=None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def forward(
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
inference_params=None,
packed_seq_params=None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def forward(
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
inference_params=None,
packed_seq_params=None, # TODO: handle this
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,20 @@ def forward(
rotary_pos_emb: Tensor = None,
rotary_pos_cos: Tensor = None,
rotary_pos_sin: Tensor = None,
attention_bias: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
):
hidden_states = super().forward(
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
inference_params,
packed_seq_params,
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
)

mlp_head_adapter = self.get_adapter_module(AdapterName.MLP_HEAD_ADAPTER)
Expand Down Expand Up @@ -232,6 +233,7 @@ def forward(
packed_seq_params=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
):
# hidden_states: [sq, b, h]

Expand Down
116 changes: 99 additions & 17 deletions nemo/collections/vlm/mllama/model/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@


def to_2tuple(x):
"""
Convert an input to a 2-tuple.
"""
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
Expand All @@ -71,9 +74,16 @@ def _stack_images(
max_num_images: int,
) -> Tuple[torch.Tensor, List[int]]:
"""
Takes a list of list of images and stacks them into a tensor.
This function is needed since images can be of completely
different resolutions and aspect ratios.
Stack a list of image lists into a tensor while accounting for varying resolutions and aspect ratios.
Args:
images (List[List[PIL_Image.Image]]): List of image lists for stacking.
max_num_chunks (int): Maximum number of chunks per image.
image_res (int): Target resolution for each image.
max_num_images (int): Maximum number of images to stack.
Returns:
Tuple[torch.Tensor, List[int]]: Tensor of stacked images and a list of chunk counts for each image.
"""
out_images, out_num_chunks = [], []
for imgs_sample in images:
Expand All @@ -97,7 +107,17 @@ def build_encoder_attention_mask(
x: torch.Tensor, ar_ids: torch.Tensor, ntok: int, num_chunks: int, supported_aspect_ratios: List[List[int]]
):
"""
Build vision encoder attention mask that omits padding tiles and tokens.
Build attention masks for a vision encoder to handle padding and token alignment.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, sequence_length).
ar_ids (torch.Tensor): Aspect ratio IDs for masking.
ntok (int): Number of tokens.
num_chunks (int): Number of chunks in the data.
supported_aspect_ratios (List[List[int]]): List of supported aspect ratios.
Returns:
torch.Tensor: Tensor containing the attention mask.
"""
masks = []
for ar_id in ar_ids:
Expand All @@ -113,6 +133,9 @@ def build_encoder_attention_mask(


def apply_scaling(freqs: torch.Tensor):
"""
Scale frequency values based on predefined thresholds and a smoothing factor.
"""
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
Expand All @@ -137,6 +160,9 @@ def apply_scaling(freqs: torch.Tensor):

# Use this spec for an implementation using modules in TE
def get_image_transformer_layer_spec() -> ModuleSpec:
"""
Create a specification for an image transformer layer.
"""
image_transformer_submodules = TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
Expand Down Expand Up @@ -175,6 +201,10 @@ def forward_with_return_intermediate(
packed_seq_params: PackedSeqParams = None,
return_intermediate: List[int] = None,
):
"""
Perform a forward pass through the transformer layers with optional intermediate outputs.
Override regular MCore transformer layer forward pass.
"""
# hidden_states (float): [s, b, h]
# attention_mask (bool): [1, 1, s, s]

Expand Down Expand Up @@ -278,16 +308,22 @@ def forward_with_return_intermediate(


class ColumnParallelConv2dPatch(MegatronModule):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, width, height)
Output: (bsz, num_tokens, out_channels)
"""
Conv2D Patching layer with model parallelism. Applies convolution in a column-parallel fashion.
Args:
config (TransformerConfig): Configuration object for the layer.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (Union[int, Tuple[int, int]]): Size of the convolution kernel.
stride (Union[int, Tuple[int, int]]): Stride of the convolution.
bias (Optional[bool], default=False): Whether to include a bias term.
Input:
torch.Tensor: Input tensor of shape (batch_size, in_channels, width, height).
Output:
torch.Tensor: Output tensor of shape (batch_size, num_tokens, out_channels).
"""

def __init__(
Expand Down Expand Up @@ -316,6 +352,7 @@ def __init__(
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward."""
x = self._unfold(x)
x = x.permute(0, 2, 1)
x = F.linear(x, self._linear.weight)
Expand All @@ -324,6 +361,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class PrecomputedTilePositionEmbedding(torch.nn.Module):
"""
Module to compute positional embeddings for tiles with optional gating.
Args:
config (TransformerConfig): Configuration object.
gated (bool, default=False): Whether to apply gating to the embeddings.
"""

def __init__(
self,
config: TransformerConfig,
Expand All @@ -340,6 +385,7 @@ def __init__(
self.gate = nn.Parameter(torch.zeros(1))

def forward(self, hidden_states: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
"""Forward."""
embeddings = self.embedding(aspect_ratio_ids)
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)

Expand All @@ -351,7 +397,15 @@ def forward(self, hidden_states: torch.Tensor, aspect_ratio_ids: torch.Tensor) -


class SelfAttentionNoBias(SelfAttention):
"""Self-attention layer class without bias"""
"""
Self-attention layer implementation without bias.
Args:
config (TransformerConfig): Configuration for the transformer.
submodules (SelfAttentionSubmodules): Submodules required for self-attention.
layer_number (int): The layer number in the transformer stack.
attn_mask_type (AttnMaskType): Type of attention mask to apply.
"""

def __init__(
self,
Expand Down Expand Up @@ -396,6 +450,16 @@ def __init__(


class ImageTransformerLayer(TransformerLayer):
"""
Transformer layer adapted for processing image data with optional gating.
Args:
config (TransformerConfig): Transformer configuration object.
submodules (TransformerLayerSubmodules): Submodules to use in the layer.
layer_number (int, default=1): Layer number in the transformer.
hidden_dropout (float, optional): Dropout rate for hidden layers.
"""

def __init__(
self,
config: TransformerConfig,
Expand Down Expand Up @@ -423,9 +487,11 @@ def forward(
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
inference_params=None,
packed_seq_params=None,
):
"""Forward."""
# hidden_states: [s, b, h]

# Residual connection.
Expand Down Expand Up @@ -485,6 +551,19 @@ def forward(


class VisionEncoder(MegatronModule):
"""
Vision encoder module for processing image inputs with patch-based embeddings.
Args:
config ('CrossAttentionVisionConfig'): Configuration object for the encoder.
image_size (int, default=560): Input image size.
patch_size (int, default=14): Size of patches extracted from the image.
in_channels (int, default=3): Number of input channels.
pre_process (bool, default=True): Whether to preprocess input.
post_process (bool, default=True): Whether to postprocess output.
return_intermediate (Optional[bool]): Whether to return intermediate layers.
"""

def __init__(
self,
config: 'CrossAttentionVisionConfig',
Expand Down Expand Up @@ -556,7 +635,7 @@ def __init__(
self.gated_positional_embedding_gate = nn.Parameter(torch.zeros(1))

def apply_positional_embedding(self, x, aspect_ratio_ids):
# apply regular position embedding
"""Apply regular position embedding and tile positonal embedding."""
bsz, num_chunks, num_tokens, dim = x.shape
x = x.view(bsz * num_chunks, num_tokens, dim)
x = x + self.positional_embedding * (1 - self.gated_positional_embedding_gate.tanh())
Expand All @@ -567,6 +646,7 @@ def apply_positional_embedding(self, x, aspect_ratio_ids):
return x

def apply_class_embedding(self, x):
"""Concat class embedding tokens."""
x = torch.cat(
[
self.class_embedding.to(x.dtype)
Expand All @@ -578,6 +658,7 @@ def apply_class_embedding(self, x):
return x

def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor:
"""Forward."""
if images.ndim == 5:
num_concurrent_media = 1
bsz, num_chunks, nch, w, h = images.shape
Expand Down Expand Up @@ -617,7 +698,8 @@ def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor:
return_intermediate=self.return_intermediate,
)

# [ntok * num_concurrent_media * num_chunks, bsz, hidden_size] -> [bsz, ntok * num_concurrent_media * num_chunks, hidden_size]
# [ntok * num_concurrent_media * num_chunks, bsz, hidden_size]
# -> [bsz, ntok * num_concurrent_media * num_chunks, hidden_size]
x, int_x = x.transpose(0, 1).contiguous(), int_x.transpose(0, 1).contiguous()
x = self.ln_post(x)
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim)
Expand Down

0 comments on commit 168c3e5

Please sign in to comment.