-
Notifications
You must be signed in to change notification settings - Fork 0
/
rag_explanation.py
511 lines (418 loc) · 20.9 KB
/
rag_explanation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer
import numpy as np
from scipy.special import softmax
from nltk.tokenize import sent_tokenize
import matplotlib.pyplot as plt
# import nltk
# nltk.download('punkt')
def generate_summary(text, model, tokenizer, max_length=256):
"""Generate a concise summary of 1-3 sentences from the input text."""
system_prompt = "You are a helpful assistant. Please provide a very concise summary of the following text in 1-2 sentences."
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text}
]
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer.encode(input_text, return_tensors="pt")
inputs = inputs.to(model.device)
outputs = model.generate(
inputs,
max_new_tokens=max_length,
temperature=0.3,
top_p=0.9,
do_sample=True,
)
full_output = tokenizer.decode(outputs[0])
summary = full_output.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip()
return summary
def load_model():
checkpoint = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
return model, tokenizer
def load_embedding_model():
checkpoint = "BAAI/bge-m3"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = SentenceTransformer(checkpoint).to(device)
return model
def load_cross_encoder_model():
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-m3')
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-v2-m3')
model.eval()
return model, tokenizer
def calculate_attributions(original_text, summary, model, cross_encoder_tokenizer=None, model_type='embedding', agg='mean', normalization_type=None, window_size=1):
"""
Calculate attribution scores between summary and input text sentences using either embedding or cross-encoder model.
Args:
original_text: Input text to analyze
summary: Generated summary text
model: Either SentenceTransformer for embeddings or AutoModelForSequenceClassification for cross-encoding
cross_encoder_tokenizer: Tokenizer for cross-encoder model (only needed if model_type='cross-encoder')
model_type: Either 'embedding' or 'cross-encoder'
agg: Aggregation method ('mean', 'max', or 'weighted') - only used for embedding model
normalization_type: Type of normalization to apply to scores ('softmax', 'min-max', or None)
window_size: Number of sentences to consider in each window (default=1 for single sentences)
Returns:
Tuple containing:
- Dictionary mapping input sentences to their attribution scores
- List of input sentences
- List of summary sentences
- Similarity matrix
"""
# Split texts into sentences
input_sentences = sent_tokenize(original_text)
summary_sentences = sent_tokenize(summary)
# Create sliding windows of sentences
input_windows = []
for i in range(len(input_sentences) - window_size + 1):
window = ' '.join(input_sentences[i:i + window_size])
input_windows.append(window)
if model_type == 'embedding':
# Calculate embeddings for windows instead of single sentences
input_embeddings = model.encode(input_windows)
summary_embeddings = model.encode(summary_sentences)
# Calculate similarity matrix
similarity_matrix = np.zeros((len(summary_sentences), len(input_windows)))
for i, sum_emb in enumerate(summary_embeddings):
for j, inp_emb in enumerate(input_embeddings):
similarity = np.dot(sum_emb, inp_emb) / (np.linalg.norm(sum_emb) * np.linalg.norm(inp_emb))
similarity_matrix[i, j] = similarity
elif model_type == 'cross-encoder': # cross-encoder
# Calculate relevance scores using cross-encoder
similarity_matrix = np.zeros((len(summary_sentences), len(input_windows)))
for i, sum_sent in enumerate(summary_sentences):
# Create pairs for cross-encoding
pairs = [[sum_sent, inp_window] for inp_window in input_windows]
# Calculate scores in batches to avoid memory issues
batch_size = 8
for batch_start in range(0, len(pairs), batch_size):
batch_pairs = pairs[batch_start:batch_start + batch_size]
with torch.no_grad():
inputs = cross_encoder_tokenizer(batch_pairs, padding=True, truncation=True,
return_tensors='pt', max_length=512)
# Move inputs to same device as model
inputs = {k: v.to(model.device) for k, v in inputs.items()}
scores = model(**inputs, return_dict=True).logits.view(-1,).float()
similarity_matrix[i, batch_start:batch_start + len(batch_pairs)] = scores.cpu().numpy()
else:
raise ValueError(f"Invalid model type: {model_type}. Please use 'embedding' or 'cross-encoder'.")
# Expand window scores back to individual sentences
if window_size > 1:
# Initialize expanded matrix with zeros
expanded_scores = np.zeros((similarity_matrix.shape[0], len(input_sentences)))
# For each window, assign its score to all sentences in that window
for i in range(len(input_windows)):
for j in range(window_size):
if i + j < len(input_sentences):
expanded_scores[:, i + j] = np.maximum(
expanded_scores[:, i + j],
similarity_matrix[:, i]
)
similarity_matrix = expanded_scores
# Apply normalization if specified
if normalization_type == 'softmax':
scores = softmax(similarity_matrix, axis=1)
elif normalization_type == 'min-max':
# Min-max normalization as alternative
scores = (similarity_matrix - similarity_matrix.min()) / (similarity_matrix.max() - similarity_matrix.min())
else:
scores = similarity_matrix
# For cross-encoder, we typically just use the raw scores
if model_type == 'cross-encoder':
final_scores = np.mean(scores, axis=0) # Average across summary sentences
else:
# Aggregate scores based on specified method (only for embedding model)
if agg == 'max':
final_scores = np.max(scores, axis=0)
elif agg == 'weighted':
weights = softmax(np.max(similarity_matrix, axis=1))
final_scores = np.average(scores, axis=0, weights=weights)
else: # default to mean
final_scores = np.mean(scores, axis=0)
# Create attribution dictionary
attributions = {
sentence: score
for sentence, score in zip(input_sentences, final_scores)
}
return attributions, input_sentences, summary_sentences, scores
def create_token_heatmap(sentences, attributions, line_height=5.0, show=False):
# Normalize attributions
attributions = np.array(attributions, dtype=float)
min_attr = np.min(attributions)
max_attr = np.max(attributions)
norm_attr = (attributions - min_attr) / (max_attr - min_attr)
# Create figure with more height
total_height = len(sentences) * line_height
fig = plt.figure(figsize=(12, min(12, max(6, total_height/2))), dpi=100) # Added min() to cap height, reduced scaling
gs = fig.add_gridspec(1, 2, width_ratios=[30, 1])
ax = fig.add_subplot(gs[0])
cax = fig.add_subplot(gs[1])
# Remove axis decorations
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)
current_y = total_height - line_height
renderer = fig.canvas.get_renderer()
# Initialize with first word of first sentence
if sentences and len(sentences[0].split()) > 0:
color = plt.cm.viridis(norm_attr[0])
text = ax.text(0.1, current_y, sentences[0].split()[0],
color=color,
fontsize=12)
# Process all sentences continuously
for i, (sentence, score) in enumerate(zip(sentences, norm_attr)):
color = plt.cm.viridis(score)
words = sentence.split()
# Skip the first word of first sentence as it's already placed
start_idx = 1 if i == 0 else 0
for word in words[start_idx:]:
# Get the current text's bbox
prev_bbox = text.get_window_extent(renderer=renderer)
text = ax.annotate(
f" {word}",
xycoords=text,
xy=(1, 0),
xytext=(2, 0), # Small horizontal gap between words
textcoords="offset points",
color=color,
fontsize=12,
verticalalignment="bottom",
)
# Check if we need to wrap to next line
bbox = text.get_window_extent(renderer=renderer)
if bbox.x1 > ax.get_window_extent(renderer=renderer).x1 * 1.0:
current_y -= line_height
text = ax.text(0.1, current_y, word,
color=color,
fontsize=12,
)
# Adjust limits
ax.set_xlim(-0.1, 1.1)
ax.set_ylim(-line_height, total_height + line_height)
# Add colorbar
norm = plt.Normalize(vmin=min_attr, vmax=max_attr)
sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, norm=norm)
plt.colorbar(sm, cax=cax, label='Attribution Score')
plt.suptitle('Text Attribution Visualization (Lighter = Higher Attribution)', fontsize=12)
plt.tight_layout()
if show:
plt.show()
return fig
def determine_threshold(similarity_matrix, method='dynamic'):
"""
Automatically determine similarity threshold using various methods.
Methods:
- 'percentile': Use statistics of the similarity distribution
- 'elbow': Find elbow point in sorted similarities
- 'dynamic': Adaptive thresholding based on distribution gaps
"""
# Flatten similarity matrix and remove zeros/very low values
similarities = similarity_matrix.flatten()
similarities = similarities[similarities > 0.1] # Remove noise floor
if method == 'percentile':
# Use mean + 1 std dev as threshold
return float(np.mean(similarities) + np.std(similarities))
elif method == 'elbow':
# Sort similarities and find "elbow" point
sorted_sims = np.sort(similarities)
n_points = len(sorted_sims)
# Calculate curvature at each point
max_curvature = 0
threshold = 0.3 # fallback
for i in range(1, n_points - 1):
# Approximate curvature using three points
y_diff = sorted_sims[i+1] - 2*sorted_sims[i] + sorted_sims[i-1]
x_diff = 1 # constant x-spacing
curvature = abs(y_diff / (1 + x_diff**2)**1.5)
if curvature > max_curvature:
max_curvature = curvature
threshold = sorted_sims[i]
return float(threshold)
else: # method == 'dynamic'
# Find natural breaks in the similarity distribution
sorted_sims = np.sort(similarities)
gaps = sorted_sims[1:] - sorted_sims[:-1]
# Calculate local statistics in windows
window_size = max(len(gaps) // 10, 1) # 10% of data points
significance_factors = []
for i in range(len(gaps) - window_size):
window = gaps[i:i+window_size]
local_mean = np.mean(window)
local_std = np.std(window)
if local_std == 0:
significance_factors.append(0)
else:
# How many standard deviations is this gap from the local mean?
significance_factors.append((gaps[i] - local_mean) / local_std)
# Find the most significant gap
if significance_factors:
max_significance_idx = np.argmax(significance_factors)
threshold = sorted_sims[max_significance_idx]
# Ensure threshold is reasonable
min_threshold = np.percentile(similarities, 60) # At least top 40%
max_threshold = np.percentile(similarities, 90) # At most top 10%
threshold = np.clip(threshold, min_threshold, max_threshold)
else:
threshold = np.percentile(similarities, 75) # fallback to top 25%
return float(threshold)
def plot_sankey_diagram(similarity_matrix, input_sentences, summary_sentences, max_connections=3):
"""
Create a Sankey diagram using pySankey library with fixed order of sentences.
"""
from pysankey import sankey
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
# Automatically determine threshold
threshold = determine_threshold(similarity_matrix, method='dynamic')
def truncate_text(text, max_len=50):
return text[:max_len] + "..." if len(text) > max_len else text
# Create input and summary labels with truncated text
input_label_map = {
f'Input {i+1}': f'Input {i+1}: {truncate_text(sent)}'
for i, sent in enumerate(input_sentences)
}
summary_label_map = {
f'Summary {i+1}': f'Summary {i+1}: {truncate_text(sent)}'
for i, sent in enumerate(summary_sentences)
}
# Create lists for the diagram
lefts = [] # Input sentence indices
rights = [] # Summary sentence indices
weights = [] # Similarity scores
# Filter connections and create data
for i in range(len(summary_sentences)):
scores = similarity_matrix[i]
top_indices = np.argsort(scores)[-max_connections:]
top_scores = scores[top_indices]
# Only keep connections above threshold
valid_mask = top_scores > threshold
top_indices = top_indices[valid_mask]
top_scores = top_scores[valid_mask]
for idx, score in zip(top_indices, top_scores):
lefts.append(f'Input {idx+1}')
rights.append(f'Summary {i+1}')
weights.append(float(score))
# Get unique labels in reverse order
unique_lefts = sorted(set(lefts), key=lambda x: -int(x.split()[1]))
unique_rights = sorted(set(rights), key=lambda x: -int(x.split()[1]))
# Create the label lists with full text
leftLabels = [input_label_map[label] for label in unique_lefts]
rightLabels = [summary_label_map[label] for label in unique_rights]
# Update the actual connection labels to include the text
lefts = [input_label_map[label] for label in lefts]
rights = [summary_label_map[label] for label in rights]
# Create color dictionary
colors = plt.cm.Set2(np.linspace(0, 1, len(summary_sentences)))
colorDict = {}
# Add colors for all input labels (using a neutral color)
for label in leftLabels:
colorDict[label] = mcolors.rgb2hex(plt.cm.Greys(0.2))
# Add colors for summary labels (using distinct colors)
for i, label in enumerate(rightLabels):
colorDict[label] = mcolors.rgb2hex(colors[i])
# Create figure
plt.figure(figsize=(15, max(8, len(input_sentences) * 0.5)))
# Create Sankey diagram
sankey(
left=lefts,
right=rights,
leftWeight=weights,
rightWeight=weights,
colorDict=colorDict,
leftLabels=leftLabels,
rightLabels=rightLabels,
aspect=20,
fontsize=8,
rightColor=True # Color based on right (summary) labels
)
plt.title(f'Information Flow (threshold={threshold:.3f}, max_connections={max_connections})')
plt.tight_layout()
plt.show()
def run_parameter_sweep(text, summary, embedding_model, cross_encoder_model, cross_encoder_tokenizer):
"""Run parameter sweep across model types, window sizes, and aggregation methods."""
from pathlib import Path
# Create plots directory if it doesn't exist
plots_dir = Path("plots")
plots_dir.mkdir(exist_ok=True)
# Parameter combinations to test
model_types = ['embedding', 'cross-encoder']
window_sizes = [1, 2, 3]
agg_methods = ['mean', 'max'] # New parameter
for model_type in model_types:
# Select appropriate model and tokenizer
if model_type == 'embedding':
model = embedding_model
tokenizer = None
else: # cross-encoder case
model = cross_encoder_model
tokenizer = cross_encoder_tokenizer
# Use all aggregation methods for both model types
aggs_to_test = agg_methods
for window_size in window_sizes:
for agg in aggs_to_test:
# Calculate attributions
attributions, input_sentences, summary_sentences, similarity_matrix = calculate_attributions(
text,
summary,
model,
cross_encoder_tokenizer=tokenizer,
model_type=model_type,
window_size=window_size,
agg=agg # Add aggregation parameter
)
# Create and save heatmap
fig = create_token_heatmap(input_sentences, list(attributions.values()), show=False)
filename = f"plots/heatmap_{model_type}_window{window_size}_agg{agg}.png"
fig.savefig(filename, bbox_inches='tight')
plt.close(fig)
print(f"Saved {filename}")
def main():
text = """This section briefly summarizes the state of the art in the area of semantic segmentation and se-
mantic instance segmentation. As the majority of state-of-the-art techniques in this area are deep
learning approaches we will focus on this area. Early deep learning-based approaches that aim at
assigning semantic classes to the pixels of an image are based on patch classification. Here the
image is decomposed into superpixels in a preprocessing step e.g. by applying the SLIC algorithm
[1]. The superpixels are then padded and classified by using a neural network architecture for im-
age classification. Typically, nowadays there are Convolutional Neural Networks which consist of
a series of 2D convolutions, followed by a number of fully connected layers that use the extracted
features predict a probability for each possible class 2.1.
Other approaches are based on so-called Fully Convolutional Neural Networks (FCNs). Here
not an image patch but the whole image are taken as input and the output is a two-dimensional
feature map that assigns class probabilities to each pixel. Conceptually FCNs are similar to CNNs
used for classification but the fully connected layers are usually replaced by transposed convolu-
tions which have learnable parameters and can learn to upsample the extracted features to the final
pixel-wise classification result.
Standard architectures FCN architectures that are commonly used for semantic segmentation are
e.g. U-Net [73] or architectures with VGG [82] or ResNet-based [34] feature encoders. The archi-
tecture change leads to several advantages such as better computational efficiency, less parameters
and that the network can process images of varying size.
As explained before in semantic segmentation the network only predicts semantic labels for
each pixel of an image. This can be sufficient to also identify object instances in a post-processing
step but as there usually are also challenging cases with e.g. overlapping instances there is also
the need for network architectures that additionally assign instance labels to the input pixels. In
the following two types of deep learning-based semantic instance segmentation techniques will be
reviewed – proposal-based and instance embedding-based techniques."""
model, tokenizer = load_model()
embedding_model = load_embedding_model()
cross_encoder_model, cross_encoder_tokenizer = load_cross_encoder_model()
summary = generate_summary(text, model, tokenizer)
print("\nSummary:")
print(summary)
# Run parameter sweep and save plots
run_parameter_sweep(
text,
summary,
embedding_model,
cross_encoder_model,
cross_encoder_tokenizer
)
# Plot Sankey diagram
attributions, input_sentences, summary_sentences, similarity_matrix = calculate_attributions(text, summary, embedding_model, model_type='embedding', window_size=1, agg='mean')
plot_sankey_diagram(similarity_matrix, input_sentences, summary_sentences)
if __name__ == "__main__":
main()