-
Notifications
You must be signed in to change notification settings - Fork 51
/
Copy pathfocal_loss_layer.cu
134 lines (124 loc) · 5.36 KB
/
focal_loss_layer.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <vector>
#include "caffe/layers/focal_loss_layer.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
template <typename Dtype>
__global__ void FocalLossForwardGPU(const int nthreads,
const Dtype* input_data, const Dtype* sigmoid_data, const Dtype* target, Dtype* scale, Dtype* oriloss,
const bool has_ignore_label_, const int ignore_label_,
Dtype* counts, float alpha, float gamma) {
CUDA_KERNEL_LOOP(i, nthreads) {
const int target_value = static_cast<int>( target[ i ] );
if ( has_ignore_label_ && target_value == ignore_label_ ) {
scale[ i ] = 0;
oriloss[ i ] = 0;
counts[ i ] = 0;
}
else {
scale[ i ] = (target_value == 1 ? alpha : 1 - alpha) * powf(1 - ( target_value == 1 ? sigmoid_data[ i ] : ( 1 - sigmoid_data[ i ] ) ), gamma);
oriloss[ i ] = -(input_data[ i ] * ( target[ i ] - ( input_data[ i ] >= 0 ) ) -
log(1 + exp(input_data[ i ] - 2 * input_data[ i ] *
( input_data[ i ] >= 0 ))));
counts[ i ] = 1;
}
}
}
template <typename Dtype>
__global__ void FocalLossBackwardSecondItemGPU(const int nthreads,
const Dtype* input_data, const Dtype* sigmoid_data, const Dtype* target, float alpha, float gamma, Dtype* secondItem) {
CUDA_KERNEL_LOOP(i, nthreads) {
const int target_value = static_cast<int>( target[ i ] );
Dtype expabsx = expf(input_data[ i ] > 0 ? -input_data[ i ] : input_data[ i ]);
secondItem[ i ] = (target_value == 1 ? alpha : 1 - alpha) * gamma *
powf(1 - ( target_value == 1 ? sigmoid_data[ i ] : ( 1 - sigmoid_data[ i ] ) ), gamma - 1) *
expabsx / ( powf(expabsx, 2) + 2 * expabsx + 1 ) *
( target_value == 1 ? -1 : 1 );
}
}
template <typename Dtype>
__global__ void FocalLossIgnoreDiffGPU(const int count,
const int ignore_label, const Dtype* target, Dtype* diff) {
CUDA_KERNEL_LOOP(i, count) {
const int target_value = static_cast<int>( target[ i ] );
if ( target_value == ignore_label ) {
diff[ i ] = 0;
}
}
}
template <typename Dtype>
void FocalLossLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
// The forward pass computes the sigmoid outputs.
sigmoid_bottom_vec_[ 0 ] = bottom[ 0 ];
sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
// Compute the loss (negative log likelihood)
const int count = bottom[ 0 ]->count();
// Stable version of loss computation from input data
const Dtype* input_data = bottom[ 0 ]->gpu_data();
const Dtype* target = bottom[ 1 ]->gpu_data();
const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
// Since this memory is not used for anything until it is overwritten
// on the backward pass, we use it here to avoid having to allocate new GPU
// memory to accumulate intermediate results in the kernel.
Dtype* loss_data = bottom[ 0 ]->mutable_gpu_diff();
Dtype* count_data = bottom[ 1 ]->mutable_gpu_diff();
Dtype valid_count;
// NOLINT_NEXT_LINE(whitespace/operators)
FocalLossForwardGPU<Dtype> << <CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS >> >( count, input_data, sigmoid_output_data,
target, scaler_.mutable_gpu_data(), scaler_.mutable_gpu_diff(),
has_ignore_label_, ignore_label_, count_data, alpha_, gamma_ );
caffe_gpu_mul(count, scaler_.gpu_data(), scaler_.gpu_diff() , loss_data);
// Only launch another CUDA kernel if we actually need the valid count.
if ( normalization_ == LossParameter_NormalizationMode_VALID &&
has_ignore_label_ ) {
caffe_gpu_asum(count, count_data, &valid_count);
}
else {
valid_count = count;
}
Dtype loss;
caffe_gpu_asum(count, loss_data, &loss);
normalizer_ = get_normalizer(normalization_, valid_count);
top[ 0 ]->mutable_cpu_data()[ 0 ] = loss / normalizer_;
}
template <typename Dtype>
void FocalLossLayer<Dtype>::Backward_gpu(
const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
// scale_.data := scale .diff := oriloss
if ( propagate_down[ 1 ] ) {
LOG(FATAL) << this->type()
<< " Layer cannot backpropagate to label inputs.";
}
if ( propagate_down[ 0 ] ) {
// First, compute the diff
const int count = bottom[ 0 ]->count();
const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
const Dtype* target = bottom[ 1 ]->gpu_data();
const Dtype* input_data = bottom[ 0 ]->gpu_data();
Dtype* bottom_diff = bottom[ 0 ]->mutable_gpu_diff();
// First item: d(oriloss)*scale
caffe_copy(count, sigmoid_output_data, bottom_diff);
caffe_gpu_axpy(count, Dtype(-1), target, bottom_diff);
caffe_gpu_mul(count, scaler_.gpu_data(), bottom[ 0 ]->gpu_diff(), bottom_diff);
// Second item: oriloss*d(scale)
// save result in scaler_.data
FocalLossBackwardSecondItemGPU<Dtype> << <CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS >> >( count,
input_data, sigmoid_output_data, target, alpha_, gamma_, scaler_.mutable_gpu_data() );
caffe_gpu_mul(count, scaler_.gpu_data(), scaler_.gpu_diff(), scaler_.mutable_gpu_data());
caffe_gpu_add(count, scaler_.gpu_data(), bottom[ 0 ]->gpu_diff(), bottom_diff);
// Zero out gradient of ignored targets.
if ( has_ignore_label_ ) {
// NOLINT_NEXT_LINE(whitespace/operators)
FocalLossIgnoreDiffGPU<Dtype> << <CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS >> >( count, ignore_label_, target, bottom_diff );
}
// Scale down gradient
Dtype loss_weight = top[ 0 ]->cpu_diff()[ 0 ] / normalizer_;
caffe_gpu_scal(count, loss_weight, bottom_diff);
}
}
INSTANTIATE_LAYER_GPU_FUNCS(FocalLossLayer);
} // namespace caffe