-
Notifications
You must be signed in to change notification settings - Fork 0
/
bert_padding.py
132 lines (108 loc) · 5.76 KB
/
bert_padding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, 'b ... -> b (...)')
grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device, dtype=grad_output.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device,
dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
# memory format to channel_first. In other words, input might not be contiguous.
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
return output, input.detach()
@staticmethod
def backward(ctx, grad_output, grad_residual):
indices, = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape
grad_input = grad_residual
# grad_input[indices] += grad_output
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
indices = indices.expand_as(grad_output)
grad_input.scatter_add_(0, indices, grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis_residual = IndexFirstAxisResidual.apply
def unpad_input(hidden_states, attention_mask):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices,
cu_seqlens, max_seqlen_in_batch)
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, '(b s) ... -> b s ...', b=batch)