Merge branch 'oobabooga:main' into stt-extension

This commit is contained in:
Elias Vincent Simon 2023-03-12 19:19:43 +01:00 committed by GitHub
commit 3b4145966d
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 462 additions and 185 deletions

View file

@ -1,6 +1,6 @@
# Text generation web UI # Text generation web UI
A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, GPT-Neo, and Pygmalion. A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, LLaMA, and Pygmalion.
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation. Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
@ -27,6 +27,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen). * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
* [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed). * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
* Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming. * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
* [Supports the LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
* [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model). * [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
* Supports softprompts. * Supports softprompts.
* [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions). * [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
@ -53,7 +54,7 @@ The third line assumes that you have an NVIDIA GPU.
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2 pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
``` ```
* If you are running in CPU mode, replace the third command with this one: * If you are running it in CPU mode, replace the third command with this one:
``` ```
conda install pytorch torchvision torchaudio git -c pytorch conda install pytorch torchvision torchaudio git -c pytorch
@ -137,6 +138,8 @@ Optionally, you can use the following command-line flags:
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. | | `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
| `--cpu` | Use the CPU to generate text.| | `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.| | `--load-in-8bit` | Load the model with 8-bit precision.|
| `--load-in-4bit` | Load the model with 4-bit precision. Currently only works with LLaMA.|
| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA. |
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.| | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. | | `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
@ -176,14 +179,10 @@ Check the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/System-
Pull requests, suggestions, and issue reports are welcome. Pull requests, suggestions, and issue reports are welcome.
Before reporting a bug, make sure that you have created a conda environment and installed the dependencies exactly as in the *Installation* section above. Before reporting a bug, make sure that you have:
These issues are known: 1. Created a conda environment and installed the dependencies exactly as in the *Installation* section above.
2. [Searched](https://github.com/oobabooga/text-generation-webui/issues) to see if an issue already exists for the issue you encountered.
* 8-bit doesn't work properly on Windows or older GPUs.
* DeepSpeed doesn't work properly on Windows.
For these two, please try commenting on an existing issue instead of creating a new one.
## Credits ## Credits

View file

@ -5,7 +5,9 @@ Example:
python download-model.py facebook/opt-1.3b python download-model.py facebook/opt-1.3b
''' '''
import argparse import argparse
import base64
import json import json
import multiprocessing import multiprocessing
import re import re
@ -93,23 +95,28 @@ facebook/opt-1.3b
def get_download_links_from_huggingface(model, branch): def get_download_links_from_huggingface(model, branch):
base = "https://huggingface.co" base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor=" page = f"/api/models/{model}/tree/{branch}?cursor="
cursor = b""
links = [] links = []
classifications = [] classifications = []
has_pytorch = False has_pytorch = False
has_safetensors = False has_safetensors = False
while page is not None: while True:
content = requests.get(f"{base}{page}").content content = requests.get(f"{base}{page}{cursor.decode()}").content
dict = json.loads(content) dict = json.loads(content)
if len(dict) == 0:
break
for i in range(len(dict)): for i in range(len(dict)):
fname = dict[i]['path'] fname = dict[i]['path']
is_pytorch = re.match("pytorch_model.*\.bin", fname) is_pytorch = re.match("pytorch_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname) is_safetensors = re.match("model.*\.safetensors", fname)
is_text = re.match(".*\.(txt|json)", fname) is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
if is_text or is_safetensors or is_pytorch: if any((is_pytorch, is_safetensors, is_text, is_tokenizer)):
if is_text: if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text') classifications.append('text')
@ -123,8 +130,9 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True has_pytorch = True
classifications.append('pytorch') classifications.append('pytorch')
#page = dict['nextUrl'] cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
page = None cursor = base64.b64encode(cursor)
cursor = cursor.replace(b'=', b'%3D')
# If both pytorch and safetensors are available, download safetensors only # If both pytorch and safetensors are available, download safetensors only
if has_pytorch and has_safetensors: if has_pytorch and has_safetensors:

View file

@ -0,0 +1,18 @@
import gradio as gr
import modules.shared as shared
import pandas as pd
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
def get_prompt_by_name(name):
if name == 'None':
return ''
else:
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
def ui():
if not shared.args.chat or shared.args.cai_chat:
choices = ['None'] + list(df['Prompt name'])
prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt')
prompts_menu.change(get_prompt_by_name, prompts_menu, shared.gradio['textbox'])

View file

@ -1,21 +1,45 @@
import re
import time
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import torch import torch
import modules.chat as chat
import modules.shared as shared
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
params = { params = {
'activate': True, 'activate': True,
'speaker': 'en_56', 'speaker': 'en_5',
'language': 'en', 'language': 'en',
'model_id': 'v3_en', 'model_id': 'v3_en',
'sample_rate': 48000, 'sample_rate': 48000,
'device': 'cpu', 'device': 'cpu',
'show_text': False,
'autoplay': True,
'voice_pitch': 'medium',
'voice_speed': 'medium',
} }
current_params = params.copy() current_params = params.copy()
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
wav_idx = 0 voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
last_msg_id = 0
# Used for making text xml compatible, needed for voice pitch and speed control
table = str.maketrans({
"<": "&lt;",
">": "&gt;",
"&": "&amp;",
"'": "&apos;",
'"': "&quot;",
})
def xmlesc(txt):
return txt.translate(table)
def load_model(): def load_model():
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
@ -33,12 +57,59 @@ def remove_surrounded_chars(string):
new_string += char new_string += char
return new_string return new_string
def remove_tts_from_history():
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
for i, entry in enumerate(shared.history['internal']):
reply = entry[1]
reply = re.sub("(<USER>|<user>|{{user}})", shared.settings[f'name1{suffix}'], reply)
if shared.args.chat:
reply = reply.replace('\n', '<br>')
shared.history['visible'][i][1] = reply
if shared.args.cai_chat:
return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character)
else:
return shared.history['visible']
def toggle_text_in_history():
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
audio_str='\n\n' # The '\n\n' used after </audio>
if shared.args.chat:
audio_str='<br><br>'
if params['show_text']==True:
#for i, entry in enumerate(shared.history['internal']):
for i, entry in enumerate(shared.history['visible']):
vis_reply = entry[1]
if vis_reply.startswith('<audio'):
reply = shared.history['internal'][i][1]
reply = re.sub("(<USER>|<user>|{{user}})", shared.settings[f'name1{suffix}'], reply)
if shared.args.chat:
reply = reply.replace('\n', '<br>')
shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str+reply
else:
for i, entry in enumerate(shared.history['visible']):
vis_reply = entry[1]
if vis_reply.startswith('<audio'):
shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str
if shared.args.cai_chat:
return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character)
else:
return shared.history['visible']
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
they are fed into the model. they are fed into the model.
""" """
# Remove autoplay from previous chat history
if (shared.args.chat or shared.args.cai_chat)and len(shared.history['internal'])>0:
[visible_text, visible_reply] = shared.history['visible'][-1]
vis_rep_clean = visible_reply.replace('controls autoplay>','controls>')
shared.history['visible'][-1] = [visible_text, vis_rep_clean]
return string return string
def output_modifier(string): def output_modifier(string):
@ -46,7 +117,7 @@ def output_modifier(string):
This function is applied to the model outputs. This function is applied to the model outputs.
""" """
global wav_idx, model, current_params global model, current_params
for i in params: for i in params:
if params[i] != current_params[i]: if params[i] != current_params[i]:
@ -57,20 +128,34 @@ def output_modifier(string):
if params['activate'] == False: if params['activate'] == False:
return string return string
orig_string = string
string = remove_surrounded_chars(string) string = remove_surrounded_chars(string)
string = string.replace('"', '') string = string.replace('"', '')
string = string.replace('', '') string = string.replace('', '')
string = string.replace('\n', ' ') string = string.replace('\n', ' ')
string = string.strip() string = string.strip()
silent_string = False # Used to prevent unnecessary audio file generation
if string == '': if string == '':
string = 'empty reply, try regenerating' string = 'empty reply, try regenerating'
silent_string = True
output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav') pitch = params['voice_pitch']
model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) speed = params['voice_speed']
prosody=f'<prosody rate="{speed}" pitch="{pitch}">'
string = '<speak>'+prosody+xmlesc(string)+'</prosody></speak>'
string = f'<audio src="file/{output_file.as_posix()}" controls></audio>' if not shared.still_streaming and not silent_string:
wav_idx += 1 output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
model.save_wav(ssml_text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
autoplay_str = ' autoplay' if params['autoplay'] else ''
string = f'<audio src="file/{output_file.as_posix()}" controls{autoplay_str}></audio>\n\n'
else:
# Placeholder so text doesn't shift around so much
string = '<audio controls></audio>\n\n'
if params['show_text']:
string += orig_string
return string return string
@ -85,9 +170,36 @@ def bot_prefix_modifier(string):
def ui(): def ui():
# Gradio elements # Gradio elements
activate = gr.Checkbox(value=params['activate'], label='Activate TTS') with gr.Accordion("Silero TTS"):
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice') with gr.Row():
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
with gr.Row():
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
with gr.Row():
convert = gr.Button('Permanently replace chat history audio with message text')
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
convert_cancel = gr.Button('Cancel', visible=False)
# Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
convert_confirm.click(remove_tts_from_history, [], shared.gradio['display'])
convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
show_text.change(toggle_text_in_history, [], shared.gradio['display'])
show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None) activate.change(lambda x: params.update({"activate": x}), activate, None)
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
voice.change(lambda x: params.update({"speaker": x}), voice, None) voice.change(lambda x: params.update({"speaker": x}), voice, None)
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)

View file

@ -1,12 +1,11 @@
import os import os
from pathlib import Path from pathlib import Path
from queue import Queue
from threading import Thread
import numpy as np import numpy as np
from tokenizers import Tokenizer from tokenizers import Tokenizer
import modules.shared as shared import modules.shared as shared
from modules.callbacks import Iteratorize
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
@ -49,11 +48,11 @@ class RWKVModel:
return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
def generate_with_streaming(self, **kwargs): def generate_with_streaming(self, **kwargs):
iterable = Iteratorize(self.generate, kwargs, callback=None) with Iteratorize(self.generate, kwargs, callback=None) as generator:
reply = kwargs['context'] reply = kwargs['context']
for token in iterable: for token in generator:
reply += token reply += token
yield reply yield reply
class RWKVTokenizer: class RWKVTokenizer:
def __init__(self): def __init__(self):
@ -73,38 +72,3 @@ class RWKVTokenizer:
def decode(self, ids): def decode(self, ids):
return self.tokenizer.decode(ids) return self.tokenizer.decode(ids)
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc=func
self.c_callback=callback
self.q = Queue(maxsize=1)
self.sentinel = object()
self.kwargs = kwargs
def _callback(val):
self.q.put(val)
def gentask():
ret = self.mfunc(callback=_callback, **self.kwargs)
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
Thread(target=gentask).start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True,None)
if obj is self.sentinel:
raise StopIteration
else:
return obj

98
modules/callbacks.py Normal file
View file

@ -0,0 +1,98 @@
import gc
from queue import Queue
from threading import Thread
import torch
import transformers
import modules.shared as shared
# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor,
starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
def __call__(self, input_ids: torch.LongTensor,
_scores: torch.FloatTensor) -> bool:
for sample in input_ids:
trimmed_sample = sample[self.starting_idx:]
# Can't unfold, output is still too tiny. Skip.
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
continue
for window in trimmed_sample.unfold(
0, self.sentinel_token_ids.shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids, window)):
return True
return False
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc=func
self.c_callback=callback
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
if self.stop_now:
raise ValueError
self.q.put(val)
def gentask():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
clear_torch_cache()
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
self.thread = Thread(target=gentask)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True,None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __del__(self):
clear_torch_cache()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True
clear_torch_cache()
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()

View file

@ -84,6 +84,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
tmp = f"\n{asker}:" tmp = f"\n{asker}:"
for j in range(1, len(tmp)): for j in range(1, len(tmp)):
if reply[-j:] == tmp[:j]: if reply[-j:] == tmp[:j]:
reply = reply[:-j]
substring_found = True substring_found = True
return reply, next_character_found, substring_found return reply, next_character_found, substring_found
@ -91,7 +92,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
def stop_everything_event(): def stop_everything_event():
shared.stop_everything = True shared.stop_everything = True
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
shared.stop_everything = False shared.stop_everything = False
just_started = True just_started = True
eos_token = '\n' if check else None eos_token = '\n' if check else None
@ -120,6 +121,10 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
else: else:
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
if not regenerate:
# Display user input and "*is typing...*" imediately
yield shared.history['visible']+[[visible_text, '*Is typing...*']]
# Generate # Generate
reply = '' reply = ''
for i in range(chat_generation_attempts): for i in range(chat_generation_attempts):
@ -158,6 +163,9 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
# Display "*is typing...*" imediately
yield '*Is typing...*'
reply = '' reply = ''
for i in range(chat_generation_attempts): for i in range(chat_generation_attempts):
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
@ -182,7 +190,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
last_visible = shared.history['visible'].pop() last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop() last_internal = shared.history['internal'].pop()
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
if shared.args.cai_chat: if shared.args.cai_chat:
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
@ -291,7 +299,7 @@ def save_history(timestamp=True):
fname = f"{prefix}persistent.json" fname = f"{prefix}persistent.json"
if not Path('logs').exists(): if not Path('logs').exists():
Path('logs').mkdir() Path('logs').mkdir()
with open(Path(f'logs/{fname}'), 'w') as f: with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
return Path(f'logs/{fname}') return Path(f'logs/{fname}')
@ -332,7 +340,7 @@ def load_character(_character, name1, name2):
shared.history['visible'] = [] shared.history['visible'] = []
if _character != 'None': if _character != 'None':
shared.character = _character shared.character = _character
data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read()) data = json.loads(open(Path(f'characters/{_character}.json'), 'r', encoding='utf-8').read())
name2 = data['char_name'] name2 = data['char_name']
if 'char_persona' in data and data['char_persona'] != '': if 'char_persona' in data and data['char_persona'] != '':
context += f"{data['char_name']}'s Persona: {data['char_persona']}\n" context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
@ -372,7 +380,7 @@ def upload_character(json_file, img, tavern=False):
i += 1 i += 1
if tavern: if tavern:
outfile_name = f'TavernAI-{outfile_name}' outfile_name = f'TavernAI-{outfile_name}'
with open(Path(f'characters/{outfile_name}.json'), 'w') as f: with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f:
f.write(json_file) f.write(json_file)
if img is not None: if img is not None:
img = Image.open(io.BytesIO(img)) img = Image.open(io.BytesIO(img))

View file

@ -1,5 +1,6 @@
import json import json
import os import os
import sys
import time import time
import zipfile import zipfile
from pathlib import Path from pathlib import Path
@ -41,7 +42,7 @@ def load_model(model_name):
shared.is_RWKV = model_name.lower().startswith('rwkv-') shared.is_RWKV = model_name.lower().startswith('rwkv-')
# Default settings # Default settings
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV): if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.gptq_bits > 0, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else: else:
@ -86,6 +87,12 @@ def load_model(model_name):
return model, tokenizer return model, tokenizer
# 4-bit LLaMA
elif shared.args.gptq_bits > 0 or shared.args.load_in_4bit:
from modules.quantized_LLaMA import load_quantized_LLaMA
model = load_quantized_LLaMA(model_name)
# Custom # Custom
else: else:
command = "AutoModelForCausalLM.from_pretrained" command = "AutoModelForCausalLM.from_pretrained"

View file

@ -0,0 +1,60 @@
import os
import sys
from pathlib import Path
import accelerate
import torch
import modules.shared as shared
sys.path.insert(0, os.path.abspath(Path("repositories/GPTQ-for-LLaMa")))
from llama import load_quant
# 4-bit LLaMA
def load_quantized_LLaMA(model_name):
if shared.args.load_in_4bit:
bits = 4
else:
bits = shared.args.gptq_bits
path_to_model = Path(f'models/{model_name}')
pt_model = ''
if path_to_model.name.lower().startswith('llama-7b'):
pt_model = f'llama-7b-{bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-13b'):
pt_model = f'llama-13b-{bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-30b'):
pt_model = f'llama-30b-{bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-65b'):
pt_model = f'llama-65b-{bits}bit.pt'
else:
pt_model = f'{model_name}-{bits}bit.pt'
# Try to find the .pt both in models/ and in the subfolder
pt_path = None
for path in [Path(p) for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists():
pt_path = path
if not pt_path:
print(f"Could not find {pt_model}, exiting...")
exit()
model = load_quant(path_to_model, os.path.abspath(pt_path), bits)
# Multi-GPU setup
if shared.args.gpu_memory:
max_memory = {}
for i in range(len(shared.args.gpu_memory)):
max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"])
model = accelerate.dispatch_model(model, device_map=device_map)
# Single GPU
else:
model = model.to(torch.device('cuda:0'))
return model

View file

@ -11,6 +11,7 @@ is_RWKV = False
history = {'internal': [], 'visible': []} history = {'internal': [], 'visible': []}
character = 'None' character = 'None'
stop_everything = False stop_everything = False
still_streaming = False
# UI elements (buttons, sliders, HTML, etc) # UI elements (buttons, sliders, HTML, etc)
gradio = {} gradio = {}
@ -42,12 +43,12 @@ settings = {
'default': 'NovelAI-Sphinx Moth', 'default': 'NovelAI-Sphinx Moth',
'pygmalion-*': 'Pygmalion', 'pygmalion-*': 'Pygmalion',
'RWKV-*': 'Naive', 'RWKV-*': 'Naive',
'(rosey|chip|joi)_.*_instruct.*': 'Instruct Joi (Contrastive Search)'
}, },
'prompts': { 'prompts': {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', 'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
'^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n', '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
'(rosey|chip|joi)_.*_instruct.*': 'User: \n' '(rosey|chip|joi)_.*_instruct.*': 'User: \n',
'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
} }
} }
@ -68,6 +69,8 @@ parser.add_argument('--chat', action='store_true', help='Launch the web UI in ch
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.') parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision. Currently only works with LLaMA.')
parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.') parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
@ -90,4 +93,5 @@ parser.add_argument('--listen', action='store_true', help='Make the web UI reach
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch')
args = parser.parse_args() args = parser.parse_args()

View file

@ -1,32 +0,0 @@
'''
This code was copied from
https://github.com/PygmalionAI/gradio-ui/
'''
import torch
import transformers
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor,
starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
def __call__(self, input_ids: torch.LongTensor,
_scores: torch.FloatTensor) -> bool:
for sample in input_ids:
trimmed_sample = sample[self.starting_idx:]
# Can't unfold, output is still too tiny. Skip.
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
continue
for window in trimmed_sample.unfold(
0, self.sentinel_token_ids.shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids, window)):
return True
return False

View file

@ -5,13 +5,13 @@ import time
import numpy as np import numpy as np
import torch import torch
import transformers import transformers
from tqdm import tqdm
import modules.shared as shared import modules.shared as shared
from modules.callbacks import (Iteratorize, Stream,
_SentinelTokenStoppingCriteria)
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html, generate_basic_html from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.models import local_rank from modules.models import local_rank
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
def get_max_prompt_length(tokens): def get_max_prompt_length(tokens):
@ -92,19 +92,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# These models are not part of Hugging Face, so we handle them # These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier # separately and terminate the function call earlier
if shared.is_RWKV: if shared.is_RWKV:
if shared.args.no_stream: try:
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) if shared.args.no_stream:
yield formatted_outputs(reply, shared.model_name) reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
else:
yield formatted_outputs(question, shared.model_name)
# RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
else:
t1 = time.time() yield formatted_outputs(question, shared.model_name)
print(f"Output generated in {(t1-t0):.2f} seconds.") # RWKV has proper streaming, which is very nice.
return # No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
yield formatted_outputs(reply, shared.model_name)
finally:
t1 = time.time()
output = encode(reply)[0]
input_ids = encode(question)
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
return
original_question = question original_question = question
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
@ -113,24 +116,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
print(f"\n\n{question}\n--------------------\n") print(f"\n\n{question}\n--------------------\n")
input_ids = encode(question, max_new_tokens) input_ids = encode(question, max_new_tokens)
original_input_ids = input_ids
output = input_ids[0]
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1]) eos_token_ids = [shared.tokenizer.eos_token_id]
if eos_token is not None:
eos_token_ids.append(int(encode(eos_token)[0][-1]))
stopping_criteria_list = transformers.StoppingCriteriaList()
if stopping_string is not None: if stopping_string is not None:
# The stopping_criteria code below was copied from # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
t = encode(stopping_string, 0, add_special_tokens=False) t = encode(stopping_string, 0, add_special_tokens=False)
stopping_criteria_list = transformers.StoppingCriteriaList([ stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
_SentinelTokenStoppingCriteria(
sentinel_token_ids=t,
starting_idx=len(input_ids[0])
)
])
else:
stopping_criteria_list = None
if not shared.args.flexgen: if not shared.args.flexgen:
generate_params = [ generate_params = [
f"eos_token_id={n}", f"max_new_tokens=max_new_tokens",
f"eos_token_id={eos_token_ids}",
f"stopping_criteria=stopping_criteria_list", f"stopping_criteria=stopping_criteria_list",
f"do_sample={do_sample}", f"do_sample={do_sample}",
f"temperature={temperature}", f"temperature={temperature}",
@ -147,44 +148,23 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
] ]
else: else:
generate_params = [ generate_params = [
f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
f"do_sample={do_sample}", f"do_sample={do_sample}",
f"temperature={temperature}", f"temperature={temperature}",
f"stop={n}", f"stop={eos_token_ids[-1]}",
] ]
if shared.args.deepspeed: if shared.args.deepspeed:
generate_params.append("synced_gpus=True") generate_params.append("synced_gpus=True")
if shared.args.no_stream:
generate_params.append("max_new_tokens=max_new_tokens")
else:
generate_params.append("max_new_tokens=8")
if shared.soft_prompt: if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
generate_params.insert(0, "inputs_embeds=inputs_embeds") generate_params.insert(0, "inputs_embeds=inputs_embeds")
generate_params.insert(0, "filler_input_ids") generate_params.insert(0, "inputs=filler_input_ids")
else: else:
generate_params.insert(0, "input_ids") generate_params.insert(0, "inputs=input_ids")
# Generate the entire reply at once
if shared.args.no_stream:
with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
yield formatted_outputs(reply, shared.model_name)
# Generate the reply 8 tokens at a time
else:
yield formatted_outputs(original_question, shared.model_name)
for i in tqdm(range(max_new_tokens//8+1)):
clear_torch_cache()
try:
# Generate the entire reply at once.
if shared.args.no_stream:
with torch.no_grad(): with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
if shared.soft_prompt: if shared.soft_prompt:
@ -193,16 +173,66 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
reply = decode(output) reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output") reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
if not shared.args.flexgen: # Stream the reply 1 token at a time.
if output[-1] == n: # This is based on the trick of using 'stopping_criteria' to create an iterator.
break elif not shared.args.flexgen:
input_ids = torch.reshape(output, (1, output.shape[0]))
else:
if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
break
input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt: def generate_with_callback(callback=None, **kwargs):
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) kwargs['stopping_criteria'].append(Stream(callback_func=callback))
clear_torch_cache()
with torch.no_grad():
shared.model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
shared.still_streaming = True
yield formatted_outputs(original_question, shared.model_name)
with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
for output in generator:
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
if output[-1] in eos_token_ids:
break
yield formatted_outputs(reply, shared.model_name)
shared.still_streaming = False
yield formatted_outputs(reply, shared.model_name)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
shared.still_streaming = True
for i in range(max_new_tokens//8+1):
clear_torch_cache()
with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
yield formatted_outputs(reply, shared.model_name)
input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
shared.still_streaming = False
yield formatted_outputs(reply, shared.model_name)
finally:
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
return

View file

@ -1,9 +1,11 @@
accelerate==0.16.0 accelerate==0.17.0
bitsandbytes==0.37.0 bitsandbytes==0.37.0
flexgen==0.1.7 flexgen==0.1.7
gradio==3.18.0 gradio==3.18.0
numpy numpy
rwkv==0.1.0 requests
safetensors==0.2.8 rwkv==0.3.1
safetensors==0.3.0
sentencepiece sentencepiece
git+https://github.com/oobabooga/transformers@llama_push tqdm
git+https://github.com/zphang/transformers@llama_push

View file

@ -18,9 +18,6 @@ from modules.html_generator import generate_chat_html
from modules.models import load_model, load_soft_prompt from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply from modules.text_generation import generate_reply
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n')
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists(): if shared.args.settings is not None and Path(shared.args.settings).exists():
@ -37,7 +34,7 @@ def get_available_models():
if shared.args.flexgen: if shared.args.flexgen:
return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower) return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
else: else:
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower) return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower)
def get_available_presets(): def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
@ -272,10 +269,10 @@ if shared.args.chat or shared.args.cai_chat:
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False))
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events) shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
@ -309,6 +306,7 @@ if shared.args.chat or shared.args.cai_chat:
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
@ -372,9 +370,9 @@ else:
shared.gradio['interface'].queue() shared.gradio['interface'].queue()
if shared.args.listen: if shared.args.listen:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
else: else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
# I think that I will need this later # I think that I will need this later
while True: while True:

View file

@ -29,6 +29,7 @@
"prompts": { "prompts": {
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n", "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
"(rosey|chip|joi)_.*_instruct.*": "User: \n" "(rosey|chip|joi)_.*_instruct.*": "User: \n",
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
} }
} }