From 15606ed12fd07bfea79bed6413615adf002da983 Mon Sep 17 00:00:00 2001 From: Zygimantas Straznickas Date: Thu, 20 Apr 2023 18:39:14 -0700 Subject: [PATCH 1/3] Add k/v caching for autoregressive generation --- .gitignore | 2 +- bark/api.py | 8 +++- bark/generation.py | 19 +++++++++- bark/model.py | 93 +++++++++++++++++++++++++++++++++------------- 4 files changed, 93 insertions(+), 29 deletions(-) diff --git a/.gitignore b/.gitignore index 372c13e2..48e4ceb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ __pycache__/ - +.venv \ No newline at end of file diff --git a/bark/api.py b/bark/api.py index 8033dc67..82316168 100644 --- a/bark/api.py +++ b/bark/api.py @@ -10,6 +10,7 @@ def text_to_semantic( history_prompt: Optional[str] = None, temp: float = 0.7, silent: bool = False, + use_kv_caching = False, ): """Generate semantic array from text. @@ -27,6 +28,7 @@ def text_to_semantic( history_prompt=history_prompt, temp=temp, silent=silent, + use_kv_caching=use_kv_caching ) return x_semantic @@ -37,6 +39,7 @@ def semantic_to_waveform( temp: float = 0.7, silent: bool = False, output_full: bool = False, + use_kv_caching = False ): """Generate audio array from semantic input. @@ -55,6 +58,7 @@ def semantic_to_waveform( history_prompt=history_prompt, temp=temp, silent=silent, + use_kv_caching=use_kv_caching ) fine_tokens = generate_fine( coarse_tokens, @@ -88,6 +92,7 @@ def generate_audio( waveform_temp: float = 0.7, silent: bool = False, output_full: bool = False, + use_kv_caching = False ): """Generate audio array from input text. @@ -103,7 +108,7 @@ def generate_audio( numpy audio array at sample frequency 24khz """ semantic_tokens = text_to_semantic( - text, history_prompt=history_prompt, temp=text_temp, silent=silent, + text, history_prompt=history_prompt, temp=text_temp, silent=silent, use_kv_caching=use_kv_caching ) out = semantic_to_waveform( semantic_tokens, @@ -111,6 +116,7 @@ def generate_audio( temp=waveform_temp, silent=silent, output_full=output_full, + use_kv_caching=use_kv_caching ) if output_full: full_generation, audio_arr = out diff --git a/bark/generation.py b/bark/generation.py index b5476bc1..5753125f 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -359,6 +359,7 @@ def generate_text_semantic( max_gen_duration_s=None, allow_early_stop=True, model=None, + use_kv_caching=False ): """Generate semantic tokens from text.""" assert isinstance(text, str) @@ -420,8 +421,14 @@ def generate_text_semantic( pbar = tqdm.tqdm(disable=silent, total=100) pbar_state = 0 tot_generated_duration_s = 0 + kv_cache = None for n in range(n_tot_steps): - logits = model(x, merge_context=True) + if use_kv_caching and kv_cache is not None: + x_input = x[:, [-1]] + else: + x_input = x + + logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_key_values=kv_cache) relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE] if allow_early_stop: relevant_logits = torch.hstack( @@ -498,6 +505,7 @@ def generate_coarse( max_coarse_history=630, # min 60 (faster), max 630 (more context) sliding_window_len=60, model=None, + use_kv_caching=False ): """Generate coarse audio codes from semantic tokens.""" assert ( @@ -592,11 +600,18 @@ def generate_coarse( x_coarse_in[:, -max_coarse_history:], ] ) + kv_cache = None for _ in range(sliding_window_len): if n_step >= n_steps: continue is_major_step = n_step % N_COARSE_CODEBOOKS == 0 - logits = model(x_in) + + if use_kv_caching and kv_cache is not None: + x_input = x_in[:, [-1]] + else: + x_input = x_in + + logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_key_values=kv_cache) logit_start_idx = ( SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE ) diff --git a/bark/model.py b/bark/model.py index bbf9b689..463557c5 100644 --- a/bark/model.py +++ b/bark/model.py @@ -43,7 +43,7 @@ def __init__(self, config): self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) - def forward(self, x): + def forward(self, x, layer_past=None, use_cache=False): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -52,14 +52,34 @@ def forward(self, x): q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + FULL_T = k.shape[-2] + + if use_cache is True: + present = (k, v) + else: + present = None + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels - y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) + if layer_past is not None: + # in theory the attention is still causal but because we're computing it incrementally, + # the last query can attend on all previous keys/values, which which is equivalent to non-causal + is_causal = False + else: + is_causal = True + + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal) else: # manual implementation of attention att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) @@ -67,7 +87,7 @@ def forward(self, x): # output projection y = self.resid_dropout(self.c_proj(y)) - return y + return (y, present) class MLP(nn.Module): @@ -95,10 +115,11 @@ def __init__(self, config, layer_idx): self.mlp = MLP(config) self.layer_idx = layer_idx - def forward(self, x): - x = x + self.attn(self.ln_1(x)) + def forward(self, x, layer_past=None, use_cache=False): + attn_output, prev_kvs = self.attn(self.ln_1(x), layer_past=layer_past, use_cache=use_cache) + x = x + attn_output x = x + self.mlp(self.ln_2(x)) - return x + return (x, prev_kvs) @dataclass class GPTConfig: @@ -142,33 +163,55 @@ def get_num_params(self, non_embedding=True): n_params -= self.transformer.wpe.weight.numel() return n_params - def forward(self, idx, merge_context=False): + def forward(self, idx, merge_context=False, past_key_values=None, position_ids=None, use_cache=False): device = idx.device b, t = idx.size() - if merge_context: - assert(idx.shape[1] >= 256+256+1) - t = idx.shape[1] - 256 + if past_key_values is not None: + assert t == 1 + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) else: - assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" - - # forward the GPT model itself - if merge_context: - tok_emb = torch.cat([ - self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]), - self.transformer.wte(idx[:,256+256:]) - ], dim=1) + if merge_context: + assert(idx.shape[1] >= 256+256+1) + t = idx.shape[1] - 256 + else: + assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + # forward the GPT model itself + if merge_context: + tok_emb = torch.cat([ + self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]), + self.transformer.wte(idx[:,256+256:]) + ], dim=1) + else: + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.transformer.h)) else: - tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) # shape (1, t) + assert position_ids.shape == (1, t) + + pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd) - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) - for block in self.transformer.h: - x = block(x) + + presents = () if use_cache else None + + for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)): + x, kv = block(x, layer_past=layer_past, use_cache=use_cache) + + if use_cache: + presents = presents + (kv,) + x = self.transformer.ln_f(x) # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim - return logits + return (logits, presents) From bee9e030802e612e1941d68822680b5be6895d7e Mon Sep 17 00:00:00 2001 From: Zygimantas Straznickas Date: Sat, 22 Apr 2023 12:23:55 -0700 Subject: [PATCH 2/3] Rename variables and add comments --- .gitignore | 3 +-- bark/generation.py | 4 ++-- bark/model.py | 40 +++++++++++++++++++++------------------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index 48e4ceb6..ba0430d2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1 @@ -__pycache__/ -.venv \ No newline at end of file +__pycache__/ \ No newline at end of file diff --git a/bark/generation.py b/bark/generation.py index 5753125f..4e860d83 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -428,7 +428,7 @@ def generate_text_semantic( else: x_input = x - logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_key_values=kv_cache) + logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache) relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE] if allow_early_stop: relevant_logits = torch.hstack( @@ -611,7 +611,7 @@ def generate_coarse( else: x_input = x_in - logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_key_values=kv_cache) + logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache) logit_start_idx = ( SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE ) diff --git a/bark/model.py b/bark/model.py index 463557c5..bb999324 100644 --- a/bark/model.py +++ b/bark/model.py @@ -43,7 +43,7 @@ def __init__(self, config): self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) - def forward(self, x, layer_past=None, use_cache=False): + def forward(self, x, past_kv=None, use_cache=False): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -52,9 +52,9 @@ def forward(self, x, layer_past=None, use_cache=False): q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - if layer_past is not None: - past_key = layer_past[0] - past_value = layer_past[1] + if past_kv is not None: + past_key = past_kv[0] + past_value = past_kv[1] k = torch.cat((past_key, k), dim=-2) v = torch.cat((past_value, v), dim=-2) @@ -68,9 +68,11 @@ def forward(self, x, layer_past=None, use_cache=False): # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels - if layer_past is not None: - # in theory the attention is still causal but because we're computing it incrementally, - # the last query can attend on all previous keys/values, which which is equivalent to non-causal + if past_kv is not None: + # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains + # the query for the last token. scaled_dot_product_attention interprets this as the first token in the + # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so + # to work around this we set is_causal=False. is_causal = False else: is_causal = True @@ -115,8 +117,8 @@ def __init__(self, config, layer_idx): self.mlp = MLP(config) self.layer_idx = layer_idx - def forward(self, x, layer_past=None, use_cache=False): - attn_output, prev_kvs = self.attn(self.ln_1(x), layer_past=layer_past, use_cache=use_cache) + def forward(self, x, past_kv=None, use_cache=False): + attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache) x = x + attn_output x = x + self.mlp(self.ln_2(x)) return (x, prev_kvs) @@ -163,10 +165,10 @@ def get_num_params(self, non_embedding=True): n_params -= self.transformer.wpe.weight.numel() return n_params - def forward(self, idx, merge_context=False, past_key_values=None, position_ids=None, use_cache=False): + def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): device = idx.device b, t = idx.size() - if past_key_values is not None: + if past_kv is not None: assert t == 1 tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) else: @@ -185,11 +187,11 @@ def forward(self, idx, merge_context=False, past_key_values=None, position_ids=N else: tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - if past_key_values is None: + if past_kv is None: past_length = 0 - past_key_values = tuple([None] * len(self.transformer.h)) + past_kv = tuple([None] * len(self.transformer.h)) else: - past_length = past_key_values[0][0].size(-2) + past_length = past_kv[0][0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device) @@ -201,17 +203,17 @@ def forward(self, idx, merge_context=False, past_key_values=None, position_ids=N x = self.transformer.drop(tok_emb + pos_emb) - presents = () if use_cache else None + new_kv = () if use_cache else None - for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)): - x, kv = block(x, layer_past=layer_past, use_cache=use_cache) + for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)): + x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache) if use_cache: - presents = presents + (kv,) + new_kv = new_kv + (kv,) x = self.transformer.ln_f(x) # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim - return (logits, presents) + return (logits, new_kv) From acfd65b1a96e6cb7893c0059984520693d1381b3 Mon Sep 17 00:00:00 2001 From: Zygimantas Straznickas Date: Sat, 22 Apr 2023 12:27:47 -0700 Subject: [PATCH 3/3] Add newline to .gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ba0430d2..c18dd8d8 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1 @@ -__pycache__/ \ No newline at end of file +__pycache__/