Skip to content

Commit

Permalink
Fix the calculation of layer_norm_bwd (PaddlePaddle#53224)
Browse files Browse the repository at this point in the history
* Fix the calculation of layer_norm_bwd

* fix
  • Loading branch information
ZzSean committed Apr 24, 2023
1 parent bfa5d6b commit a0aff19
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions paddle/phi/kernels/funcs/layer_norm_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -1603,13 +1603,13 @@ __global__ void LayerNormBackwardGradientAll(

for (int64_t i = beg_idx; i < end_idx; i += stride) {
int row_idx = i / feature_size;
auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon);
auto var_val = rsqrt_(static_cast<U>(var[row_idx]) + epsilon);
d_scale_partial += static_cast<U>(d_y[i]) *
(static_cast<U>(x[i]) - mean[row_idx]) / var_val;
(static_cast<U>(x[i]) - mean[row_idx]) * var_val;
d_bias_partial += static_cast<U>(d_y[i]);
if (HasDx) {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
static_cast<U>(scale[blockIdx.x + col_offset]) /
static_cast<U>(scale[blockIdx.x + col_offset]) *
var_val);
}
}
Expand Down Expand Up @@ -1659,10 +1659,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
for (int64_t i = beg_idx; i < end_idx; i += stride) {
int row_idx = i / feature_size;
auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon));
static_cast<U>(rsqrt_(static_cast<float>(var[row_idx]) + epsilon));
if (HasDScale) {
d_scale_or_d_bias_partial += static_cast<U>(d_y[i]) *
(static_cast<U>(x[i]) - mean[row_idx]) /
(static_cast<U>(x[i]) - mean[row_idx]) *
var_val;
} else { // d_bias != nullptr
d_scale_or_d_bias_partial += static_cast<U>(d_y[i]);
Expand All @@ -1671,10 +1671,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
if (HasDx) {
if (scale != nullptr) {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
static_cast<U>(scale[blockIdx.x + col_offset]) /
static_cast<U>(scale[blockIdx.x + col_offset]) *
var_val);
} else {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) * var_val);
}
}
}
Expand Down Expand Up @@ -1762,13 +1762,13 @@ __global__ void LayerNormBackwardGradientOnlyDX(
U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
static_cast<U>(rsqrt_(static_cast<float>(block_var) + epsilon));
if (scale != nullptr) {
int col_idx = i % feature_size;
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
static_cast<U>(scale[col_idx]) / var_val);
static_cast<U>(scale[col_idx]) * var_val);
} else {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) * var_val);
}
d_x_mean_partial += static_cast<U>(d_x[i]);
d_x_var_partial +=
Expand Down Expand Up @@ -1812,21 +1812,20 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
if (idx < feature_size) {
auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[0]) + epsilon));
auto var_val = static_cast<U>(rsqrt_(static_cast<float>(var[0]) + epsilon));
if (d_x != nullptr) {
if (d_scale == nullptr) {
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) * var_val);
} else {
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) *
static_cast<U>(scale[idx]) / var_val);
static_cast<U>(scale[idx]) * var_val);
}
}

if (d_scale != nullptr) {
d_scale[idx] =
static_cast<ScaleBiasT>(static_cast<U>(d_y[idx]) *
(static_cast<U>(x[idx]) - mean[0]) / var_val);
(static_cast<U>(x[idx]) - mean[0]) * var_val);
}

if (d_bias != nullptr) {
Expand Down

0 comments on commit a0aff19

Please sign in to comment.