Skip to content

Commit

Permalink
added support for loading custom whisper models from local computer a…
Browse files Browse the repository at this point in the history
…nd huggingface finetuned whisper models
  • Loading branch information
NavodPeiris committed Aug 16, 2024
1 parent 12308ba commit 60e8454
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 14 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ transcript will also indicate the timeframe in seconds where each speaker speaks
```
from speechlib import Transcriptor
file = "obama_zach.wav" # your audio file
file = "obama1.wav" # your audio file
voices_folder = "voices" # voices folder containing voice samples for recognition
language = "en" # language code
log_folder = "logs" # log folder for storing transcripts
Expand All @@ -105,6 +105,12 @@ res = transcriptor.whisper()
# use faster-whisper (simply faster)
res = transcriptor.faster_whisper()
# use a custom trained whisper model
res = transcriptor.custom_whisper("D:/whisper_tiny_model/tiny.pt")
# use a huggingface whisper model
res = transcriptor.huggingface_model("Jingmiao/whisper-small-chinese_base")
res --> [["start", "end", "text", "speaker"], ["start", "end", "text", "speaker"]...]
```

Expand Down
8 changes: 7 additions & 1 deletion library.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ transcript will also indicate the timeframe in seconds where each speaker speaks
```
from speechlib import Transcriptor
file = "obama_zach.wav" # your audio file
file = "obama1.wav" # your audio file
voices_folder = "voices" # voices folder containing voice samples for recognition
language = "en" # language code
log_folder = "logs" # log folder for storing transcripts
Expand All @@ -89,6 +89,12 @@ res = transcriptor.whisper()
# use faster-whisper (simply faster)
res = transcriptor.faster_whisper()
# use a custom trained whisper model
res = transcriptor.custom_whisper("D:/whisper_tiny_model/tiny.pt")
# use a huggingface whisper model
res = transcriptor.huggingface_model("Jingmiao/whisper-small-chinese_base")
res --> [["start", "end", "text", "speaker"], ["start", "end", "text", "speaker"]...]
```

Expand Down
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="speechlib",
version="1.1.2",
version="1.1.3",
description="speechlib is a library that can do speaker diarization, transcription and speaker recognition on an audio file to create transcripts with actual speaker names. This library also contain audio preprocessor functions.",
packages=find_packages(),
long_description=long_description,
Expand All @@ -19,6 +19,8 @@
"Programming Language :: Python :: 3.10",
"Operating System :: OS Independent",
],
install_requires=["transformers==4.36.2", "torch==2.1.2", "torchaudio==2.1.2", "pydub==0.25.1", "pyannote.audio==3.1.1", "speechbrain==0.5.16", "accelerate==0.26.1", "faster-whisper==0.10.1", "openai-whisper==20231117"],
install_requires=["transformers", "torch", "torchaudio", "pydub", "pyannote.audio", "speechbrain==0.5.16", "accelerate", "faster-whisper", "openai-whisper"],
python_requires=">=3.8",
)
)

# ["transformers==4.36.2", "torch==2.1.2", "torchaudio==2.1.2", "pydub==0.25.1", "pyannote.audio==3.1.1", "speechbrain==0.5.16", "accelerate==0.26.1", "faster-whisper==0.10.1", "openai-whisper==20231117"]
2 changes: 1 addition & 1 deletion setup_instruction.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ for publishing:
pip install twine

for install locally for testing:
pip install dist/speechlib-1.1.2-py3-none-any.whl
pip install dist/speechlib-1.1.3-py3-none-any.whl

finally run:
twine upload dist/*
Expand Down
4 changes: 2 additions & 2 deletions speechlib/core_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# by default use google speech-to-text API
# if False, then use whisper finetuned version for sinhala
def core_analysis(file_name, voices_folder, log_folder, language, modelSize, ACCESS_TOKEN, whisper_type ,quantization=False):
def core_analysis(file_name, voices_folder, log_folder, language, modelSize, ACCESS_TOKEN, model_type, quantization=False, custom_model_path=None, hf_model_id=None):

# <-------------------PreProcessing file-------------------------->

Expand Down Expand Up @@ -114,7 +114,7 @@ def core_analysis(file_name, voices_folder, log_folder, language, modelSize, ACC
print("running transcription...")
for spk_tag, spk_segments in speakers.items():
spk = speaker_map[spk_tag]
segment_out = wav_file_segmentation(file_name, spk_segments, language, modelSize, whisper_type, quantization)
segment_out = wav_file_segmentation(file_name, spk_segments, language, modelSize, model_type, quantization, custom_model_path, hf_model_id)
speakers[spk_tag] = segment_out
end_time = int(time.time())
elapsed_time = int(end_time - start_time)
Expand Down
8 changes: 8 additions & 0 deletions speechlib/speechlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ def faster_whisper(self):
res = core_analysis(self.file, self.voices_folder, self.log_folder, self.language, self.modelSize, self.ACCESS_TOKEN, "faster-whisper", self.quantization)
return res

def custom_whisper(self, custom_model_path):
res = core_analysis(self.file, self.voices_folder, self.log_folder, self.language, self.modelSize, self.ACCESS_TOKEN, "custom", self.quantization, custom_model_path)
return res

def huggingface_model(self, hf_model_id):
res = core_analysis(self.file, self.voices_folder, self.log_folder, self.language, self.modelSize, self.ACCESS_TOKEN, "huggingface", self.quantization, None, hf_model_id)
return res

class PreProcessor:
'''
class for preprocessing audio files.
Expand Down
36 changes: 33 additions & 3 deletions speechlib/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
from .whisper_sinhala import (whisper_sinhala)
from faster_whisper import WhisperModel
import whisper
import os
from transformers import pipeline

def transcribe(file, language, model_size, whisper_type, quantization):
def transcribe(file, language, model_size, model_type, quantization, custom_model_path, hf_model_path):
res = ""
if language in ["si", "Si"]:
res = whisper_sinhala(file)
return res
elif model_size in ["base", "tiny", "small", "medium", "large", "large-v1", "large-v2", "large-v3"]:
if whisper_type == "faster-whisper":
if model_type == "faster-whisper":
if torch.cuda.is_available():
if quantization:
model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
Expand All @@ -30,7 +32,7 @@ def transcribe(file, language, model_size, whisper_type, quantization):
return res
else:
Exception("Language code not supported.\nThese are the supported languages:\n", model.supported_languages)
else:
elif model_type == "whisper":
try:
if torch.cuda.is_available():
model = whisper.load_model(model_size, device="cuda")
Expand All @@ -44,6 +46,34 @@ def transcribe(file, language, model_size, whisper_type, quantization):
return res
except Exception as err:
print("an error occured while transcribing: ", err)
elif model_type == "custom":
model_folder = os.path.dirname(custom_model_path)
model_folder = model_folder + "/"
print("model file: ", custom_model_path)
print("model fodler: ", model_folder)
try:
if torch.cuda.is_available():
model = whisper.load_model(custom_model_path, download_root=model_folder)
result = model.transcribe(file, language=language, fp16=True)
res = result["text"]
else:
model = whisper.load_model(custom_model_path, download_root=model_folder)
result = model.transcribe(file, language=language, fp16=False)
res = result["text"]

return res
except Exception as err:
raise Exception(f"an error occured while transcribing: {err}")
elif model_type == "huggingface":
try:
pipe = pipeline("automatic-speech-recognition", model=hf_model_path)
result = pipe(file)
res = result['text']
return res
except Exception as err:
raise Exception(f"an error occured while transcribing: {err}")
else:
raise Exception(f"model_type {model_type} is not supported")
else:
raise Exception("only 'base', 'tiny', 'small', 'medium', 'large', 'large-v1', 'large-v2', 'large-v3' models are available.")

7 changes: 4 additions & 3 deletions speechlib/wav_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .transcribe import (transcribe)

# segment according to speaker
def wav_file_segmentation(file_name, segments, language, modelSize, whisper_type, quantization):
def wav_file_segmentation(file_name, segments, language, modelSize, model_type, quantization, custom_model_path, hf_model_path):
# Load the WAV file
audio = AudioSegment.from_file(file_name, format="wav")
trans = ""
Expand All @@ -27,11 +27,12 @@ def wav_file_segmentation(file_name, segments, language, modelSize, whisper_type
clip.export(file, format="wav")

try:
trans = transcribe(file, language, modelSize, whisper_type, quantization)
trans = transcribe(file, language, modelSize, model_type, quantization, custom_model_path, hf_model_path)

# return -> [[start time, end time, transcript], [start time, end time, transcript], ..]
texts.append([segment[0], segment[1], trans])
except:
except Exception as err:
# to avoid transcription exceptions that occur when transcribing silent segments we have to pass
pass
# Delete the WAV file after processing
os.remove(file)
Expand Down

0 comments on commit 60e8454

Please sign in to comment.