Differences between the attention computation in JAX and Torch #24694
-
I may have overlooked/misunderstood something in the documentation, so before opening an issue I am opening a discussion for this. The outputs of the import numpy as np
import torch
from torch import nn
import jax
import jax.numpy as jnp
import equinox as eqx
def port_weights(eqx_layer, torch_layer):
"""Port weights from a torch layer to Equinox layer"""
new_weight = torch_layer.state_dict()["weight"].numpy()
where = lambda l: l.weight
eqx_layer = eqx.tree_at(where, eqx_layer, new_weight)
return eqx_layer
in_channels = 32
batch_size = 2
key = jax.random.PRNGKey(1)
key1, key2, key3, key4 = jax.random.split(key, 3)
# Linear layers in Equinox
qj = eqx.nn.Linear(in_channels, in_channels, key=key1, use_bias=False)
kj = eqx.nn.Linear(in_channels, in_channels, key=key2, use_bias=False)
vj = eqx.nn.Linear(in_channels, in_channels, key=key3, use_bias=False)
# Linear layers in Torch
qt = nn.Linear(in_channels, in_channels, bias=False)
kt = nn.Linear(in_channels, in_channels, bias=False)
vt = nn.Linear(in_channels, in_channels, bias=False)
# Porting weights from torch to equinox
qj = port_weights(qj, qt)
kj = port_weights(kj, kt)
vj = port_weights(vj, vt)
# Random sample
x = np.random.rand(batch_size, in_channels).astype(np.float32)
xt = torch.from_numpy(x)
xj = jax.numpy.asarray(x)
# Torch outputs
with torch.no_grad():
qt_o = qt(xt)
kt_o = kt(xt)
vt_o = vt(xt)
attn_t = nn.functional.scaled_dot_product_attention(qt_o, kt_o, vt_o, is_causal=False)
# JAX outputs
qj_o = eqx.filter_vmap(qj)(xj)
kj_o = eqx.filter_vmap(kj)(xj)
vj_o = eqx.filter_vmap(vj)(xj)
attn_j = jax.nn.dot_product_attention(qj_o, kj_o, vj_o, is_causal=False)
# Check intermediate outputs
np.allclose(qt_o.numpy(), qj_o, atol=1e-6)
np.allclose(kt_o.numpy(), kj_o, atol=1e-6)
np.allclose(vt_o.numpy(), vj_o, atol=1e-6)
# Check attention outputs
np.allclose(attn_t.numpy(), attn_j, atol=1e-6) # this fails! Am I missing.misunderstanding something here? 🤔 |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
I wonder if @kaixih can weigh in here? |
Beta Was this translation helpful? Give feedback.
-
@dfm I got it. The |
Beta Was this translation helpful? Give feedback.
@dfm I got it. The
dot_product_attention(...)
in torch acceptsBNTH
while in JAX it acceptsBTNH
. JAX implementation makes more sense as I am aware why torch has that implementation. Guess I overlooked the dimensions in the documentation.