Skip to content

Commit

Permalink
Merge pull request suno-ai#27 from zygi/main
Browse files Browse the repository at this point in the history
Add key/value caching for autoregressive generation
  • Loading branch information
gkucsko authored Apr 22, 2023
2 parents 874af1b + acfd65b commit 3247106
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 29 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
__pycache__/

8 changes: 7 additions & 1 deletion bark/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def text_to_semantic(
history_prompt: Optional[str] = None,
temp: float = 0.7,
silent: bool = False,
use_kv_caching = False,
):
"""Generate semantic array from text.
Expand All @@ -27,6 +28,7 @@ def text_to_semantic(
history_prompt=history_prompt,
temp=temp,
silent=silent,
use_kv_caching=use_kv_caching
)
return x_semantic

Expand All @@ -37,6 +39,7 @@ def semantic_to_waveform(
temp: float = 0.7,
silent: bool = False,
output_full: bool = False,
use_kv_caching = False
):
"""Generate audio array from semantic input.
Expand All @@ -55,6 +58,7 @@ def semantic_to_waveform(
history_prompt=history_prompt,
temp=temp,
silent=silent,
use_kv_caching=use_kv_caching
)
fine_tokens = generate_fine(
coarse_tokens,
Expand Down Expand Up @@ -88,6 +92,7 @@ def generate_audio(
waveform_temp: float = 0.7,
silent: bool = False,
output_full: bool = False,
use_kv_caching = False
):
"""Generate audio array from input text.
Expand All @@ -103,14 +108,15 @@ def generate_audio(
numpy audio array at sample frequency 24khz
"""
semantic_tokens = text_to_semantic(
text, history_prompt=history_prompt, temp=text_temp, silent=silent,
text, history_prompt=history_prompt, temp=text_temp, silent=silent, use_kv_caching=use_kv_caching
)
out = semantic_to_waveform(
semantic_tokens,
history_prompt=history_prompt,
temp=waveform_temp,
silent=silent,
output_full=output_full,
use_kv_caching=use_kv_caching
)
if output_full:
full_generation, audio_arr = out
Expand Down
19 changes: 17 additions & 2 deletions bark/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def generate_text_semantic(
max_gen_duration_s=None,
allow_early_stop=True,
model=None,
use_kv_caching=False
):
"""Generate semantic tokens from text."""
assert isinstance(text, str)
Expand Down Expand Up @@ -420,8 +421,14 @@ def generate_text_semantic(
pbar = tqdm.tqdm(disable=silent, total=100)
pbar_state = 0
tot_generated_duration_s = 0
kv_cache = None
for n in range(n_tot_steps):
logits = model(x, merge_context=True)
if use_kv_caching and kv_cache is not None:
x_input = x[:, [-1]]
else:
x_input = x

logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache)
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
if allow_early_stop:
relevant_logits = torch.hstack(
Expand Down Expand Up @@ -498,6 +505,7 @@ def generate_coarse(
max_coarse_history=630, # min 60 (faster), max 630 (more context)
sliding_window_len=60,
model=None,
use_kv_caching=False
):
"""Generate coarse audio codes from semantic tokens."""
assert (
Expand Down Expand Up @@ -592,11 +600,18 @@ def generate_coarse(
x_coarse_in[:, -max_coarse_history:],
]
)
kv_cache = None
for _ in range(sliding_window_len):
if n_step >= n_steps:
continue
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
logits = model(x_in)

if use_kv_caching and kv_cache is not None:
x_input = x_in[:, [-1]]
else:
x_input = x_in

logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
logit_start_idx = (
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
)
Expand Down
95 changes: 70 additions & 25 deletions bark/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, config):
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))

def forward(self, x):
def forward(self, x, past_kv=None, use_cache=False):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
Expand All @@ -52,22 +52,44 @@ def forward(self, x):
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

if past_kv is not None:
past_key = past_kv[0]
past_value = past_kv[1]
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)

FULL_T = k.shape[-2]

if use_cache is True:
present = (k, v)
else:
present = None

# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
if past_kv is not None:
# When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
# the query for the last token. scaled_dot_product_attention interprets this as the first token in the
# sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
# to work around this we set is_causal=False.
is_causal = False
else:
is_causal = True

y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

# output projection
y = self.resid_dropout(self.c_proj(y))
return y
return (y, present)

class MLP(nn.Module):

Expand Down Expand Up @@ -95,10 +117,11 @@ def __init__(self, config, layer_idx):
self.mlp = MLP(config)
self.layer_idx = layer_idx

def forward(self, x):
x = x + self.attn(self.ln_1(x))
def forward(self, x, past_kv=None, use_cache=False):
attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
x = x + attn_output
x = x + self.mlp(self.ln_2(x))
return x
return (x, prev_kvs)

@dataclass
class GPTConfig:
Expand Down Expand Up @@ -142,33 +165,55 @@ def get_num_params(self, non_embedding=True):
n_params -= self.transformer.wpe.weight.numel()
return n_params

def forward(self, idx, merge_context=False):
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
device = idx.device
b, t = idx.size()
if merge_context:
assert(idx.shape[1] >= 256+256+1)
t = idx.shape[1] - 256
if past_kv is not None:
assert t == 1
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
else:
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

# forward the GPT model itself
if merge_context:
tok_emb = torch.cat([
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
self.transformer.wte(idx[:,256+256:])
], dim=1)
if merge_context:
assert(idx.shape[1] >= 256+256+1)
t = idx.shape[1] - 256
else:
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

# forward the GPT model itself
if merge_context:
tok_emb = torch.cat([
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
self.transformer.wte(idx[:,256+256:])
], dim=1)
else:
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)

if past_kv is None:
past_length = 0
past_kv = tuple([None] * len(self.transformer.h))
else:
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
past_length = past_kv[0][0].size(-2)

if position_ids is None:
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0) # shape (1, t)
assert position_ids.shape == (1, t)

pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)

pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)

x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)

new_kv = () if use_cache else None

for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)

if use_cache:
new_kv = new_kv + (kv,)

x = self.transformer.ln_f(x)

# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim

return logits
return (logits, new_kv)

0 comments on commit 3247106

Please sign in to comment.