Skip to content

Commit

Permalink
Merge pull request suno-ai#146 from jn-jairo/offload-cpu
Browse files Browse the repository at this point in the history
Option to offload models to cpu
  • Loading branch information
gkucsko authored Apr 26, 2023
2 parents 6c26fb7 + dfbe09f commit 2c12023
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions bark/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def autocast():
global models
models = {}

global models_devices
models_devices = {}


CONTEXT_WINDOW_SIZE = 1024

Expand Down Expand Up @@ -84,6 +87,7 @@ def autocast():

USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False)
OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False)

REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"

Expand Down Expand Up @@ -294,8 +298,12 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
if model_type not in ("text", "coarse", "fine"):
raise NotImplementedError()
global models
global models_devices
device = _grab_best_device(use_gpu=use_gpu)
model_key = f"{model_type}"
if OFFLOAD_CPU:
models_devices[model_key] = device
device = "cpu"
if model_key not in models or force_reload:
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
clean_models(model_key=model_key)
Expand All @@ -310,11 +318,15 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te

def load_codec_model(use_gpu=True, force_reload=False):
global models
global models_devices
device = _grab_best_device(use_gpu=use_gpu)
if device == "mps":
# encodec doesn't support mps
device = "cpu"
model_key = "codec"
if OFFLOAD_CPU:
models_devices[model_key] = device
device = "cpu"
if model_key not in models or force_reload:
clean_models(model_key=model_key)
model = _load_codec_model(device)
Expand Down Expand Up @@ -411,12 +423,15 @@ def generate_text_semantic(
semantic_history = None
# load models if not yet exist
global models
global models_devices
if "text" not in models:
preload_models()
model_container = models["text"]
model = model_container["model"]
tokenizer = model_container["tokenizer"]
encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
if OFFLOAD_CPU:
model.to(models_devices["text"])
device = next(model.parameters()).device
if len(encoded_text) > 256:
p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
Expand Down Expand Up @@ -514,6 +529,8 @@ def generate_text_semantic(
pbar_state = req_pbar_state
pbar.close()
out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
if OFFLOAD_CPU:
model.to("cpu")
assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
_clear_cuda_cache()
return out
Expand Down Expand Up @@ -602,9 +619,12 @@ def generate_coarse(
x_coarse_history = np.array([], dtype=np.int32)
# load models if not yet exist
global models
global models_devices
if "coarse" not in models:
preload_models()
model = models["coarse"]
if OFFLOAD_CPU:
model.to(models_devices["coarse"])
device = next(model.parameters()).device
# start loop
n_steps = int(
Expand Down Expand Up @@ -691,6 +711,8 @@ def generate_coarse(
n_step += 1
del x_in
del x_semantic_in
if OFFLOAD_CPU:
model.to("cpu")
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
del x_coarse_in
assert len(gen_coarse_arr) == n_steps
Expand Down Expand Up @@ -737,9 +759,12 @@ def generate_fine(
n_coarse = x_coarse_gen.shape[0]
# load models if not yet exist
global models
global models_devices
if "fine" not in models:
preload_models()
model = models["fine"]
if OFFLOAD_CPU:
model.to(models_devices["fine"])
device = next(model.parameters()).device
# make input arr
in_arr = np.vstack(
Expand Down Expand Up @@ -808,6 +833,8 @@ def generate_fine(
del in_buffer
gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
del in_arr
if OFFLOAD_CPU:
model.to("cpu")
gen_fine_arr = gen_fine_arr[:, n_history:]
if n_remove_from_end > 0:
gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
Expand All @@ -820,9 +847,12 @@ def codec_decode(fine_tokens):
"""Turn quantized audio codes into audio array using encodec."""
# load models if not yet exist
global models
global models_devices
if "codec" not in models:
preload_models()
model = models["codec"]
if OFFLOAD_CPU:
model.to(models_devices["codec"])
device = next(model.parameters()).device
arr = torch.from_numpy(fine_tokens)[None]
arr = arr.to(device)
Expand All @@ -831,4 +861,6 @@ def codec_decode(fine_tokens):
out = model.decoder(emb)
audio_arr = out.detach().cpu().numpy().squeeze()
del arr, emb, out
if OFFLOAD_CPU:
model.to("cpu")
return audio_arr

0 comments on commit 2c12023

Please sign in to comment.