Silero TTS offline cache (#628)

This commit is contained in:
Φφ 2023-04-07 18:15:57 +03:00 committed by GitHub
parent 1c413ed593
commit e563b015d8
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23

View file

@ -21,6 +21,7 @@ params = {
'autoplay': True,
'voice_pitch': 'medium',
'voice_speed': 'medium',
'local_cache_path': '' # User can override the default cache path to something other via settings.json
}
current_params = params.copy()
@ -44,14 +45,18 @@ def xmlesc(txt):
def load_model():
torch_cache_path = torch.hub.get_dir() if params['local_cache_path'] == '' else params['local_cache_path']
model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
if Path(model_path).is_file():
print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
else:
print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
model.to(params['device'])
return model
model = load_model()
def remove_tts_from_history(name1, name2, mode):
for i, entry in enumerate(shared.history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
@ -132,6 +137,11 @@ def bot_prefix_modifier(string):
return string
def setup():
global model
model = load_model()
def ui():
# Gradio elements
with gr.Accordion("Silero TTS"):