Skip to content

Commit

Permalink
llama : add custom RoPE (ggerganov#2054)
Browse files Browse the repository at this point in the history
* Implement customizable RoPE

The original RoPE has pre-defined parameters

theta_i = 10000^(−2(i−1)/d), for i in [1, 2, ..., d/2]

Our customizable RoPE, ggml_rope_custom_inplace, uses

theta_i = scale * base^(−2(i−1)/d), for i in [1, 2, ..., d/2]

with the default matches the original

scale = 1.0
base = 10000

The new command line arguments
--rope-freq-base
--rope-freq-scale
set the two new RoPE parameter.

Recent researches show changing these two parameters extends the context limit with minimal loss.

1. Extending Context to 8K
   kaiokendev
   https://kaiokendev.github.io/til#extending-context-to-8k

2. Extending Context Window of Large Language Models via Positional Interpolation
   Shouyuan Chen, Sherman Wong, Liangjian Chen, Yuandong Tian
   https://arxiv.org/abs/2306.15595

3. NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.
   https://www.reddit.com/user/bloc97
   https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/

For the bold, try adding the following command line parameters to your favorite model:
-c 16384 --rope-freq-base 80000 --rope-freq-scale 0.5

* ggml-metal: fix custom rope

* common: fix argument names in help

* llama: increase MEM_REQ_EVAL for MODEL_3B

It avoids crashing for quantized weights on CPU.
Better ways to calculate the required buffer size would be better.

* llama: make MEM_REQ_EVAL depend on n_ctx

* server: use proper Content-Type in curl examples

Without the header Content-Type: application/json, curl will POST with
Content-Type: application/x-www-form-urlencoded

Though our simple server doesn't care, the httplib.h used has a limit
with CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192

With Content-Type: application/json, we can send large json data.

* style : minor fixes, mostly indentations

* ggml : fix asserts

---------

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
jxy and ggerganov committed Jul 15, 2023
1 parent a6803ca commit 6e7cca4
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 67 deletions.
16 changes: 16 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_ctx = std::stoi(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_base = std::stof(argv[i]);
} else if (arg == "--rope-freq-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
} else if (arg == "--top-p") {
Expand Down Expand Up @@ -493,6 +505,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor);
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
Expand Down Expand Up @@ -573,6 +587,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
lparams.rope_freq_base = params.rope_freq_base;
lparams.rope_freq_scale = params.rope_freq_scale;

return lparams;
}
Expand Down
2 changes: 2 additions & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct gpt_params {
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
float rope_freq_base = 10000.0f; // RoPE base frequency
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor

// sampling parameters
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
Expand Down
12 changes: 10 additions & 2 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,17 @@ int main(int argc, char ** argv) {
return 0;
}

if (params.rope_freq_base != 10000.0) {
fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
}

if (params.rope_freq_scale != 1.0) {
fprintf(stderr, "%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
}

if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
"expect poor results\n", __func__, params.n_ctx);
fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
" you are on your own\n", __func__, params.n_ctx);
} else if (params.n_ctx < 8) {
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
Expand Down
1 change: 1 addition & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Using [curl](https://curl.se/). On Windows `curl.exe` should be available in the
```sh
curl --request POST \
--url http://localhost:8080/completion \
--header "Content-Type: application/json" \
--data '{"prompt": "Building a website can be done in 10 simple steps:","n_predict": 128}'
```

Expand Down
2 changes: 2 additions & 0 deletions examples/server/chat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ tokenize() {
--silent \
--request POST \
--url "${API_URL}/tokenize" \
--header "Content-Type: application/json" \
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
| jq '.tokens[]'
}
Expand Down Expand Up @@ -64,6 +65,7 @@ chat_completion() {
--no-buffer \
--request POST \
--url "${API_URL}/completion" \
--header "Content-Type: application/json" \
--data-raw "${DATA}")

printf "\n"
Expand Down
18 changes: 18 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
fprintf(stderr, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
Expand Down Expand Up @@ -722,6 +724,22 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_ctx = std::stoi(argv[i]);
}
else if (arg == "--rope-freq-base")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_base = std::stof(argv[i]);
}
else if (arg == "--rope-freq-scale")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_scale = std::stof(argv[i]);
}
else if (arg == "--memory-f32" || arg == "--memory_f32")
{
params.memory_f16 = false;
Expand Down
45 changes: 26 additions & 19 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -881,28 +881,35 @@ void ggml_metal_graph_compute(

const int n_past = ((int32_t *)(src1->data))[0];

float freq_base;
float freq_scale;
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));

[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
Expand Down
6 changes: 4 additions & 2 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -656,17 +656,19 @@ kernel void kernel_rope(
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant float & freq_base,
constant float & freq_scale,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i3 = tpig[2];
const int64_t i2 = tpig[1];
const int64_t i1 = tpig[0];

const bool is_neox = mode & 2;
const float theta_scale = pow(10000.0, -2.0f/n_dims);
const float theta_scale = pow(freq_base, -2.0f/n_dims);

const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);

float theta = (float)p;
float theta = freq_scale * (float)p;

if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
Expand Down
50 changes: 38 additions & 12 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -6956,6 +6956,8 @@ struct ggml_tensor * ggml_rope_impl(
int n_past,
int n_dims,
int mode,
float freq_base,
float freq_scale,
int n_ctx,
bool inplace) {
GGML_ASSERT(n_past >= 0);
Expand All @@ -6969,12 +6971,14 @@ struct ggml_tensor * ggml_rope_impl(

ggml_scratch_save(ctx);

struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);

((int32_t *) b->data)[0] = n_past;
((int32_t *) b->data)[1] = n_dims;
((int32_t *) b->data)[2] = mode;
((int32_t *) b->data)[3] = n_ctx;
memcpy((int32_t *) b->data + 4, &freq_base, sizeof(float));
memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));

ggml_scratch_load(ctx);

Expand All @@ -6993,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
int n_dims,
int mode,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
}

struct ggml_tensor * ggml_rope_inplace(
Expand All @@ -7003,7 +7007,19 @@ struct ggml_tensor * ggml_rope_inplace(
int n_dims,
int mode,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
}

struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode,
float freq_base,
float freq_scale,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
}

// ggml_rope_back
Expand Down Expand Up @@ -12074,16 +12090,21 @@ static void ggml_compute_forward_rope_f32(
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_nelements(src1) == 4);
GGML_ASSERT(ggml_nelements(src1) == 6);

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

float freq_base;
float freq_scale;

const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3];
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));

assert(n_past >= 0);

Expand Down Expand Up @@ -12112,7 +12133,7 @@ static void ggml_compute_forward_rope_f32(
// row index used to determine which thread to use
int ir = 0;

const float theta_scale = powf(10000.0, -2.0f/n_dims);
const float theta_scale = powf(freq_base, -2.0f/n_dims);

const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
Expand All @@ -12124,7 +12145,7 @@ static void ggml_compute_forward_rope_f32(
if (ir++ < ir0) continue;
if (ir > ir1) break;

float theta = (float)p;
float theta = freq_scale * (float)p;

if (is_glm) {
theta = MIN(p, n_ctx - 2);
Expand Down Expand Up @@ -12201,16 +12222,21 @@ static void ggml_compute_forward_rope_f16(
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_nelements(src1) == 4);
GGML_ASSERT(ggml_nelements(src1) == 6);

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

float freq_base;
float freq_scale;

const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3];
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));

assert(n_past >= 0);

Expand Down Expand Up @@ -12239,7 +12265,7 @@ static void ggml_compute_forward_rope_f16(
// row index used to determine which thread to use
int ir = 0;

const float theta_scale = powf(10000.0, -2.0f/n_dims);
const float theta_scale = powf(freq_base, -2.0f/n_dims);

const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
Expand All @@ -12251,7 +12277,7 @@ static void ggml_compute_forward_rope_f16(
if (ir++ < ir0) continue;
if (ir > ir1) break;

float theta = (float)p;
float theta = freq_scale * (float)p;

if (is_glm) {
theta = MIN(p, n_ctx - 2);
Expand Down Expand Up @@ -12312,7 +12338,7 @@ static void ggml_compute_forward_rope_f16(
const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);

dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
}
Expand Down Expand Up @@ -15710,7 +15736,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama
if (src0->grad) {
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 4);
assert(ggml_nelements(src1) == 6);
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
Expand All @@ -15731,7 +15757,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
if (src0->grad) {
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 4);
assert(ggml_nelements(src1) == 3);
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
Expand Down
11 changes: 11 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,17 @@ extern "C" {
int mode,
int n_ctx);

// custom RoPE, in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode,
float freq_base,
float freq_scale,
int n_ctx);

// rotary position embedding backward, i.e compute dx from dy
// a - dy
GGML_API struct ggml_tensor * ggml_rope_back(
Expand Down
Loading

0 comments on commit 6e7cca4

Please sign in to comment.