diff --git a/extensions/whisper_stt/script.js b/extensions/whisper_stt/script.js index fff2b297..c4a908b5 100644 --- a/extensions/whisper_stt/script.js +++ b/extensions/whisper_stt/script.js @@ -1,25 +1,86 @@ -var recButton = document.getElementsByClassName("record-button")[0].cloneNode(true); +console.log("Whisper STT script loaded"); + +let mediaRecorder; +let audioChunks = []; +let isRecording = false; + +window.startStopRecording = function() { + if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) { + console.error("getUserMedia not supported on your browser!"); + return; + } + + if (isRecording == false) { + //console.log("Start recording function called"); + navigator.mediaDevices.getUserMedia({ audio: true }) + .then(stream => { + //console.log("Got audio stream"); + mediaRecorder = new MediaRecorder(stream); + audioChunks = []; // Reset audio chunks + mediaRecorder.start(); + //console.log("MediaRecorder started"); + recButton.icon; + recordButton.innerHTML = recButton.innerHTML = "Stop"; + isRecording = true; + + mediaRecorder.addEventListener("dataavailable", event => { + //console.log("Data available event, data size: ", event.data.size); + audioChunks.push(event.data); + }); + + mediaRecorder.addEventListener("stop", () => { + //console.log("MediaRecorder stopped"); + if (audioChunks.length > 0) { + const audioBlob = new Blob(audioChunks, { type: "audio/webm" }); + //console.log("Audio blob created, size: ", audioBlob.size); + const reader = new FileReader(); + reader.readAsDataURL(audioBlob); + reader.onloadend = function() { + const base64data = reader.result; + //console.log("Audio converted to base64, length: ", base64data.length); + + const audioBase64Input = document.querySelector("#audio-base64 textarea"); + if (audioBase64Input) { + audioBase64Input.value = base64data; + audioBase64Input.dispatchEvent(new Event("input", { bubbles: true })); + audioBase64Input.dispatchEvent(new Event("change", { bubbles: true })); + //console.log("Updated textarea with base64 data"); + } else { + console.error("Could not find audio-base64 textarea"); + } + }; + } else { + console.error("No audio data recorded for Whisper"); + } + }); + }); + } else { + //console.log("Stopping MediaRecorder"); + recordButton.innerHTML = recButton.innerHTML = "Rec."; + isRecording = false; + mediaRecorder.stop(); + } +}; + +const recordButton = gradioApp().querySelector("#record-button"); +recordButton.addEventListener("click", window.startStopRecording); + + +function gradioApp() { + const elems = document.getElementsByTagName("gradio-app"); + const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot; + return gradioShadowRoot ? gradioShadowRoot : document; +} + + +// extra rec button next to generate button +var recButton = recordButton.cloneNode(true); var generate_button = document.getElementById("Generate"); generate_button.insertAdjacentElement("afterend", recButton); recButton.style.setProperty("margin-left", "-10px"); -recButton.innerText = "Rec."; - +recButton.innerHTML = "Rec."; recButton.addEventListener("click", function() { - var originalRecordButton = document.getElementsByClassName("record-button")[1]; - originalRecordButton.click(); - - var stopRecordButtons = document.getElementsByClassName("stop-button"); - if (stopRecordButtons.length > 1) generate_button.parentElement.removeChild(stopRecordButtons[0]); - var stopRecordButton = document.getElementsByClassName("stop-button")[0]; - generate_button.insertAdjacentElement("afterend", stopRecordButton); - - //stopRecordButton.style.setProperty("margin-left", "-10px"); - stopRecordButton.style.setProperty("padding-right", "10px"); - recButton.style.display = "none"; - - stopRecordButton.addEventListener("click", function() { - recButton.style.display = "flex"; - }); -}); \ No newline at end of file + recordButton.click(); +}); diff --git a/extensions/whisper_stt/script.py b/extensions/whisper_stt/script.py index f52d2542..e45c8b1e 100644 --- a/extensions/whisper_stt/script.py +++ b/extensions/whisper_stt/script.py @@ -1,8 +1,13 @@ +import base64 +import gc +import io from pathlib import Path import gradio as gr -import speech_recognition as sr import numpy as np +import torch +import whisper +from pydub import AudioSegment from modules import shared @@ -11,13 +16,16 @@ input_hijack = { 'value': ["", ""] } -# parameters which can be customized in settings.json of webui +# parameters which can be customized in settings.yaml of webui params = { 'whipser_language': 'english', 'whipser_model': 'small.en', 'auto_submit': True } +startup_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +WHISPERMODEL = whisper.load_model(params['whipser_model'], device=startup_device) + def chat_input_modifier(text, visible_text, state): global input_hijack @@ -28,54 +36,76 @@ def chat_input_modifier(text, visible_text, state): return text, visible_text -def do_stt(audio, whipser_model, whipser_language): - transcription = "" - r = sr.Recognizer() +def do_stt(audio, whipser_language): + # use pydub to convert sample_rate and sample_width for whisper input + dubaudio = AudioSegment.from_file(io.BytesIO(audio)) + dubaudio = dubaudio.set_channels(1) + dubaudio = dubaudio.set_frame_rate(16000) + dubaudio = dubaudio.set_sample_width(2) - # Convert to AudioData - audio_data = sr.AudioData(sample_rate=audio[0], frame_data=audio[1], sample_width=4) + # same method to get the array as openai whisper repo used from wav file + audio_np = np.frombuffer(dubaudio.raw_data, np.int16).flatten().astype(np.float32) / 32768.0 - try: - transcription = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model) - except sr.UnknownValueError: - print("Whisper could not understand audio") - except sr.RequestError as e: - print("Could not request results from Whisper", e) + if len(whipser_language) == 0: + result = WHISPERMODEL.transcribe(audio=audio_np) + else: + result = WHISPERMODEL.transcribe(audio=audio_np, language=whipser_language) + return result["text"] + +def auto_transcribe(audio, auto_submit, whipser_language): + if audio is None or audio == "": + print("Whisper received no audio data") + return "", "" + audio_bytes = base64.b64decode(audio.split(',')[1]) + + transcription = do_stt(audio_bytes, whipser_language) + if auto_submit: + input_hijack.update({"state": True, "value": [transcription, transcription]}) return transcription -def auto_transcribe(audio, auto_submit, whipser_model, whipser_language): - if audio is None: - return "", "" - sample_rate, audio_data = audio - if not isinstance(audio_data[0], np.ndarray): # workaround for chrome audio. Mono? - # Convert to 2 channels, so each sample s_i consists of the same value in both channels [val_i, val_i] - audio_data = np.column_stack((audio_data, audio_data)) - audio = (sample_rate, audio_data) - transcription = do_stt(audio, whipser_model, whipser_language) - if auto_submit: - input_hijack.update({"state": True, "value": [transcription, transcription]}) +def reload_whispermodel(whisper_model_name: str, whisper_language: str, device: str): + if len(whisper_model_name) > 0: + global WHISPERMODEL + WHISPERMODEL = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() - return transcription, None + if device != "none": + if device == "cuda": + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + WHISPERMODEL = whisper.load_model(whisper_model_name, device=device) + params.update({"whipser_model": whisper_model_name}) + if ".en" in whisper_model_name: + whisper_language = "english" + audio_update = gr.Audio.update(interactive=True) + else: + audio_update = gr.Audio.update(interactive=False) + return [whisper_model_name, whisper_language, str(device), audio_update] def ui(): with gr.Accordion("Whisper STT", open=True): with gr.Row(): - audio = gr.Audio(source="microphone", type="numpy") + audio = gr.Textbox(elem_id="audio-base64", visible=False) + record_button = gr.Button("Rec.", elem_id="record-button", elem_classes="custom-button") with gr.Row(): with gr.Accordion("Settings", open=False): auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=params['auto_submit']) - whipser_model = gr.Dropdown(label='Whisper Model', value=params['whipser_model'], choices=["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large"]) - whipser_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'], choices=["chinese", "german", "spanish", "russian", "korean", "french", "japanese", "portuguese", "turkish", "polish", "catalan", "dutch", "arabic", "swedish", "italian", "indonesian", "hindi", "finnish", "vietnamese", "hebrew", "ukrainian", "greek", "malay", "czech", "romanian", "danish", "hungarian", "tamil", "norwegian", "thai", "urdu", "croatian", "bulgarian", "lithuanian", "latin", "maori", "malayalam", "welsh", "slovak", "telugu", "persian", "latvian", "bengali", "serbian", "azerbaijani", "slovenian", "kannada", "estonian", "macedonian", "breton", "basque", "icelandic", "armenian", "nepali", "mongolian", "bosnian", "kazakh", "albanian", "swahili", "galician", "marathi", "punjabi", "sinhala", "khmer", "shona", "yoruba", "somali", "afrikaans", "occitan", "georgian", "belarusian", "tajik", "sindhi", "gujarati", "amharic", "yiddish", "lao", "uzbek", "faroese", "haitian creole", "pashto", "turkmen", "nynorsk", "maltese", "sanskrit", "luxembourgish", "myanmar", "tibetan", "tagalog", "malagasy", "assamese", "tatar", "hawaiian", "lingala", "hausa", "bashkir", "javanese", "sundanese"]) + device_dropd = gr.Dropdown(label='Device', value=str(startup_device), choices=["cuda", "cpu", "none"]) + whisper_model_dropd = gr.Dropdown(label='Whisper Model', value=params['whipser_model'], choices=["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large"]) + whisper_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'], choices=["english", "chinese", "german", "spanish", "russian", "korean", "french", "japanese", "portuguese", "turkish", "polish", "catalan", "dutch", "arabic", "swedish", "italian", "indonesian", "hindi", "finnish", "vietnamese", "hebrew", "ukrainian", "greek", "malay", "czech", "romanian", "danish", "hungarian", "tamil", "norwegian", "thai", "urdu", "croatian", "bulgarian", "lithuanian", "latin", "maori", "malayalam", "welsh", "slovak", "telugu", "persian", "latvian", "bengali", "serbian", "azerbaijani", "slovenian", "kannada", "estonian", "macedonian", "breton", "basque", "icelandic", "armenian", "nepali", "mongolian", "bosnian", "kazakh", "albanian", "swahili", "galician", "marathi", "punjabi", "sinhala", "khmer", "shona", "yoruba", "somali", "afrikaans", "occitan", "georgian", "belarusian", "tajik", "sindhi", "gujarati", "amharic", "yiddish", "lao", "uzbek", "faroese", "haitian creole", "pashto", "turkmen", "nynorsk", "maltese", "sanskrit", "luxembourgish", "myanmar", "tibetan", "tagalog", "malagasy", "assamese", "tatar", "hawaiian", "lingala", "hausa", "bashkir", "javanese", "sundanese"]) - audio.stop_recording( - auto_transcribe, [audio, auto_submit, whipser_model, whipser_language], [shared.gradio['textbox'], audio]).then( - None, auto_submit, None, js="(check) => {if (check) { document.getElementById('Generate').click() }}") + audio.change( + auto_transcribe, [audio, auto_submit, whisper_language], [shared.gradio['textbox']]).then( + None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}") - whipser_model.change(lambda x: params.update({"whipser_model": x}), whipser_model, None) - whipser_language.change(lambda x: params.update({"whipser_language": x}), whipser_language, None) + device_dropd.input(reload_whispermodel, [whisper_model_dropd, whisper_language, device_dropd], [whisper_model_dropd, whisper_language, device_dropd, audio]) + whisper_model_dropd.change(reload_whispermodel, [whisper_model_dropd, whisper_language, device_dropd], [whisper_model_dropd, whisper_language, device_dropd, audio]) + whisper_language.change(lambda x: params.update({"whipser_language": x}), whisper_language, None) auto_submit.change(lambda x: params.update({"auto_submit": x}), auto_submit, None)