diff --git a/extensions/silero_tts/requirements.txt b/extensions/silero_tts/requirements.txt index f2f0bff5..ac2785af 100644 --- a/extensions/silero_tts/requirements.txt +++ b/extensions/silero_tts/requirements.txt @@ -1,4 +1,5 @@ ipython +num2words omegaconf pydub PyYAML diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index 6ee617c8..b001ae5a 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -1,14 +1,15 @@ -import re import time from pathlib import Path import gradio as gr -import modules.chat as chat -import modules.shared as shared import torch +from extensions.silero_tts import tts_preprocessor +from modules import chat, shared +from modules.html_generator import chat_html_wrapper torch._C._jit_set_profiling_mode(False) + params = { 'activate': True, 'speaker': 'en_56', @@ -26,7 +27,7 @@ current_params = params.copy() voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high'] voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast'] -streaming_state = shared.args.no_stream # remember if chat streaming was enabled +streaming_state = shared.args.no_stream # remember if chat streaming was enabled # Used for making text xml compatible, needed for voice pitch and speed control table = str.maketrans({ @@ -37,26 +38,27 @@ table = str.maketrans({ '"': """, }) + def xmlesc(txt): return txt.translate(table) + def load_model(): model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) model.to(params['device']) return model + + model = load_model() -def remove_surrounded_chars(string): - # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR - # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' - return re.sub('\*[^\*]*?(\*|$)','',string) -def remove_tts_from_history(name1, name2): +def remove_tts_from_history(name1, name2, mode): for i, entry in enumerate(shared.history['internal']): shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]] - return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character) + return chat_html_wrapper(shared.history['visible'], name1, name2, mode) -def toggle_text_in_history(name1, name2): + +def toggle_text_in_history(name1, name2, mode): for i, entry in enumerate(shared.history['visible']): visible_reply = entry[1] if visible_reply.startswith('')[0]}\n\n{reply}"] else: shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('')[0]}"] - return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character) + return chat_html_wrapper(shared.history['visible'], name1, name2, mode) + def input_modifier(string): """ @@ -75,12 +78,13 @@ def input_modifier(string): # Remove autoplay from the last reply if shared.is_chat() and len(shared.history['internal']) > 0: - shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')] + shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')] shared.processing_message = "*Is recording a voice message...*" - shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated + shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated return string + def output_modifier(string): """ This function is applied to the model outputs. @@ -94,15 +98,11 @@ def output_modifier(string): current_params = params.copy() break - if params['activate'] == False: + if not params['activate']: return string original_string = string - string = remove_surrounded_chars(string) - string = string.replace('"', '') - string = string.replace('“', '') - string = string.replace('\n', ' ') - string = string.strip() + string = tts_preprocessor.preprocess(string) if string == '': string = '*Empty reply, try regenerating*' @@ -118,9 +118,10 @@ def output_modifier(string): string += f'\n\n{original_string}' shared.processing_message = "*Is typing...*" - shared.args.no_stream = streaming_state # restore the streaming option to the previous value + shared.args.no_stream = streaming_state # restore the streaming option to the previous value return string + def bot_prefix_modifier(string): """ This function is only applied in chat mode. It modifies @@ -130,17 +131,20 @@ def bot_prefix_modifier(string): return string + def ui(): # Gradio elements with gr.Accordion("Silero TTS"): with gr.Row(): activate = gr.Checkbox(value=params['activate'], label='Activate TTS') autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically') + show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player') voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice') with gr.Row(): v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch') v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed') + with gr.Row(): convert = gr.Button('Permanently replace audios with the message texts') convert_cancel = gr.Button('Cancel', visible=False) @@ -148,20 +152,20 @@ def ui(): # Convert history with confirmation convert_arr = [convert_confirm, convert, convert_cancel] - convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) - convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) - convert_confirm.click(remove_tts_from_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display']) - convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) + convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) + convert_confirm.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) + convert_confirm.click(remove_tts_from_history, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], shared.gradio['display']) + convert_confirm.click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False) + convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) # Toggle message text in history show_text.change(lambda x: params.update({"show_text": x}), show_text, None) - show_text.change(toggle_text_in_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display']) - show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) + show_text.change(toggle_text_in_history, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], shared.gradio['display']) + show_text.change(lambda: chat.save_history(timestamp=False), [], [], show_progress=False) # Event functions to update the parameters in the backend activate.change(lambda x: params.update({"activate": x}), activate, None) autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None) voice.change(lambda x: params.update({"speaker": x}), voice, None) v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None) - v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None) \ No newline at end of file + v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None) diff --git a/extensions/silero_tts/test_tts.py b/extensions/silero_tts/test_tts.py new file mode 100644 index 00000000..ad8ee764 --- /dev/null +++ b/extensions/silero_tts/test_tts.py @@ -0,0 +1,82 @@ +import time +from pathlib import Path + +import torch + +import tts_preprocessor + +torch._C._jit_set_profiling_mode(False) + + +params = { + 'activate': True, + 'speaker': 'en_49', + 'language': 'en', + 'model_id': 'v3_en', + 'sample_rate': 48000, + 'device': 'cpu', + 'show_text': True, + 'autoplay': True, + 'voice_pitch': 'medium', + 'voice_speed': 'medium', +} + +current_params = params.copy() +voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] +voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high'] +voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast'] + +# Used for making text xml compatible, needed for voice pitch and speed control +table = str.maketrans({ + "<": "<", + ">": ">", + "&": "&", + "'": "'", + '"': """, +}) + + +def xmlesc(txt): + return txt.translate(table) + + +def load_model(): + model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) + model.to(params['device']) + return model + + +model = load_model() + + +def output_modifier(string): + """ + This function is applied to the model outputs. + """ + + global model, current_params + + original_string = string + string = tts_preprocessor.preprocess(string) + processed_string = string + + if string == '': + string = '*Empty reply, try regenerating*' + else: + output_file = Path(f'extensions/silero_tts/outputs/test_{int(time.time())}.wav') + prosody = ''.format(params['voice_speed'], params['voice_pitch']) + silero_input = f'{prosody}{xmlesc(string)}' + model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) + + autoplay = 'autoplay' if params['autoplay'] else '' + string = f'' + + if params['show_text']: + string += f'\n\n{original_string}\n\nProcessed:\n{processed_string}' + + print(string) + + +if __name__ == '__main__': + import sys + output_modifier(sys.argv[1]) diff --git a/extensions/silero_tts/tts_preprocessor.py b/extensions/silero_tts/tts_preprocessor.py new file mode 100644 index 00000000..1dd751e4 --- /dev/null +++ b/extensions/silero_tts/tts_preprocessor.py @@ -0,0 +1,195 @@ +import re + +from num2words import num2words + + +punctuation = r'[\s,.?!/)\'\]>]' +alphabet_map = { + "A": " Ei ", + "B": " Bee ", + "C": " See ", + "D": " Dee ", + "E": " Eee ", + "F": " Eff ", + "G": " Jee ", + "H": " Eich ", + "I": " Eye ", + "J": " Jay ", + "K": " Kay ", + "L": " El ", + "M": " Emm ", + "N": " Enn ", + "O": " Ohh ", + "P": " Pee ", + "Q": " Queue ", + "R": " Are ", + "S": " Ess ", + "T": " Tee ", + "U": " You ", + "V": " Vee ", + "W": " Double You ", + "X": " Ex ", + "Y": " Why ", + "Z": " Zed " # Zed is weird, as I (da3dsoul) am American, but most of the voice models sound British, so it matches +} + + +def preprocess(string): + # the order for some of these matter + # For example, you need to remove the commas in numbers before expanding them + string = remove_surrounded_chars(string) + string = string.replace('"', '') + string = string.replace('\u201D', '').replace('\u201C', '') # right and left quote + string = string.replace('\u201F', '') # italic looking quote + string = string.replace('\n', ' ') + string = convert_num_locale(string) + string = replace_negative(string) + string = replace_roman(string) + string = hyphen_range_to(string) + string = num_to_words(string) + + # TODO Try to use a ML predictor to expand abbreviations. It's hard, dependent on context, and whether to actually + # try to say the abbreviation or spell it out as I've done below is not agreed upon + + # For now, expand abbreviations to pronunciations + # replace_abbreviations adds a lot of unnecessary whitespace to ensure separation + string = replace_abbreviations(string) + string = replace_lowercase_abbreviations(string) + + # cleanup whitespaces + # remove whitespace before punctuation + string = re.sub(rf'\s+({punctuation})', r'\1', string) + string = string.strip() + # compact whitespace + string = ' '.join(string.split()) + + return string + + +def remove_surrounded_chars(string): + # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR + # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' + return re.sub(r'\*[^*]*?(\*|$)', '', string) + + +def convert_num_locale(text): + # This detects locale and converts it to American without comma separators + pattern = re.compile(r'(?:\s|^)\d{1,3}(?:\.\d{3})+(,\d+)(?:\s|$)') + result = text + while True: + match = pattern.search(result) + if match is None: + break + + start = match.start() + end = match.end() + result = result[0:start] + result[start:end].replace('.', '').replace(',', '.') + result[end:len(result)] + + # removes comma separators from existing American numbers + pattern = re.compile(r'(\d),(\d)') + result = pattern.sub(r'\1\2', result) + + return result + + +def replace_negative(string): + # handles situations like -5. -5 would become negative 5, which would then be expanded to negative five + return re.sub(rf'(\s)(-)(\d+)({punctuation})', r'\1negative \3\4', string) + + +def replace_roman(string): + # find a string of roman numerals. + # Only 2 or more, to avoid capturing I and single character abbreviations, like names + pattern = re.compile(rf'\s[IVXLCDM]{{2,}}{punctuation}') + result = string + while True: + match = pattern.search(result) + if match is None: + break + + start = match.start() + end = match.end() + result = result[0:start + 1] + str(roman_to_int(result[start + 1:end - 1])) + result[end - 1:len(result)] + + return result + + +def roman_to_int(s): + rom_val = {'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C': 100, 'D': 500, 'M': 1000} + int_val = 0 + for i in range(len(s)): + if i > 0 and rom_val[s[i]] > rom_val[s[i - 1]]: + int_val += rom_val[s[i]] - 2 * rom_val[s[i - 1]] + else: + int_val += rom_val[s[i]] + return int_val + + +def hyphen_range_to(text): + pattern = re.compile(r'(\d+)[-–](\d+)') + result = pattern.sub(lambda x: x.group(1) + ' to ' + x.group(2), text) + return result + + +def num_to_words(text): + # 1000 or 10.23 + pattern = re.compile(r'\d+\.\d+|\d+') + result = pattern.sub(lambda x: num2words(float(x.group())), text) + return result + + +def replace_abbreviations(string): + # abbreviations 1 to 4 characters long. It will get things like A and I, but those are pronounced with their letter + pattern = re.compile(rf'(^|[\s(.\'\[<])([A-Z]{{1,4}})({punctuation}|$)') + result = string + while True: + match = pattern.search(result) + if match is None: + break + + start = match.start() + end = match.end() + result = result[0:start] + replace_abbreviation(result[start:end]) + result[end:len(result)] + + return result + + +def replace_lowercase_abbreviations(string): + # abbreviations 1 to 4 characters long, separated by dots i.e. e.g. + pattern = re.compile(rf'(^|[\s(.\'\[<])(([a-z]\.){{1,4}})({punctuation}|$)') + result = string + while True: + match = pattern.search(result) + if match is None: + break + + start = match.start() + end = match.end() + result = result[0:start] + replace_abbreviation(result[start:end].upper()) + result[end:len(result)] + + return result + + +def replace_abbreviation(string): + result = "" + for char in string: + result += match_mapping(char) + + return result + + +def match_mapping(char): + for mapping in alphabet_map.keys(): + if char == mapping: + return alphabet_map[char] + + return char + + +def __main__(args): + print(preprocess(args[1])) + + +if __name__ == "__main__": + import sys + __main__(sys.argv)