-
Notifications
You must be signed in to change notification settings - Fork 6
/
distance_metrics.py
87 lines (62 loc) · 2.31 KB
/
distance_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from gnn_3 import *
#from ged_graph_data import *
def pairwise_euclidean_similarity(x, y):
"""Compute the pairwise Euclidean similarity between x and y.
This function computes the following similarity value between each pair of x_i
and y_j: s(x_i, y_j) = -|x_i - y_j|^2.
Args:
x: NxD float tensor.
y: MxD float tensor.
Returns:
s: NxM float tensor, the pairwise euclidean similarity.
"""
y_transpose = torch.transpose(y, 0, 1)
s = torch.matmul(x, y_transpose)
diag_x = torch.sum(x * x, dim=-1, keepdim=True)#.unsqueeze(-1)
diag_y = torch.sum(y * y, dim=-1).reshape([1,-1])
return s - diag_x - diag_y
def pairwise_dot_product_similarity(x, y):
"""Compute the dot product similarity between x and y.
This function computes the following similarity value between each pair of x_i
and y_j: s(x_i, y_j) = x_i * y_j^T.
Args:
x: NxD float tensor.
y: MxD float tensor.
Returns:
s: NxM float tensor, the pairwise dot product similarity.
"""
y_transpose = torch.transpose(y, 0, 1)
return torch.matmul(x, y_transpose)
def pairwise_cosine_similarity(x, y):
"""Compute the cosine similarity between x and y.
This function computes the following similarity value between each pair of x_i
and y_j: s(x_i, y_j) = x_i^T y_j / (|x_i||y_j|).
Args:
x: NxD float tensor.
y: MxD float tensor.
Returns:
s: NxM float tensor, the pairwise cosine similarity.
"""
x = torch.norm(x, dim=-1)
y = torch.norm(y, dim=-1)
y_transpose = torch.transpose(y, 0, 1)
return torch.matmul(x, y_transpose)
PAIRWISE_SIMILARITY_FUNCTION = {
'euclidean': pairwise_euclidean_similarity,
'dotproduct': pairwise_dot_product_similarity,
'cosine': pairwise_cosine_similarity,
}
def get_pairwise_similarity(name):
"""Get pairwise similarity metric by name.
Args:
:param name: string, name of the similarity metric, one of {dot-product, cosine,
euclidean}.
Returns:
similarity: a (x, y) -> sim function.
Raises:
ValueError: if name is not supported.
"""
if name not in PAIRWISE_SIMILARITY_FUNCTION:
raise ValueError('Similarity metric name "%s" not supported.' % name)
else:
return PAIRWISE_SIMILARITY_FUNCTION[name]