Skip to content

Commit

Permalink
add back legacy lamb code for backward comptibility now
Browse files Browse the repository at this point in the history
  • Loading branch information
FDecaYed committed Aug 17, 2019
1 parent 18062b6 commit 2bc766c
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 0 deletions.
150 changes: 150 additions & 0 deletions csrc/multi_tensor_lamb_stage_1.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>

#include <assert.h>

#include "type_shim.h"
#include "multi_tensor_apply.cuh"

#define BLOCK_SIZE 512
#define ILP 4

// Step 1 computes the 'update' value of regular Adam optimizer.
template<typename GRAD_T, typename T, typename UPD_T>
struct LAMBStage1Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<5>& tl,
const float* per_tensor_decay,
const float beta1,
const float beta2,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
const float clipped_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;

int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

float decay = per_tensor_decay[tensor_num];

GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;

T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;

T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;

T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;

UPD_T* update = (UPD_T*)tl.addresses[4][tensor_loc];
update += chunk_idx*chunk_size;

n -= chunk_idx*chunk_size;

// see note in multi_tensor_scale_kernel.cu
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
GRAD_T r_g[ILP];
T r_p[ILP];
T r_m[ILP];
T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = GRAD_T(0);
r_p[ii] = T(0);
r_m[ii] = T(0);
r_v[ii] = T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + (1-beta1) * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
T next_m_unbiased = r_m[ii] / beta1_correction;
T next_v_unbiased = r_v[ii] / beta2_correction;
T denom = std::sqrt(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
update[i] = (UPD_T)r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};

void multi_tensor_lamb_stage1_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_decay,
const int step,
const float beta1,
const float beta2,
const float epsilon,
const float global_grad_norm,
const float max_global_grad_norm)
{
using namespace at;

float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
float next_step = float(step+1);
float beta1_correction = 1.0f - std::pow(beta1, next_step);
float beta2_correction = 1.0f - std::pow(beta2, next_step);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage1Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),
per_tensor_decay.data<float>(),
beta1,
beta2,
beta1_correction,
beta2_correction,
epsilon,
clipped_global_grad_norm); )))

AT_CUDA_CHECK(cudaGetLastError());

// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
109 changes: 109 additions & 0 deletions csrc/multi_tensor_lamb_stage_2.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>

#include <assert.h>

#include "type_shim.h"
#include "multi_tensor_apply.cuh"

#define BLOCK_SIZE 512
#define ILP 4

// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template<typename T, typename UPD_T>
struct LAMBStage2Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;

int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;

T* p = (T*)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;

UPD_T* update = (UPD_T*)tl.addresses[1][tensor_loc];
update += chunk_idx*chunk_size;

n -= chunk_idx*chunk_size;

for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
T r_p[ILP];
UPD_T r_update[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio*(T)r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
}
}
}
}
};

void multi_tensor_lamb_stage2_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate)
{
using namespace at;

DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage2Functor<scalar_t_0, scalar_t_1>(),
per_tensor_param_norm.data<float>(),
per_tensor_update_norm.data<float>(),
learning_rate); ))

AT_CUDA_CHECK(cudaGetLastError());

// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu',
'csrc/multi_tensor_lamb_stage_1.cu',
'csrc/multi_tensor_lamb_stage_2.cu',
'csrc/multi_tensor_adam.cu',
'csrc/multi_tensor_novograd.cu',
'csrc/multi_tensor_lamb.cu'],
Expand Down

0 comments on commit 2bc766c

Please sign in to comment.