Skip to content

Commit

Permalink
Fixed single GPU performance regression
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Jun 5, 2023
1 parent 4f9640b commit 11af678
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 0 deletions.
133 changes: 133 additions & 0 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,30 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
return false;
}

bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
size_t src0_sz = ggml_nbytes(src0);
size_t src1_sz = ggml_nbytes(src1);

// mul_mat_q: src0 is converted to fp32 on device
size_t mul_mat_q_transfer = src0_sz + src1_sz;

// mul_mat_f16: src1 is converted to fp16 on cpu
size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);

// choose the smaller one to transfer to the device
// TODO: this is not always the best choice due to the overhead of converting to fp16
return mul_mat_f16_transfer < mul_mat_q_transfer;
}

size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
return ggml_nelements(src1) * sizeof(ggml_fp16_t);
}
else {
return 0;
}
}

void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));

Expand All @@ -950,6 +974,99 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
}
}

static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];

const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];

const int nb10 = src1->nb[0];
const int nb11 = src1->nb[1];
const int nb12 = src1->nb[2];
const int nb13 = src1->nb[3];

const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];

const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
const int n_mm = ne03 * ne02;

size_t x_size, y_size, d_size;
half * d_X = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
half * d_Y = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);

bool src1_cont_rows = nb10 == sizeof(float);
bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);

for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
int i = i03*ne02 + i02;
cudaStream_t cudaStream = g_cudaStreams_main[0][i % GGML_CUDA_MAX_STREAMS];

half * c_X = d_X + i * x_ne;
half * c_Y = d_Y + i * y_ne;
float * c_D = d_D + i * d_ne;

// copy src0 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, 0, ne01, cudaStream));

// convert src1 to fp16
// TODO: use multiple threads
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
if (src1_cont_rows) {
if (src1_cont_cols) {
ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
}
else {
for (int64_t i01 = 0; i01 < ne11; i01++) {
ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
}
}
}
else {
for (int64_t i01 = 0; i01 < ne11; i01++) {
for (int64_t i00 = 0; i00 < ne10; i00++) {
// very slow due to no inlining
tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
}
}
}

// copy src1 to device
CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));

// compute
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[0], cudaStream));
CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[0], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, c_X, CUDA_R_16F, ne00,
c_Y, CUDA_R_16F, ne10,
&beta, c_D, CUDA_R_32F, ne01,
CUBLAS_COMPUTE_32F_FAST_16F,
CUBLAS_GEMM_DEFAULT));

// copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
}
}

CUDA_CHECK(cudaDeviceSynchronize());
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
}

void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset, int n_layer) {
FILE * fp = fopen(fname, "rb");
int nrows = ggml_nrows(tensor);
Expand Down Expand Up @@ -1054,6 +1171,22 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
if (!ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
return false;
}

// For prompt processing the multi GPU code is currently slower than the single GPU code that existed before.
// To avoid a performance regression the old code is kept for now:
if (g_device_count == 1 && tensor->src0->type == GGML_TYPE_F16 &&
ggml_cuda_mul_mat_use_f16(tensor->src0, tensor->src1, tensor)) {

if (params->ith != 0) {
return true;
}
if (params->type == GGML_TASK_COMPUTE) {
ggml_cuda_mul_mat_f16(tensor->src0, tensor->src1, tensor, params->wdata, params->wsize);
return true;
}

return false;
}
func = ggml_cuda_mul_mat;
break;
default:
Expand Down
1 change: 1 addition & 0 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -14203,6 +14203,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
node->n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
}
else
#elif defined(GGML_USE_CLBLAST)
Expand Down

0 comments on commit 11af678

Please sign in to comment.