Skip to content

Commit

Permalink
add option for smaller models
Browse files Browse the repository at this point in the history
  • Loading branch information
gkucsko committed Apr 21, 2023
1 parent 9751cfb commit c372430
Showing 1 changed file with 33 additions and 16 deletions.
49 changes: 33 additions & 16 deletions bark/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
torch.cuda.is_available() and
hasattr(torch.cuda, "amp") and
hasattr(torch.cuda.amp, "autocast") and
hasattr(torch.cuda, "is_bf16_supported") and
torch.cuda.is_bf16_supported()
):
autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16)
Expand Down Expand Up @@ -80,23 +81,39 @@ def autocast():
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")


USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)

REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
REMOTE_MODEL_PATHS = {
"text": {
"path": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")),
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
},
"coarse": {
"path": os.environ.get(
"SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt")
),
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
},
"fine": {
"path": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")),
"checksum": "59d184ed44e3650774a2f0503a48a97b",
},
}
if USE_SMALL_MODELS:
REMOTE_MODEL_PATHS = {
"text": {
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
},
"coarse": {
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
},
"fine": {
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
"checksum": "5428d1befe05be2ba32195496e58dc90",
},
}
else:
REMOTE_MODEL_PATHS = {
"text": {
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
},
"coarse": {
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
},
"fine": {
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
"checksum": "59d184ed44e3650774a2f0503a48a97b",
},
}


if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
Expand Down

0 comments on commit c372430

Please sign in to comment.