Skip to content

Commit

Permalink
Merge pull request suno-ai#62 from suno-ai/test_model_control
Browse files Browse the repository at this point in the history
Simplify small model and gpu/cpu choice
  • Loading branch information
gkucsko authored Apr 22, 2023
2 parents 009ff7c + 8313b57 commit 97c6019
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 45 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@ Bark is a transformer-based text-to-audio model created by [Suno](https://suno.a
## 🤖 Usage

```python
from bark import SAMPLE_RATE, generate_audio
from bark import SAMPLE_RATE, generate_audio, preload_models
from IPython.display import Audio

# download and load all models
preload_models()

# generate audio from text
text_prompt = """
Hello, my name is Suno. And, uh — and I like pizza. [laughs]
But I also have other interests such as playing tic tac toe.
"""
audio_array = generate_audio(text_prompt)

# play text in notebook
Audio(audio_array, rate=SAMPLE_RATE)
```

Expand Down
101 changes: 57 additions & 44 deletions bark/generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import gc
import hashlib
import os
import re
Expand Down Expand Up @@ -84,36 +85,33 @@ def autocast():
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)

REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
if USE_SMALL_MODELS:
REMOTE_MODEL_PATHS = {
"text": {
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
},
"coarse": {
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
},
"fine": {
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
"checksum": "5428d1befe05be2ba32195496e58dc90",
},
}
else:
REMOTE_MODEL_PATHS = {
"text": {
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
},
"coarse": {
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
},
"fine": {
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
"checksum": "59d184ed44e3650774a2f0503a48a97b",
},
}

REMOTE_MODEL_PATHS = {
"text_small": {
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
},
"coarse_small": {
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
},
"fine_small": {
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
"checksum": "5428d1befe05be2ba32195496e58dc90",
},
"text": {
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
},
"coarse": {
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
},
"fine": {
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
"checksum": "59d184ed44e3650774a2f0503a48a97b",
},
}


if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
Expand All @@ -137,8 +135,9 @@ def _md5(fname):
return hash_md5.hexdigest()


def _get_ckpt_path(model_type):
model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"])
def _get_ckpt_path(model_type, use_small=False):
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"])
return os.path.join(CACHE_DIR, f"{model_name}.pt")


Expand Down Expand Up @@ -204,9 +203,10 @@ def clean_models(model_key=None):
if k in models:
del models[k]
_clear_cuda_cache()
gc.collect()


def _load_model(ckpt_path, device, model_type="text"):
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
if "cuda" not in device:
logger.warning("No GPU being used. Careful, inference might be extremely slow!")
if model_type == "text":
Expand All @@ -220,15 +220,17 @@ def _load_model(ckpt_path, device, model_type="text"):
ModelClass = FineGPT
else:
raise NotImplementedError()
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_info = REMOTE_MODEL_PATHS[model_key]
if (
os.path.exists(ckpt_path) and
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
_md5(ckpt_path) != model_info["checksum"]
):
logger.warning(f"found outdated {model_type} model, removing.")
os.remove(ckpt_path)
if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
_download(model_info["path"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack
model_args = checkpoint["model_args"]
Expand Down Expand Up @@ -278,8 +280,8 @@ def _load_codec_model(device):
return model


def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"):
_load_model_f = funcy.partial(_load_model, model_type=model_type)
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
_load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
if model_type not in ("text", "coarse", "fine"):
raise NotImplementedError()
global models
Expand All @@ -289,8 +291,7 @@ def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="tex
device = "cuda"
model_key = str(device) + f"__{model_type}"
if model_key not in models or force_reload:
if ckpt_path is None:
ckpt_path = _get_ckpt_path(model_type)
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
clean_models(model_key=model_key)
model = _load_model_f(ckpt_path, device)
models[model_key] = model
Expand All @@ -311,17 +312,29 @@ def load_codec_model(use_gpu=True, force_reload=False):
return models[model_key]


def preload_models(text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True):
def preload_models(
text_use_gpu=True,
text_use_small=False,
coarse_use_gpu=True,
coarse_use_small=False,
fine_use_gpu=True,
fine_use_small=False,
codec_use_gpu=True,
force_reload=False,
):
_ = load_model(
ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
)
_ = load_model(
ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True
model_type="coarse",
use_gpu=coarse_use_gpu,
use_small=coarse_use_small,
force_reload=force_reload,
)
_ = load_model(
ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
)
_ = load_codec_model(use_gpu=use_gpu, force_reload=True)
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)


####
Expand Down

0 comments on commit 97c6019

Please sign in to comment.