Skip to content

Commit

Permalink
feat(ui/circuit): show attention pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Sep 12, 2024
1 parent 48a03a6 commit 67bfb0b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
6 changes: 6 additions & 0 deletions ui/src/components/model/circuit.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ const NodeInfo = ({ node }: { node: Node<NodeData> }) => {
<div className="text-sm">{node.data.tracingNode.key}</div>
<div className="text-sm font-bold">Score:</div>
<div className="text-sm">{node.data.tracingNode.activation.toFixed(3)}</div>
<div className="text-sm font-bold">Pattern:</div>
<div className={cn("text-sm", getAccentClassname(node.data.tracingNode.pattern, 1, "text"))}>
{node.data.tracingNode.pattern.toFixed(3)}
</div>
</div>
</div>
);
Expand Down Expand Up @@ -211,6 +215,8 @@ export const CircuitViewer = memo(
const getNodeClassNames = useCallback((node: TracingNode) => {
if (node.type === "feature") {
return cn(getAccentClassname(node.activation, node.maxActivation, "border"));
} else if (node.type === "attn-score") {
return cn(getAccentClassname(node.pattern, 1, "border"));
}
return "";
}, []);
Expand Down
1 change: 1 addition & 0 deletions ui/src/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const AttnScoreNodeSchema = z.object({
query: z.number(),
key: z.number(),
activation: z.number(),
pattern: z.number(),
});

export const TracingNodeSchema = z.discriminatedUnion("type", [
Expand Down

0 comments on commit 67bfb0b

Please sign in to comment.