From 9c94349e29f2a5f078fedd5ea6752feb6e725d66 Mon Sep 17 00:00:00 2001 From: Jesse Vig <45317205+jessevig@users.noreply.github.com> Date: Sat, 2 Apr 2022 05:47:17 -0700 Subject: [PATCH] Enable multiple neuron views per notebook --- bertviz/neuron_view.js | 27 +++++------ bertviz/neuron_view.py | 101 ++++++++++++++++------------------------- 2 files changed, 54 insertions(+), 74 deletions(-) diff --git a/bertviz/neuron_view.js b/bertviz/neuron_view.js index 85572d4..6ce5b33 100644 --- a/bertviz/neuron_view.js +++ b/bertviz/neuron_view.js @@ -11,6 +11,7 @@ * 01/16/21 Jesse Vig Dark mode * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step * 03/23/22 Daniel SC Update requirement URLs for d3 and jQuery (source of bug not allowing end result to be displayed on browsers) + * 04/02/22 Jesse Vig Enable multiple neuron views per notebook **/ require.config({ @@ -83,16 +84,16 @@ requirejs(['jquery', 'd3'], var keys = attnData.keys[config.layer][config.head]; var att = attnData.attn[config.layer][config.head]; - $("#bertviz #vis").empty(); + $(`#${config.rootDivId} #vis`).empty(); var height = config.initialTextLength * BOXHEIGHT + HEIGHT_PADDING; - var svg = d3.select("#bertviz #vis") + var svg = d3.select(`#${config.rootDivId} #vis`) .append('svg') .attr("width", "100%") .attr("height", height + "px"); - d3.select("#bertviz") + d3.select(`#${config.rootDivId}`) .style("background-color", getColor('background')); - d3.selectAll("#bertviz .dropdown-label") + d3.selectAll(`#${config.rootDivId} .dropdown-label`) .style("color", getColor('dropdown')) renderVisExpanded(svg, leftText, rightText, queries, keys); @@ -942,11 +943,11 @@ requirejs(['jquery', 'd3'], function showCollapsed() { if (config.index != null) { - var svg = d3.select("#bertviz #vis"); + var svg = d3.select(`#${config.rootDivId} #vis`); highlightSelection(svg, config.index); } - d3.select("#bertviz #expanded").attr("visibility", "hidden"); - d3.select("#bertviz #collapsed").attr("visibility", "visible"); + d3.select(`#${config.rootDivId} #expanded`).attr("visibility", "hidden"); + d3.select(`#${config.rootDivId} #collapsed`).attr("visibility", "visible"); } function showExpanded() { @@ -955,8 +956,8 @@ requirejs(['jquery', 'd3'], highlightSelection(svg, config.index); showComputation(svg, config.index); } - d3.select("#bertviz #expanded").attr("visibility", "visible"); - d3.select("#bertviz #collapsed").attr("visibility", "hidden") + d3.select(`#${config.rootDivId} #expanded`).attr("visibility", "visible"); + d3.select(`#${config.rootDivId} #collapsed`).attr("visibility", "hidden") } function getColor(name) { @@ -977,9 +978,9 @@ requirejs(['jquery', 'd3'], config.mode = params['display_mode']; config.layer = (params['layer'] == null ? 0 : params['layer']) config.head = (params['head'] == null ? 0 : params['head']) + config.rootDivId = params['root_div_id']; - - const layerSelect = $("#bertviz #layer"); + const layerSelect = $(`#${config.rootDivId} #layer`); layerSelect.empty(); for (var i = 0; i < config.nLayers; i++) { layerSelect.append($("").val(i).text(i)); @@ -990,7 +991,7 @@ requirejs(['jquery', 'd3'], render(); }); - const headSelect = $("#bertviz #att_head"); + const headSelect = $(`#${config.rootDivId} #att_head`); headSelect.empty(); for (var i = 0; i < config.nHeads; i++) { headSelect.append($("").val(i).text(i)); @@ -1001,7 +1002,7 @@ requirejs(['jquery', 'd3'], render(); }); - $("#bertviz #filter").on('change', function (e) { + $(`#${config.rootDivId} #filter`).on('change', function (e) { config.filter = e.currentTarget.value; render(); }); diff --git a/bertviz/neuron_view.py b/bertviz/neuron_view.py index ae87842..3005a37 100644 --- a/bertviz/neuron_view.py +++ b/bertviz/neuron_view.py @@ -34,16 +34,11 @@ from IPython.core.display import display, HTML, Javascript -def show(model, model_type, tokenizer, sentence_a, sentence_b=None, display_mode='dark', layer=None, head=None, html_action='view'): - - # Generate unique div id to enable multiple visualizations in one notebook +def show(model, model_type, tokenizer, sentence_a, sentence_b=None, display_mode='dark', layer=None, head=None, + html_action='view'): if sentence_b: - vis_html = """ -