-
Notifications
You must be signed in to change notification settings - Fork 3
/
scratch.py
32 lines (23 loc) · 674 Bytes
/
scratch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import jax
import jax.numpy as jnp
# create your pairwise function
def distance(a, b):
return jnp.linalg.norm(a - b)
# vmap based combinator to operate on all pairs
def all_pairs(f):
f = jax.vmap(f, in_axes=(None, 0))
f = jax.vmap(f, in_axes=(0, None))
return f
# transform to operate over sets
distances = all_pairs(distance)
# create some test data
A = jnp.array([[0, 0], [1, 1], [2, 2]])
B = jnp.array([[-10, -10], [-20, -20]])
# compute distance of the first two points
d00 = distance(A[0], B[0])
# 14.142136
# now compute the distance of all pairs
D = distances(A, B)
# [[14.142136 28.284271]
# [15.556349 29.698484]
# [16.970562 31.112698]]