forked from baidu-research/warp-ctc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ctc.h
218 lines (200 loc) · 10.1 KB
/
ctc.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
/** \file ctc.h
* Contains a simple C interface to call fast CPU and GPU based computation
* of the CTC loss.
*/
#pragma once
#ifdef _WIN32
#ifdef warpctc_EXPORTS
#define API_REFERENCE extern "C" __declspec(dllexport)
#else
#define API_REFERENCE extern "C" __declspec(dllimport)
#endif
#else
#define API_REFERENCE
#endif
#include <stdio.h>
#ifdef __cplusplus
#include <cstddef>
extern "C" {
#endif
#ifdef WARPCTC_WITH_HIP
//forward declare of HIP typedef to avoid needing to pull in HIP headers
typedef struct ihipStream_t* GPUstream;
#else
//forward declare of CUDA typedef to avoid needing to pull in CUDA headers
typedef struct CUstream_st* GPUstream;
#endif
typedef enum {
CTC_STATUS_SUCCESS = 0,
CTC_STATUS_MEMOPS_FAILED = 1,
CTC_STATUS_INVALID_VALUE = 2,
CTC_STATUS_EXECUTION_FAILED = 3,
CTC_STATUS_UNKNOWN_ERROR = 4
} ctcStatus_t;
/** Returns a single integer which specifies the API version of the warpctc library */
API_REFERENCE int get_warpctc_version();
/** Returns a string containing a description of status that was passed in
* \param[in] status identifies which string should be returned
* \return C style string containing the text description
* */
API_REFERENCE const char* ctcGetStatusString(ctcStatus_t status);
typedef enum {
CTC_CPU = 0,
CTC_GPU = 1
} ctcComputeLocation;
/** Structure used for options to the CTC compution. Applications
* should zero out the array using memset and sizeof(struct
* ctcOptions) in C or default initialization (e.g. 'ctcOptions
* options{};' or 'auto options = ctcOptions{}') in C++ to ensure
* forward compatibility with added options. */
struct ctcOptions {
/// indicates where the ctc calculation should take place {CTC_CPU | CTC_GPU}
ctcComputeLocation loc;
union {
/// used when loc == CTC_CPU, the maximum number of threads that can be used
unsigned int num_threads;
/// used when loc == CTC_GPU, which stream the kernels should be launched in
GPUstream stream;
};
/// the label value/index that the CTC calculation should use as the blank label
int blank_label;
};
/** Compute the connectionist temporal classification loss between
* a probability sequence with dtype float and a ground truth labeling.
* Optionally compute the gradient with respect to the inputs.
* \param [in] activations pointer to the activations in either CPU or GPU
* addressable memory, depending on info. We assume a fixed
* memory layout for this 3 dimensional tensor, which has dimension
* (t, n, p), where t is the time index, n is the minibatch index,
* and p indexes over probabilities of each symbol in the alphabet.
* The memory layout is (t, n, p) in C order (slowest to fastest changing
* index, aka row-major), or (p, n, t) in Fortran order (fastest to slowest
* changing index, aka column-major). We also assume strides are equal to
* dimensions - there is no padding between dimensions.
* More precisely, element (t, n, p), for a problem with mini_batch examples
* in the mini batch, and alphabet_size symbols in the alphabet, is located at:
* activations[(t * mini_batch + n) * alphabet_size + p]
* \param [out] gradients if not NULL, then gradients are computed. Should be
* allocated in the same memory space as probs and memory
* ordering is identical.
* \param [in] flat_labels Always in CPU memory. A concatenation
* of all the labels for the minibatch.
* \param [in] label_lengths Always in CPU memory. The length of each label
* for each example in the minibatch.
* \param [in] input_lengths Always in CPU memory. The number of time steps
* for each sequence in the minibatch.
* \param [in] alphabet_size The number of possible output symbols. There
* should be this many probabilities for each time step.
* \param [in] mini_batch How many examples in a minibatch.
* \param [out] costs Always in CPU memory. The cost of each example in the
* minibatch.
* \param [in,out] workspace In same memory space as probs. Should be of
* size requested by get_workspace_size.
* \param [in] options see struct ctcOptions
*
* \return Status information
*
* */
API_REFERENCE ctcStatus_t compute_ctc_loss(const float* const activations,
float* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths,
int alphabet_size,
int minibatch,
float *costs,
void *workspace,
ctcOptions options);
/** Compute the connectionist temporal classification loss between
* a probability sequence of dtype double and a ground truth labeling.
* Optionally compute the gradient with respect to the inputs.
* \param [in] activations pointer to the activations in either CPU or GPU
* addressable memory, depending on info. We assume a fixed
* memory layout for this 3 dimensional tensor, which has dimension
* (t, n, p), where t is the time index, n is the minibatch index,
* and p indexes over probabilities of each symbol in the alphabet.
* The memory layout is (t, n, p) in C order (slowest to fastest changing
* index, aka row-major), or (p, n, t) in Fortran order (fastest to slowest
* changing index, aka column-major). We also assume strides are equal to
* dimensions - there is no padding between dimensions.
* More precisely, element (t, n, p), for a problem with mini_batch examples
* in the mini batch, and alphabet_size symbols in the alphabet, is located at:
* activations[(t * mini_batch + n) * alphabet_size + p]
* \param [out] gradients if not NULL, then gradients are computed. Should be
* allocated in the same memory space as probs and memory
* ordering is identical.
* \param [in] flat_labels Always in CPU memory. A concatenation
* of all the labels for the minibatch.
* \param [in] label_lengths Always in CPU memory. The length of each label
* for each example in the minibatch.
* \param [in] input_lengths Always in CPU memory. The number of time steps
* for each sequence in the minibatch.
* \param [in] alphabet_size The number of possible output symbols. There
* should be this many probabilities for each time step.
* \param [in] mini_batch How many examples in a minibatch.
* \param [out] costs Always in CPU memory. The cost of each example in the
* minibatch.
* \param [in,out] workspace In same memory space as probs. Should be of
* size requested by get_workspace_size.
* \param [in] options see struct ctcOptions
*
* \return Status information
*
* */
API_REFERENCE ctcStatus_t compute_ctc_loss_double(const double* const activations,
double* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths,
int alphabet_size,
int minibatch,
double *costs,
void *workspace,
ctcOptions options);
/** For a given set of labels and minibatch size return the required workspace
* size when the dtype of your probabilities is float. This will need to be allocated
* in the same memory space as your probabilities.
* \param [in] label_lengths Always in CPU memory. The length of each label
* for each example in the minibatch.
* \param [in] input_lengths Always in CPU memory. The number of time steps
* for each sequence in the minibatch.
* \param [in] alphabet_size How many symbols in the alphabet or, equivalently,
* the number of probabilities at each time step
* \param [in] mini_batch How many examples in a minibatch.
* \param [in] info see struct ctcOptions
* \param [out] size_bytes is pointer to a scalar where the memory
* requirement in bytes will be placed. This memory should be allocated
* at the same place, CPU or GPU, that the probs are in
*
* \return Status information
**/
API_REFERENCE ctcStatus_t get_workspace_size(const int* const label_lengths,
const int* const input_lengths,
int alphabet_size, int minibatch,
ctcOptions info,
size_t* size_bytes);
/** For a given set of labels and minibatch size return the required workspace
* size when the dtype of your probabilities is double. This will need to be allocated
* in the same memory space as your probabilities.
* \param [in] label_lengths Always in CPU memory. The length of each label
* for each example in the minibatch.
* \param [in] input_lengths Always in CPU memory. The number of time steps
* for each sequence in the minibatch.
* \param [in] alphabet_size How many symbols in the alphabet or, equivalently,
* the number of probabilities at each time step
* \param [in] mini_batch How many examples in a minibatch.
* \param [in] info see struct ctcOptions
* \param [out] size_bytes is pointer to a scalar where the memory
* requirement in bytes will be placed. This memory should be allocated
* at the same place, CPU or GPU, that the probs are in
*
* \return Status information
**/
API_REFERENCE ctcStatus_t get_workspace_size_double(const int* const label_lengths,
const int* const input_lengths,
int alphabet_size, int minibatch,
ctcOptions info,
size_t* size_bytes);
#ifdef __cplusplus
}
#endif