Skip to content

Commit

Permalink
add fake classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
gkucsko committed Apr 12, 2023
1 parent 0981150 commit 2c03817
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 11 deletions.
42 changes: 32 additions & 10 deletions bark/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import re
import requests
import sys

from encodec import EncodecModel
import funcy
Expand Down Expand Up @@ -63,27 +62,43 @@ def autocast():

default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
os.makedirs(CACHE_DIR, exist_ok=True)


REMOTE_BASE_URL = "http://s3.amazonaws.com/suno-public/bark/models/v0/"
REMOTE_MODEL_PATHS = {
"text": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")),
"coarse": os.environ.get(
"SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt")
),
"fine": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")),
"text": {
"path": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")),
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
},
"coarse": {
"path": os.environ.get(
"SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt")
),
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
},
"fine": {
"path": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")),
"checksum": "59d184ed44e3650774a2f0503a48a97b",
},
}


def _compute_md5(s):
def _string_md5(s):
m = hashlib.md5()
m.update(s.encode("utf-8"))
return m.hexdigest()


def _md5(fname):
hash_md5 = hashlib.md5()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()


def _get_ckpt_path(model_type):
model_name = _compute_md5(REMOTE_MODEL_PATHS[model_type])
model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"])
return os.path.join(CACHE_DIR, f"{model_name}.pt")


Expand All @@ -97,6 +112,7 @@ def _parse_s3_filepath(s3_filepath):


def _download(from_s3_path, to_local_path):
os.makedirs(CACHE_DIR, exist_ok=True)
response = requests.get(from_s3_path, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
Expand Down Expand Up @@ -164,9 +180,15 @@ def _load_model(ckpt_path, device, model_type="text"):
ModelClass = FineGPT
else:
raise NotImplementedError()
if (
os.path.exists(ckpt_path) and
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
):
print(f"found outdated {model_type} model, removing...")
os.remove(ckpt_path)
if not os.path.exists(ckpt_path):
print(f"{model_type} model not found, downloading...")
_download(REMOTE_MODEL_PATHS[model_type], ckpt_path)
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack
model_args = checkpoint["model_args"]
Expand Down
Loading

0 comments on commit 2c03817

Please sign in to comment.