Skip to content

Commit

Permalink
Copy seq_lengths before creating descriptor
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 519771897
  • Loading branch information
sharadmv authored and jax authors committed Mar 27, 2023
1 parent 88c2898 commit 3c3fa04
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 1 deletion.
4 changes: 4 additions & 0 deletions jax/experimental/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
See https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#lstm
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnRNNMode_t
"""
if seq_lengths.dtype != jnp.dtype("int32"):
raise NotImplementedError("`seq_lengths` can only be int32.")
if dropout != 0.0:
raise NotImplementedError(
'Dropout not supported in LSTM reference because we cannot determine CUDNN dropout mask.'
Expand Down Expand Up @@ -326,6 +328,8 @@ def lstm_cell(carry, x, *, W_ih, W_hh, b_ih, b_hh):
def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array,
input_size: int, hidden_size: int, num_layers: int, dropout: float,
bidirectional: bool):
if seq_lengths.dtype != jnp.dtype("int32"):
raise NotImplementedError("`seq_lengths` can only be int32.")
y, h_n, c_n, workspace, reserve_space = rnn_fwd_p.bind(
x,
h_0,
Expand Down
5 changes: 4 additions & 1 deletion jaxlib/gpu/rnn_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,12 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
float padding = 0.0f;

auto seq_lengths_buf = buffers[11];
std::vector<int32_t> seq_length_vector(d.batch_size, d.max_seq_length);
int32_t* seq_length_array = &seq_length_vector[0];
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpy(
seq_length_array, seq_lengths_buf,
seq_length_vector.size() * sizeof(int32_t), gpuMemcpyDeviceToHost)));

cudnnRNNDataDescriptor_t input_data_desc;
JAX_RETURN_IF_ERROR(
Expand Down Expand Up @@ -367,7 +371,6 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
auto workspace_buf = buffers[8];
auto reserve_space_buf = buffers[9];
auto zeroed_dw_buf = buffers[10];
auto seq_lengths_buf = buffers[11];
auto dx_buf = buffers[12];
auto dh_0_buf = buffers[13];
auto dc_0_buf = buffers[14];
Expand Down
64 changes: 64 additions & 0 deletions tests/experimental_rnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,70 @@ def f(x, h_0, c_0, weights):
np.testing.assert_allclose(h_n_ref, h_n, rtol=1e-05, atol=1e-5)
np.testing.assert_allclose(c_n_ref, c_n, rtol=1e-05, atol=1e-5)

@jtu.skip_on_devices("cpu", "tpu", "rocm")
def test_lstm_with_varying_seq_lens(self):
batch_size = 6
seq_len = 7
input_size = 8
hidden_size = 12
num_layers = 5
bidirectional = False
num_directions = 2 if bidirectional else 1

seq_lengths = jnp.array([4, 5, 1, 1, 1, 1], dtype=jnp.dtype("int32"))

root_key = jax.random.PRNGKey(1)
k1, k2, k3, k4 = jax.random.split(root_key, 4)
x = jax.random.normal(
k1, (batch_size, seq_len, input_size), dtype=jnp.float32)
h_0 = jax.random.normal(
k2, (num_directions * num_layers, batch_size, hidden_size),
dtype=jnp.float32)
c_0 = jax.random.normal(
k3, (num_directions * num_layers, batch_size, hidden_size),
dtype=jnp.float32)
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
bidirectional)

@jax.jit
def f(x, h_0, c_0, weights):
return rnn.lstm(
x,
h_0,
c_0,
weights,
seq_lengths=seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=False,
bidirectional=bidirectional)

jtu.check_grads(f, (x, h_0, c_0, weights), modes=['rev'], order=1)

# TODO(sharadmv): enable when lstm_ref works with seq_lengths
# W_ih, W_hh, b_ih, b_hh = rnn.unpack_lstm_weights(weights, input_size,
# hidden_size, num_layers,
# bidirectional)
# y_ref, h_n_ref, c_n_ref = rnn.lstm_ref(
# x,
# h_0,
# c_0,
# W_ih,
# W_hh,
# b_ih,
# b_hh,
# seq_lengths=seq_lengths,
# input_size=input_size,
# hidden_size=hidden_size,
# num_layers=num_layers,
# dropout=False,
# bidirectional=bidirectional)

# np.testing.assert_allclose(y_ref, y, rtol=1e-05, atol=1e-5)
# np.testing.assert_allclose(h_n_ref, h_n, rtol=1e-05, atol=1e-5)
# np.testing.assert_allclose(c_n_ref, c_n, rtol=1e-05, atol=1e-5)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 3c3fa04

Please sign in to comment.