Skip to content

Commit

Permalink
make kv caching default in inference
Browse files Browse the repository at this point in the history
  • Loading branch information
gkucsko committed Apr 22, 2023
1 parent 3247106 commit 009ff7c
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions bark/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def text_to_semantic(
history_prompt: Optional[str] = None,
temp: float = 0.7,
silent: bool = False,
use_kv_caching = False,
):
"""Generate semantic array from text.
Expand All @@ -28,7 +27,7 @@ def text_to_semantic(
history_prompt=history_prompt,
temp=temp,
silent=silent,
use_kv_caching=use_kv_caching
use_kv_caching=True
)
return x_semantic

Expand All @@ -39,7 +38,6 @@ def semantic_to_waveform(
temp: float = 0.7,
silent: bool = False,
output_full: bool = False,
use_kv_caching = False
):
"""Generate audio array from semantic input.
Expand All @@ -58,7 +56,7 @@ def semantic_to_waveform(
history_prompt=history_prompt,
temp=temp,
silent=silent,
use_kv_caching=use_kv_caching
use_kv_caching=True
)
fine_tokens = generate_fine(
coarse_tokens,
Expand Down Expand Up @@ -92,7 +90,6 @@ def generate_audio(
waveform_temp: float = 0.7,
silent: bool = False,
output_full: bool = False,
use_kv_caching = False
):
"""Generate audio array from input text.
Expand All @@ -108,15 +105,17 @@ def generate_audio(
numpy audio array at sample frequency 24khz
"""
semantic_tokens = text_to_semantic(
text, history_prompt=history_prompt, temp=text_temp, silent=silent, use_kv_caching=use_kv_caching
text,
history_prompt=history_prompt,
temp=text_temp,
silent=silent,
)
out = semantic_to_waveform(
semantic_tokens,
history_prompt=history_prompt,
temp=waveform_temp,
silent=silent,
output_full=output_full,
use_kv_caching=use_kv_caching
)
if output_full:
full_generation, audio_arr = out
Expand Down

0 comments on commit 009ff7c

Please sign in to comment.