Skip to content

Commit

Permalink
Add bidirectional attention support to attention pattern (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyrdt authored Aug 22, 2023
1 parent 8605b16 commit 4d5869b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
10 changes: 10 additions & 0 deletions python/circuitsvis/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def attention_heads(
min_value: Optional[float] = None,
negative_color: Optional[str] = None,
positive_color: Optional[str] = None,
mask_upper_tri: Optional[bool] = None,
) -> RenderedHTML:
"""Attention Heads
Expand All @@ -37,6 +38,9 @@ def attention_heads(
positive_color: Color for positive values. This can be any valid CSS
color string. Be mindful of color blindness if not using the default
here.
mask_upper_tri: Whether or not to mask the upper triangular portion of
the attention patterns. Should be true for causal attention, false for
bidirectional attention.
Returns:
Html: Attention pattern visualization
Expand All @@ -49,6 +53,7 @@ def attention_heads(
"negativeColor": negative_color,
"positiveColor": positive_color,
"tokens": tokens,
"maskUpperTri": mask_upper_tri,
}

return render(
Expand Down Expand Up @@ -90,6 +95,7 @@ def attention_pattern(
negative_color: Optional[str] = None,
show_axis_labels: Optional[bool] = None,
positive_color: Optional[str] = None,
mask_upper_tri: Optional[bool] = None,
) -> RenderedHTML:
"""Attention Pattern
Expand All @@ -112,6 +118,9 @@ def attention_pattern(
positive_color: Color for positive values. This can be any valid CSS
color string. Be mindful of color blindness if not using the default
here.
mask_upper_tri: Whether or not to mask the upper triangular portion of
the attention patterns. Should be true for causal attention, false for
bidirectional attention.
Returns:
Html: Attention pattern visualization
Expand All @@ -124,6 +133,7 @@ def attention_pattern(
"negativeColor": negative_color,
"positiveColor": positive_color,
"showAxisLabels": show_axis_labels,
"maskUpperTri": mask_upper_tri,
}

return render(
Expand Down
16 changes: 16 additions & 0 deletions react/src/attention/AttentionHeads.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export function AttentionHeadsSelector({
onMouseEnter,
onMouseLeave,
positiveColor,
maskUpperTri,
tokens
}: AttentionHeadsProps & {
attentionHeadNames: string[];
Expand Down Expand Up @@ -89,6 +90,7 @@ export function AttentionHeadsSelector({
minValue={minValue}
negativeColor={negativeColor}
positiveColor={positiveColor}
maskUpperTri={maskUpperTri}
/>
</div>
</div>
Expand All @@ -112,6 +114,7 @@ export function AttentionHeads({
minValue,
negativeColor,
positiveColor,
maskUpperTri = true,
tokens
}: AttentionHeadsProps) {
// Attention head focussed state
Expand All @@ -137,6 +140,7 @@ export function AttentionHeads({
onMouseEnter={onMouseEnter}
onMouseLeave={onMouseLeave}
positiveColor={positiveColor}
maskUpperTri={maskUpperTri}
tokens={tokens}
/>

Expand Down Expand Up @@ -165,6 +169,7 @@ export function AttentionHeads({
negativeColor={negativeColor}
positiveColor={positiveColor}
zoomed={true}
maskUpperTri={maskUpperTri}
tokens={tokens}
/>
</div>
Expand Down Expand Up @@ -241,6 +246,17 @@ export interface AttentionHeadsProps {
*/
positiveColor?: string;

/**
* Mask upper triangular
*
* Whether or not to mask the upper triangular portion of the attention patterns.
*
* Should be true for causal attention, false for bidirectional attention.
*
* @default true
*/
maskUpperTri?: boolean;

/**
* Show axis labels
*/
Expand Down
21 changes: 18 additions & 3 deletions react/src/attention/AttentionPattern.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export function AttentionPattern({
upperTriColor = DefaultUpperTriColor,
showAxisLabels = true,
zoomed = false,
maskUpperTri = true,
tokens
}: AttentionPatternProps) {
// Tokens must be unique (for the categories), so we add an index prefix
Expand Down Expand Up @@ -96,7 +97,7 @@ export function AttentionPattern({
// Set the background color for each block, based on the attention value
backgroundColor(context: ScriptableContext<"matrix">) {
const block = context.dataset.data[context.dataIndex] as any as Block;
if (block.srcIdx > block.destIdx) {
if (maskUpperTri && block.srcIdx > block.destIdx) {
// Color the upper triangular part separately
return colord(upperTriColor).toRgbString();
}
Expand Down Expand Up @@ -130,7 +131,10 @@ export function AttentionPattern({
title: () => "", // Hide the title
label({ raw }: TooltipItem<"matrix">) {
const block = raw as Block;
if (block.destIdx < block.srcIdx) return "N/A"; // Just show N/A for the upper triangular part
if (maskUpperTri && block.destIdx < block.srcIdx) {
// Just show N/A for the upper triangular part
return "N/A";
}
return [
`(${block.destIdx}, ${block.srcIdx})`,
`Src: ${block.srcToken}`,
Expand Down Expand Up @@ -259,11 +263,22 @@ export interface AttentionPatternProps {
*/
positiveColor?: string;

/**
* Mask upper triangular
*
* Whether or not to mask the upper triangular portion of the attention patterns.
*
* Should be true for causal attention, false for bidirectional attention.
*
* @default true
*/
maskUpperTri?: boolean;

/**
* Upper triangular color
*
* Color to use for the upper triangular part of the attention pattern to make visualization slightly nicer.
* The upper triangular part is irrelevant because of the causal mask.
* Only applied if maskUpperTri is set to true.
*
* @default rgb(200, 200, 200)
*
Expand Down

0 comments on commit 4d5869b

Please sign in to comment.