Merge branch 'dev' into IllogicalDesigns-main

This commit is contained in:
oobabooga 2023-07-13 14:15:57 -07:00
commit 5bfe55fc76
52 changed files with 2011 additions and 960 deletions

View file

@ -23,7 +23,8 @@ async def run(user_input, history):
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',
'instruction_template': 'Vicuna-v1.1',
'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
# 'context_instruct': '', # Optional
'your_name': 'You',
'regenerate': False,
@ -34,7 +35,7 @@ async def run(user_input, history):
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,

View file

@ -17,7 +17,8 @@ def run(user_input, history):
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',
'instruction_template': 'Vicuna-v1.1',
'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
# 'context_instruct': '', # Optional
'your_name': 'You',
'regenerate': False,

View file

@ -4,8 +4,9 @@ import requests
HOST = '0.0.0.0:5000'
def generate(prompt, tokens = 200):
request = { 'prompt': prompt, 'max_new_tokens': tokens }
def generate(prompt, tokens=200):
request = {'prompt': prompt, 'max_new_tokens': tokens}
response = requests.post(f'http://{HOST}/api/v1/generate', json=request)
if response.status_code == 200:
@ -23,7 +24,7 @@ def print_basic_model_info(response):
print("Model: ", response['result']['model_name'])
print("Lora(s): ", response['result']['lora_names'])
for setting in basic_settings:
print(setting, "=", response['result']['shared.settings'][setting])
print(setting, "=", response['result']['shared.settings'][setting])
# model info
@ -54,7 +55,7 @@ def complex_model_load(model):
'action': 'load',
'model_name': model,
'args': {
'gptq_for_llama': False, # Use AutoGPTQ by default, set to True for gptq-for-llama
'loader': 'AutoGPTQ',
'bf16': False,
'load_in_8bit': False,
@ -74,18 +75,18 @@ def complex_model_load(model):
'rwkv_strategy': None,
'rwkv_cuda_on': False,
# b&b 4-bit
#'load_in_4bit': False,
#'compute_dtype': 'float16',
#'quant_type': 'nf4',
#'use_double_quant': False,
# b&b 4-bit
# 'load_in_4bit': False,
# 'compute_dtype': 'float16',
# 'quant_type': 'nf4',
# 'use_double_quant': False,
#"cpu": false,
#"auto_devices": false,
#"gpu_memory": null,
#"cpu_memory": null,
#"disk": false,
#"disk_cache_dir": "cache",
# "cpu": false,
# "auto_devices": false,
# "gpu_memory": null,
# "cpu_memory": null,
# "disk": false,
# "disk_cache_dir": "cache",
},
}
@ -104,26 +105,25 @@ def complex_model_load(model):
req['args']['load_in_8bit'] = True
elif '-hf' in model or 'fp16' in model:
if '7b' in model:
req['args']['bf16'] = True # for 24GB
req['args']['bf16'] = True # for 24GB
elif '13b' in model:
req['args']['load_in_8bit'] = True # for 24GB
req['args']['load_in_8bit'] = True # for 24GB
elif 'ggml' in model:
#req['args']['threads'] = 16
# req['args']['threads'] = 16
if '7b' in model:
req['args']['n_gpu_layers'] = 100
elif '13b' in model:
req['args']['n_gpu_layers'] = 100
elif '30b' in model or '33b' in model:
req['args']['n_gpu_layers'] = 59 # 24GB
req['args']['n_gpu_layers'] = 59 # 24GB
elif '65b' in model:
req['args']['n_gpu_layers'] = 42 # 24GB
req['args']['n_gpu_layers'] = 42 # 24GB
elif 'rwkv' in model:
req['args']['rwkv_cuda_on'] = True
if '14b' in model:
req['args']['rwkv_strategy'] = 'cuda f16i8' # 24GB
req['args']['rwkv_strategy'] = 'cuda f16i8' # 24GB
else:
req['args']['rwkv_strategy'] = 'cuda f16' # 24GB
req['args']['rwkv_strategy'] = 'cuda f16' # 24GB
return model_api(req)
@ -134,7 +134,7 @@ if __name__ == '__main__':
resp = complex_model_load(model)
if 'error' in resp:
print (f"{model} FAIL Error: {resp['error']['message']}")
print(f"{model} FAIL Error: {resp['error']['message']}")
continue
else:
print_basic_model_info(resp)
@ -142,17 +142,17 @@ if __name__ == '__main__':
ans = generate("0,1,1,2,3,5,8,13,", tokens=2)
if '21' in ans:
print (f"{model} PASS ({ans})")
print(f"{model} PASS ({ans})")
else:
print (f"{model} FAIL ({ans})")
print(f"{model} FAIL ({ans})")
except Exception as e:
print (f"{model} FAIL Exception: {repr(e)}")
print(f"{model} FAIL Exception: {repr(e)}")
# 0,1,1,2,3,5,8,13, is the fibonacci sequence, the next number is 21.
# Some results below.
""" $ ./model-api-example.py
""" $ ./model-api-example.py
Model: 4bit_gpt4-x-alpaca-13b-native-4bit-128g-cuda
Lora(s): []
truncation_length = 2048

View file

@ -23,7 +23,7 @@ async def run(context):
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,

View file

@ -15,7 +15,7 @@ def run(prompt):
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,

View file

@ -5,13 +5,13 @@ services:
context: .
args:
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}
WEBUI_VERSION: ${WEBUI_VERSION}
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-7.5}
WEBUI_VERSION: ${WEBUI_VERSION:-HEAD}
env_file: .env
ports:
- "${HOST_PORT}:${CONTAINER_PORT}"
- "${HOST_API_PORT}:${CONTAINER_API_PORT}"
- "${HOST_API_STREAM_PORT}:${CONTAINER_API_STREAM_PORT}"
- "${HOST_PORT:-7860}:${CONTAINER_PORT:-7860}"
- "${HOST_API_PORT:-5000}:${CONTAINER_API_PORT:-5000}"
- "${HOST_API_STREAM_PORT:-5005}:${CONTAINER_API_STREAM_PORT:-5005}"
stdin_open: true
tty: true
volumes:

View file

@ -23,13 +23,15 @@ from tqdm.contrib.concurrent import thread_map
class ModelDownloader:
def __init__(self, max_retries = 5):
def __init__(self, max_retries=5):
self.s = requests.Session()
if max_retries:
self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
if os.getenv('HF_TOKEN') is not None:
self.s.headers = {'authorization': f'Bearer {os.getenv("HF_TOKEN")}'}
def sanitize_model_and_branch_names(self, model, branch):
if model[-1] == '/':
@ -73,11 +75,11 @@ class ModelDownloader:
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
is_lora = True
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_pytorch = re.match("(pytorch|adapter|gptq)_model.*\.bin", fname)
is_safetensors = re.match(".*\.safetensors", fname)
is_pt = re.match(".*\.pt", fname)
is_ggml = re.match(".*ggml.*\.bin", fname)
is_tokenizer = re.match("(tokenizer|ice).*\.model", fname)
is_tokenizer = re.match("(tokenizer|ice|spiece).*\.model", fname)
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)):
if 'lfs' in dict[i]:

View file

@ -4,7 +4,7 @@ from threading import Thread
from websockets.server import serve
from extensions.api.util import build_parameters, try_start_cloudflared
from extensions.api.util import build_parameters, try_start_cloudflared, with_api_lock
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply
@ -12,72 +12,82 @@ from modules.text_generation import generate_reply
PATH = '/api/v1/stream'
@with_api_lock
async def _handle_stream_message(websocket, message):
message = json.loads(message)
prompt = message['prompt']
generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
# As we stream, only send the new bytes.
skip_index = 0
message_num = 0
for a in generator:
to_send = a[skip_index:]
if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet.
continue
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': to_send
}))
await asyncio.sleep(0)
skip_index += len(to_send)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
@with_api_lock
async def _handle_chat_stream_message(websocket, message):
body = json.loads(message)
user_input = body['user_input']
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = True
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generator = generate_chat_reply(
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
message_num = 0
for a in generator:
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'history': a
}))
await asyncio.sleep(0)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
async def _handle_connection(websocket, path):
if path == '/api/v1/stream':
async for message in websocket:
message = json.loads(message)
prompt = message['prompt']
generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
# As we stream, only send the new bytes.
skip_index = 0
message_num = 0
for a in generator:
to_send = a[skip_index:]
if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet.
continue
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': to_send
}))
await asyncio.sleep(0)
skip_index += len(to_send)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
await _handle_stream_message(websocket, message)
elif path == '/api/v1/chat-stream':
async for message in websocket:
body = json.loads(message)
user_input = body['user_input']
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = True
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generator = generate_chat_reply(
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
message_num = 0
for a in generator:
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'history': a
}))
await asyncio.sleep(0)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
await _handle_chat_stream_message(websocket, message)
else:
print(f'Streaming api: unknown path: {path}')

View file

@ -1,3 +1,6 @@
import asyncio
import functools
import threading
import time
import traceback
from threading import Thread
@ -8,6 +11,13 @@ from modules.chat import load_character_memoized
from modules.presets import load_preset_memoized
# We use a thread local to store the asyncio lock, so that each thread
# has its own lock. This isn't strictly necessary, but it makes it
# such that if we can support multiple worker threads in the future,
# thus handling multiple requests in parallel.
api_tls = threading.local()
def build_parameters(body, chat=False):
generate_params = {
@ -49,7 +59,10 @@ def build_parameters(body, chat=False):
if chat:
character = body.get('character')
instruction_template = body.get('instruction_template')
instruction_template = body.get('instruction_template', shared.settings['instruction_template'])
if str(instruction_template) == "None":
instruction_template = "Vicuna-v1.1"
name1, name2, _, greeting, context, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), shared.settings['name2'], instruct=False)
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
generate_params.update({
@ -62,7 +75,7 @@ def build_parameters(body, chat=False):
'greeting': greeting,
'name1_instruct': name1_instruct,
'name2_instruct': name2_instruct,
'context_instruct': context_instruct,
'context_instruct': body.get('context_instruct', context_instruct),
'turn_template': turn_template,
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
'history': body.get('history', {'internal': [], 'visible': []})
@ -97,3 +110,35 @@ def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Call
time.sleep(3)
raise Exception('Could not start cloudflared.')
def _get_api_lock(tls) -> asyncio.Lock:
"""
The streaming and blocking API implementations each run on their own
thread, and multiplex requests using asyncio. If multiple outstanding
requests are received at once, we will try to acquire the shared lock
shared.generation_lock multiple times in succession in the same thread,
which will cause a deadlock.
To avoid this, we use this wrapper function to block on an asyncio
lock, and then try and grab the shared lock only while holding
the asyncio lock.
"""
if not hasattr(tls, "asyncio_lock"):
tls.asyncio_lock = asyncio.Lock()
return tls.asyncio_lock
def with_api_lock(func):
"""
This decorator should be added to all streaming API methods which
require access to the shared.generation_lock. It ensures that the
tls.asyncio_lock is acquired before the method is called, and
released afterwards.
"""
@functools.wraps(func)
async def api_wrapper(*args, **kwargs):
async with _get_api_lock(api_tls):
return await func(*args, **kwargs)
return api_wrapper

View file

@ -6,6 +6,7 @@ import gradio as gr
from modules import chat, shared
from modules.utils import gradio
from modules.logging_colors import logger
params = {
'activate': True,
@ -13,10 +14,12 @@ params = {
'selected_voice': 'None',
'autoplay': False,
'show_text': True,
'model': 'eleven_monolingual_v1',
}
voices = None
wav_idx = 0
LANG_MODELS = ['eleven_monolingual_v1', 'eleven_multilingual_v1']
def update_api_key(key):
@ -108,7 +111,7 @@ def output_modifier(string):
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.mp3'.format(wav_idx))
print(f'Outputting audio to {str(output_file)}')
try:
audio = elevenlabs.generate(text=string, voice=params['selected_voice'], model="eleven_monolingual_v1")
audio = elevenlabs.generate(text=string, voice=params['selected_voice'], model=params['model'])
elevenlabs.save(audio, str(output_file))
autoplay = 'autoplay' if params['autoplay'] else ''
@ -132,7 +135,12 @@ def ui():
global voices
if not voices:
voices = refresh_voices()
params['selected_voice'] = voices[0]
selected = params['selected_voice']
if selected == 'None':
params['selected_voice'] = voices[0]
elif selected not in voices:
logger.error(f'Selected voice {selected} not available, switching to {voices[0]}')
params['selected_voice'] = voices[0]
# Gradio elements
with gr.Row():
@ -145,7 +153,14 @@ def ui():
refresh = gr.Button(value='Refresh')
with gr.Row():
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
if params['api_key']:
api_key = gr.Textbox(value=params['api_key'], label='API Key')
update_api_key(params['api_key'])
else:
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
with gr.Row():
model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
with gr.Row():
convert = gr.Button('Permanently replace audios with the message texts')
@ -175,6 +190,7 @@ def ui():
activate.change(lambda x: params.update({'activate': x}), activate, None)
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
api_key.change(update_api_key, api_key, None)
model.change(lambda x: params.update({'model': x}), model, None)
# connect.click(check_valid_api, [], connection_status)
refresh.click(refresh_voices_dd, [], voice)
# Event functions to update the parameters in the backend

View file

@ -2,6 +2,7 @@ import gradio as gr
from deep_translator import GoogleTranslator
params = {
"activate": True,
"language string": "ja",
}
@ -13,6 +14,8 @@ def input_modifier(string):
This function is applied to your text inputs before
they are fed into the model.
"""
if not params['activate']:
return string
return GoogleTranslator(source=params['language string'], target='en').translate(string)
@ -21,6 +24,8 @@ def output_modifier(string):
"""
This function is applied to the model outputs.
"""
if not params['activate']:
return string
return GoogleTranslator(source='en', target=params['language string']).translate(string)
@ -40,7 +45,12 @@ def ui():
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
# Gradio elements
language = gr.Dropdown(value=language_name, choices=[k for k in language_codes], label='Language')
with gr.Row():
activate = gr.Checkbox(value=params['activate'], label='Activate translation')
with gr.Row():
language = gr.Dropdown(value=language_name, choices=[k for k in language_codes], label='Language')
# Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None)
language.change(lambda x: params.update({"language string": language_codes[x]}), language, None)

View file

@ -38,6 +38,8 @@ As of now, the following multimodal pipelines are supported:
|[LLaVA 7B](https://github.com/haotian-liu/LLaVA)|`llava-7b`|[LLaVA 7B](https://huggingface.co/wojtab/llava-7b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|[MiniGPT-4 7B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-7b`|[Vicuna v0 7B](https://huggingface.co/TheBloke/vicuna-7B-GPTQ-4bit-128g)|GPTQ 4-bit quant, new format|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|[MiniGPT-4 13B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-13b`|[Vicuna v0 13B](https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g)|GPTQ 4-bit quant, old CUDA|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|[InstructBLIP 7B](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip)|`instructblip-7b`|[Vicuna v1.1 7B](https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g)|GPTQ 4-bit quant|[kjerk/instructblip-pipeline](https://github.com/kjerk/instructblip-pipeline)|
|[InstructBLIP 13B](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip)|`instructblip-13b`|[Vicuna v1.1 13B](https://huggingface.co/TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g)|GPTQ 4-bit quant|[kjerk/instructblip-pipeline](https://github.com/kjerk/instructblip-pipeline)|
Some pipelines could support different LLMs but do note that while it might work, it isn't a supported configuration.

View file

@ -1,8 +1,8 @@
# Adds ngrok ingress, to use add `--extension ngrok` to the command line options
#
# Parameters can be customized in settings.json of webui, e.g.:
# Parameters can be customized in settings.json of webui, e.g.:
# {"ngrok": {"basic_auth":"user:password"} }
# or
# or
# {"ngrok": {"oauth_provider":"google", "oauth_allow_emails":["asdf@asdf.com"]} }
#
# See this example for full list of options: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
@ -22,6 +22,7 @@ options = {
'session_metadata': 'text-generation-webui',
}
def ui():
settings = shared.settings.get("ngrok")
if settings:
@ -33,4 +34,3 @@ def ui():
logging.info(f"Ingress established at: {tunnel.url()}")
except ModuleNotFoundError:
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")

View file

@ -218,12 +218,11 @@ but there are some exceptions.
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |
## Future plans
* better error handling
* model changing, esp. something for swapping loras or embedding models
* consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
* do something about rate limiting or locking requests for completions, most systems will only be able handle a single request at a time before OOM
## Bugs? Feedback? Comments? Pull requests?

View file

@ -0,0 +1,597 @@
import time
import yaml
import tiktoken
import torch
import torch.nn.functional as F
from transformers import LogitsProcessor, LogitsProcessorList
from modules import shared
from modules.text_generation import encode, decode, generate_reply
from extensions.openai.defaults import get_default_req_params, default, clamp
from extensions.openai.utils import end_line, debug_msg
from extensions.openai.errors import *
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
class LogitsBiasProcessor(LogitsProcessor):
def __init__(self, logit_bias={}):
self.logit_bias = logit_bias
super().__init__()
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logit_bias:
keys = list([int(key) for key in self.logit_bias.keys()])
values = list([int(val) for val in self.logit_bias.values()])
logits[0, keys] += torch.tensor(values).cuda()
return logits
class LogprobProcessor(LogitsProcessor):
def __init__(self, logprobs=None):
self.logprobs = logprobs
self.token_alternatives = {}
super().__init__()
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1)
# XXX hack. should find the selected token and include the prob of that
# ... but we just +1 here instead because we don't know it yet.
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [decode(tok) for tok in top_indices[0]]
self.token_alternatives = dict(zip(top_tokens, top_values[0].tolist()))
return logits
def convert_logprobs_to_tiktoken(model, logprobs):
try:
encoder = tiktoken.encoding_for_model(model)
# just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
except KeyError:
# assume native tokens if we can't find the tokenizer
return logprobs
def marshal_common_params(body):
# Request Parameters
# Try to use openai defaults or map them to something with the same intent
req_params = get_default_req_params()
# Common request parameters
req_params['truncation_length'] = shared.settings['truncation_length']
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
# OpenAI API Parameters
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
req_params['requested_model'] = body.get('model', shared.model_name)
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0/2.0
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
n = default(body, 'n', 1)
if n != 1:
raise InvalidRequestError(message="Only n = 1 is supported.", param='n')
if 'stop' in body: # str or array, max len 4 (ignored)
if isinstance(body['stop'], str):
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
elif isinstance(body['stop'], list):
req_params['stopping_strings'] = body['stop']
# presence_penalty - ignored
# frequency_penalty - ignored
# user - ignored
logits_processor = []
logit_bias = body.get('logit_bias', None)
if logit_bias: # {str: float, ...}
# XXX convert tokens from tiktoken based on requested model
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
try:
encoder = tiktoken.encoding_for_model(req_params['requested_model'])
new_logit_bias = {}
for logit, bias in logit_bias.items():
for x in encode(encoder.decode([int(logit)]))[0]:
new_logit_bias[str(int(x))] = bias
print(logit_bias, '->', new_logit_bias)
logit_bias = new_logit_bias
except KeyError:
pass # assume native tokens if we can't find the tokenizer
logits_processor = [LogitsBiasProcessor(logit_bias)]
logprobs = None # coming to chat eventually
if 'logprobs' in body:
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
req_params['logprob_proc'] = LogprobProcessor(logprobs)
logits_processor.extend([req_params['logprob_proc']])
else:
logprobs = None
if logits_processor: # requires logits_processor support
req_params['logits_processor'] = LogitsProcessorList(logits_processor)
return req_params
def messages_to_prompt(body: dict, req_params: dict, max_tokens):
# functions
if body.get('functions', []): # chat only
raise InvalidRequestError(message="functions is not supported.", param='functions')
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
if not 'messages' in body:
raise InvalidRequestError(message="messages is required", param='messages')
messages = body['messages']
role_formats = {
'user': 'user: {message}\n',
'assistant': 'assistant: {message}\n',
'system': '{message}',
'context': 'You are a helpful assistant. Answer as concisely as possible.',
'prompt': 'assistant:',
}
if not 'stopping_strings' in req_params:
req_params['stopping_strings'] = []
# Instruct models can be much better
if shared.settings['instruction_template']:
try:
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
system_message_template = "{message}"
system_message_default = instruct['context']
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
role_formats = {
'user': user_message_template,
'assistant': bot_message_template,
'system': system_message_template,
'context': system_message_default,
'prompt': bot_prompt,
}
if 'Alpaca' in shared.settings['instruction_template']:
req_params['stopping_strings'].extend(['\n###'])
elif instruct['user']: # WizardLM and some others have no user prompt.
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
except Exception as e:
req_params['stopping_strings'].extend(['\nuser:'])
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
print("Warning: Loaded default instruction-following template for model.")
else:
req_params['stopping_strings'].extend(['\nuser:'])
print("Warning: Loaded default instruction-following template for model.")
system_msgs = []
chat_msgs = []
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
context_msg = end_line(context_msg)
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
if 'prompt' in body:
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg
for m in messages:
role = m['role']
content = m['content']
# name = m.get('name', None)
# function_call = m.get('function_call', None) # user name or function name with output in content
msg = role_formats[role].format(message=content)
if role == 'system':
system_msgs.extend([msg])
elif role == 'function':
raise InvalidRequestError(message="role: function is not supported.", param='messages')
else:
chat_msgs.extend([msg])
system_msg = '\n'.join(system_msgs)
system_msg = end_line(system_msg)
prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt']
token_count = len(encode(prompt)[0])
if token_count >= req_params['truncation_length']:
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens."
raise InvalidRequestError(message=err_msg)
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
print(f"Warning: ${err_msg}")
# raise InvalidRequestError(message=err_msg)
return prompt, token_count
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
# Chat Completions
object_type = 'chat.completions'
created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# common params
req_params = marshal_common_params(body)
req_params['stream'] = False
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
# chat default max_tokens is 'inf', but also flexible
max_tokens = 0
max_tokens_str = 'length' if is_legacy else 'max_tokens'
if max_tokens_str in body:
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
req_params['max_new_tokens'] = max_tokens
else:
req_params['max_new_tokens'] = req_params['truncation_length']
# format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
logprob_proc = req_params.pop('logprob_proc', None)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
answer = a
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name, # TODO: add Lora info?
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"message": {"role": "assistant", "content": answer}
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# resp[resp_list][0]["logprobs"] = None
return resp
# generator
def stream_chat_completions(body: dict, is_legacy: bool = False):
# Chat Completions
stream_object_type = 'chat.completions.chunk'
created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# common params
req_params = marshal_common_params(body)
req_params['stream'] = True
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
# chat default max_tokens is 'inf', but also flexible
max_tokens = 0
max_tokens_str = 'length' if is_legacy else 'max_tokens'
if max_tokens_str in body:
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
req_params['max_new_tokens'] = max_tokens
else:
req_params['max_new_tokens'] = req_params['truncation_length']
# format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
def chat_streaming_chunk(content):
# begin streaming
chunk = {
"id": cmpl_id,
"object": stream_object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": None,
# So yeah... do both methods? delta and messages.
"message": {'role': 'assistant', 'content': content},
"delta": {'role': 'assistant', 'content': content},
}],
}
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# chunk[resp_list][0]["logprobs"] = None
return chunk
yield chat_streaming_chunk('')
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
logprob_proc = req_params.pop('logprob_proc', None)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
seen_content = ''
completion_token_count = 0
for a in generator:
answer = a
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
# strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
completion_token_count += len(encode(new_content)[0])
chunk = chat_streaming_chunk(new_content)
yield chunk
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
chunk = chat_streaming_chunk('')
chunk[resp_list][0]['finish_reason'] = stop_reason
chunk['usage'] = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
yield chunk
def completions(body: dict, is_legacy: bool = False):
# Legacy
# Text Completions
object_type = 'text_completion'
created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt'
if not prompt_str in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
prompt = body[prompt_str]
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encode(encoder.decode(prompt))[0]
except KeyError:
prompt = decode(prompt)[0]
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
# common params
req_params = marshal_common_params(body)
req_params['stream'] = False
max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
token_count = len(encode(prompt)[0])
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
logprob_proc = req_params.pop('logprob_proc', None)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
answer = a
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name, # TODO: add Lora info?
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"text": answer,
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if logprob_proc:
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
else:
resp[resp_list][0]["logprobs"] = None
return resp
# generator
def stream_completions(body: dict, is_legacy: bool = False):
# Legacy
# Text Completions
# object_type = 'text_completion'
stream_object_type = 'text_completion.chunk'
created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt'
if not prompt_str in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
prompt = body[prompt_str]
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encode(encoder.decode(prompt))[0]
except KeyError:
prompt = decode(prompt)[0]
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
# common params
req_params = marshal_common_params(body)
req_params['stream'] = True
max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
token_count = len(encode(prompt)[0])
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
def text_streaming_chunk(content):
# begin streaming
chunk = {
"id": cmpl_id,
"object": stream_object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": None,
"text": content,
}],
}
if logprob_proc:
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
else:
chunk[resp_list][0]["logprobs"] = None
return chunk
yield text_streaming_chunk('')
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
logprob_proc = req_params.pop('logprob_proc', None)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
seen_content = ''
completion_token_count = 0
for a in generator:
answer = a
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
# strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
chunk = text_streaming_chunk(new_content)
completion_token_count += len(encode(new_content)[0])
yield chunk
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
chunk = text_streaming_chunk('')
chunk[resp_list][0]["finish_reason"] = stop_reason
chunk["usage"] = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
yield chunk

View file

@ -0,0 +1,67 @@
import copy
# Slightly different defaults for OpenAI's API
# Data type is important, Ex. use 0.0 for a float 0
default_req_params = {
'max_new_tokens': 16, # 'Inf' for chat
'temperature': 1.0,
'top_p': 1.0,
'top_k': 1, # choose 20 for chat in absence of another default
'repetition_penalty': 1.18,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1.0,
'suffix': None,
'stream': False,
'echo': False,
'seed': -1,
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
'truncation_length': 2048, # first use shared.settings value
'add_bos_token': True,
'do_sample': True,
'typical_p': 1.0,
'epsilon_cutoff': 0.0, # In units of 1e-4
'eta_cutoff': 0.0, # In units of 1e-4
'tfs': 1.0,
'top_a': 0.0,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0.0,
'length_penalty': 1.0,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
'ban_eos_token': False,
'skip_special_tokens': True,
'custom_stopping_strings': '',
# 'logits_processor' - conditionally passed
# 'stopping_strings' - temporarily used
# 'logprobs' - temporarily used
# 'requested_model' - temporarily used
}
def get_default_req_params():
return copy.deepcopy(default_req_params)
# little helper to get defaults if arg is present but None and should be the same type as default.
def default(dic, key, default):
val = dic.get(key, default)
if type(val) != type(default):
# maybe it's just something like 1 instead of 1.0
try:
v = type(default)(val)
if type(val)(v) == val: # if it's the same value passed in, it's ok.
return v
except:
pass
val = default
return val
def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue))

102
extensions/openai/edits.py Normal file
View file

@ -0,0 +1,102 @@
import time
import yaml
import os
from modules import shared
from extensions.openai.defaults import get_default_req_params
from extensions.openai.utils import debug_msg
from extensions.openai.errors import *
from modules.text_generation import encode, generate_reply
def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
created_time = int(time.time() * 1000)
# Request parameters
req_params = get_default_req_params()
stopping_strings = []
# Alpaca is verbose so a good default prompt
default_template = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
)
instruction_template = default_template
# Use the special instruction/input/response template for anything trained like Alpaca
if shared.settings['instruction_template']:
if 'Alpaca' in shared.settings['instruction_template']:
stopping_strings.extend(['\n###'])
else:
try:
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
template = template\
.replace('<|user|>', instruct.get('user', ''))\
.replace('<|bot|>', instruct.get('bot', ''))\
.replace('<|user-message|>', '{instruction}\n{input}')
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
if instruct['user']:
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
except Exception as e:
instruction_template = default_template
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
else:
stopping_strings.extend(['\n###'])
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
edit_task = instruction_template.format(instruction=instruction, input=input)
truncation_length = shared.settings['truncation_length']
token_count = len(encode(edit_task)[0])
max_tokens = truncation_length - token_count
if max_tokens < 1:
err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens."
raise InvalidRequestError(err_msg, param='input')
req_params['max_new_tokens'] = max_tokens
req_params['truncation_length'] = truncation_length
req_params['temperature'] = temperature
req_params['top_p'] = top_p
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
answer = ''
for a in generator:
answer = a
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
resp = {
"object": "edit",
"created": created_time,
"choices": [{
"text": answer,
"index": 0,
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
return resp

View file

@ -0,0 +1,54 @@
import os
from sentence_transformers import SentenceTransformer
from extensions.openai.utils import float_list_to_base64, debug_msg
from extensions.openai.errors import *
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
embeddings_model = None
def load_embedding_model(model):
try:
emb_model = SentenceTransformer(model)
print(f"\nLoaded embedding model: {model}, max sequence length: {emb_model.max_seq_length}")
except Exception as e:
print(f"\nError: Failed to load embedding model: {model}")
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
return emb_model
def get_embeddings_model():
global embeddings_model, st_model
if st_model and not embeddings_model:
embeddings_model = load_embedding_model(st_model) # lazy load the model
return embeddings_model
def get_embeddings_model_name():
global st_model
return st_model
def embeddings(input: list, encoding_format: str):
embeddings = get_embeddings_model().encode(input).tolist()
if encoding_format == "base64":
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
else:
data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)]
response = {
"object": "list",
"data": data,
"model": st_model, # return the real model
"usage": {
"prompt_tokens": 0,
"total_tokens": 0,
}
}
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
return response

View file

@ -0,0 +1,31 @@
class OpenAIError(Exception):
def __init__(self, message=None, code=500, internal_message=''):
self.message = message
self.code = code
self.internal_message = internal_message
def __repr__(self):
return "%s(message=%r, code=%d)" % (
self.__class__.__name__,
self.message,
self.code,
)
class InvalidRequestError(OpenAIError):
def __init__(self, message, param, code=400, error_type='InvalidRequestError', internal_message=''):
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
self.param = param
def __repr__(self):
return "%s(message=%r, code=%d, param=%s)" % (
self.__class__.__name__,
self.message,
self.code,
self.param,
)
class ServiceUnavailableError(OpenAIError):
def __init__(self, message=None, code=500, error_type='ServiceUnavailableError', internal_message=''):
super(OpenAIError, self).__init__(message, code, error_type, internal_message)

View file

@ -0,0 +1,49 @@
import os
import time
import requests
from extensions.openai.errors import *
def generations(prompt: str, size: str, response_format: str, n: int):
# Stable Diffusion callout wrapper for txt2img
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
# Will probably work best with the stock SD models.
# SD configuration is beyond the scope of this API.
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
# require changing the form data handling to accept multipart form data, also to properly support
# url return types will require file management and a web serving files... Perhaps later!
width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size
# to hack on better generation, edit default payload.
payload = {
'prompt': prompt, # ignore prompt limit of 1000 characters
'width': width,
'height': height,
'batch_size': n,
'restore_faces': True, # slightly less horrible
}
resp = {
'created': int(time.time()),
'data': []
}
# TODO: support SD_WEBUI_AUTH username:password pair.
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img"
response = requests.post(url=sd_url, json=payload)
r = response.json()
if response.status_code != 200 or 'images' not in r:
raise ServiceUnavailableError(r.get('detail', [{'msg': 'Unknown error calling Stable Diffusion'}])[0]['msg'], code=response.status_code)
# r['parameters']...
for b64_json in r['images']:
if response_format == 'b64_json':
resp['data'].extend([{'b64_json': b64_json}])
else:
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
return resp

View file

@ -0,0 +1,79 @@
from modules import shared
from modules.utils import get_available_models
from modules.models import load_model, unload_model
from modules.models_settings import (get_model_settings_from_yamls,
update_model_parameters)
from extensions.openai.embeddings import get_embeddings_model_name
from extensions.openai.errors import *
def get_current_model_list() -> list:
return [shared.model_name] # The real chat/completions model, maybe "None"
def get_pseudo_model_list() -> list:
return [ # these are expected by so much, so include some here as a dummy
'gpt-3.5-turbo',
'text-embedding-ada-002',
]
def load_model(model_name: str) -> dict:
resp = {
"id": model_name,
"object": "engine",
"owner": "self",
"ready": True,
}
if model_name not in get_pseudo_model_list() + [get_embeddings_model_name()] + get_current_model_list(): # Real model only
# No args. Maybe it works anyways!
# TODO: hack some heuristics into args for better results
shared.model_name = model_name
unload_model()
model_settings = get_model_settings_from_yamls(shared.model_name)
shared.settings.update(model_settings)
update_model_parameters(model_settings, initial=True)
if shared.settings['mode'] != 'instruct':
shared.settings['instruction_template'] = None
shared.model, shared.tokenizer = load_model(shared.model_name)
if not shared.model: # load failed.
shared.model_name = "None"
raise OpenAIError(f"Model load failed for: {shared.model_name}")
return resp
def list_models(is_legacy: bool = False) -> dict:
# TODO: Lora's?
all_model_list = get_current_model_list() + [get_embeddings_model_name()] + get_pseudo_model_list() + get_available_models()
models = {}
if is_legacy:
models = [{"id": id, "object": "engine", "owner": "user", "ready": True} for id in all_model_list]
if not shared.model:
models[0]['ready'] = False
else:
models = [{"id": id, "object": "model", "owned_by": "user", "permission": []} for id in all_model_list]
resp = {
"object": "list",
"data": models,
}
return resp
def model_info(model_name: str) -> dict:
return {
"id": model_name,
"object": "model",
"owned_by": "user",
"permission": []
}

View file

@ -0,0 +1,69 @@
import time
import numpy as np
from numpy.linalg import norm
from extensions.openai.embeddings import get_embeddings_model
moderations_disabled = False # return 0/false
category_embeddings = None
antonym_embeddings = None
categories = ["sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence"]
flag_threshold = 0.5
def get_category_embeddings():
global category_embeddings, categories
if category_embeddings is None:
embeddings = get_embeddings_model().encode(categories).tolist()
category_embeddings = dict(zip(categories, embeddings))
return category_embeddings
def cosine_similarity(a, b):
return np.dot(a, b) / (norm(a) * norm(b))
# seems most openai like with all-mpnet-base-v2
def mod_score(a, b):
return 2.0 * np.dot(a, b)
def moderations(input):
global category_embeddings, categories, flag_threshold, moderations_disabled
results = {
"id": f"modr-{int(time.time()*1e9)}",
"model": "text-moderation-001",
"results": [],
}
embeddings_model = get_embeddings_model()
if not embeddings_model or moderations_disabled:
results['results'] = [{
'categories': dict([(C, False) for C in categories]),
'category_scores': dict([(C, 0.0) for C in categories]),
'flagged': False,
}]
return results
category_embeddings = get_category_embeddings()
# input, string or array
if isinstance(input, str):
input = [input]
for in_str in input:
for ine in embeddings_model.encode([in_str]).tolist():
category_scores = dict([(C, mod_score(category_embeddings[C], ine)) for C in categories])
category_flags = dict([(C, bool(category_scores[C] > flag_threshold)) for C in categories])
flagged = any(category_flags.values())
results['results'].extend([{
'flagged': flagged,
'categories': category_flags,
'category_scores': category_scores,
}])
print(results)
return results

View file

@ -1,2 +1,3 @@
flask_cloudflared==0.0.12
sentence-transformers
sentence-transformers
tiktoken

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,38 @@
from extensions.openai.utils import float_list_to_base64
from modules.text_generation import encode, decode
def token_count(prompt):
tokens = encode(prompt)[0]
return {
'results': [{
'tokens': len(tokens)
}]
}
def token_encode(input, encoding_format=''):
# if isinstance(input, list):
tokens = encode(input)[0]
return {
'results': [{
'encoding_format': encoding_format,
'tokens': float_list_to_base64(tokens) if encoding_format == "base64" else tokens,
'length': len(tokens),
}]
}
def token_decode(tokens, encoding_format):
# if isinstance(input, list):
# if encoding_format == "base64":
# tokens = base64_to_float_list(tokens)
output = decode(tokens)[0]
return {
'results': [{
'text': output
}]
}

View file

@ -0,0 +1,29 @@
import os
import base64
import numpy as np
def float_list_to_base64(float_list):
# Convert the list to a float32 array that the OpenAPI client expects
float_array = np.array(float_list, dtype="float32")
# Get raw bytes
bytes_array = float_array.tobytes()
# Encode bytes into base64
encoded_bytes = base64.b64encode(bytes_array)
# Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes.decode('ascii')
return ascii_string
def end_line(s):
if s and s[-1] != '\n':
s = s + '\n'
return s
def debug_msg(*args, **kwargs):
if 'OPENEDAI_DEBUG' in os.environ:
print(*args, **kwargs)

View file

@ -0,0 +1,215 @@
import gradio
import torch
from transformers import LogitsProcessor
import numpy as np
from modules import shared
params = {
'color_by_perplexity': False,
'color_by_probability': False,
'ppl_scale': 15.0, # No slider for this right now, because I don't think it really needs to be changed. Very large perplexity scores don't show up often.
#'probability_dropdown': False
}
class PerplexityLogits(LogitsProcessor):
def __init__(self, verbose=False):
self.generated_token_ids = []
self.selected_probs = []
self.top_token_ids_list = []
self.top_probs_list = []
self.perplexities_list = []
self.last_probs = None
self.verbose = verbose
def __call__(self, input_ids, scores):
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
log_probs = torch.nan_to_num(torch.log(probs))
entropy = -torch.sum(probs*log_probs)
entropy = entropy.cpu().numpy()
perplexity = round(float(np.exp(entropy)), 4)
self.perplexities_list.append(perplexity)
last_token_id = int(input_ids[0][-1].cpu().numpy().item())
# Store the generated tokens (not sure why this isn't accessible in the output endpoint!)
self.generated_token_ids.append(last_token_id)
# Get last probability, and add to the list if it wasn't there
if len(self.selected_probs) > 0:
# Is the selected token in the top tokens?
if self.verbose:
print(shared.tokenizer.decode(last_token_id))
print([shared.tokenizer.decode(token_id) for token_id in self.top_token_ids_list[-1]])
print(self.top_probs_list[-1])
if last_token_id in self.top_token_ids_list[-1]:
idx = self.top_token_ids_list[-1].index(last_token_id)
self.selected_probs.append(self.top_probs_list[-1][idx])
else:
self.top_token_ids_list[-1].append(last_token_id)
last_prob = round(float(self.last_probs[last_token_id]), 4)
self.top_probs_list[-1].append(last_prob)
self.selected_probs.append(last_prob)
else:
self.selected_probs.append(1.0) # Placeholder for the last token of the prompt
if self.verbose:
pplbar = "-"
if not np.isnan(perplexity):
pplbar = "*"*round(perplexity)
print(f"{last_token}\t{perplexity:.2f}\t{pplbar}")
# Get top 5 probabilities
top_tokens_and_probs = torch.topk(probs, 5)
top_probs = top_tokens_and_probs.values.cpu().numpy().astype(float).tolist()
top_token_ids = top_tokens_and_probs.indices.cpu().numpy().astype(int).tolist()
self.top_token_ids_list.append(top_token_ids)
self.top_probs_list.append(top_probs)
probs = probs.cpu().numpy().flatten()
self.last_probs = probs # Need to keep this as a reference for top probs
# Doesn't actually modify the logits!
return scores
# Stores the perplexity and top probabilities
ppl_logits_processor = None
def logits_processor_modifier(logits_processor_list, input_ids):
global ppl_logits_processor
ppl_logits_processor = PerplexityLogits()
logits_processor_list.append(ppl_logits_processor)
def output_modifier(text):
global ppl_logits_processor
# TODO: It's probably more efficient to do this above rather than modifying all these lists
# Remove last element of perplexities_list, top_token_ids_list, top_tokens_list, top_probs_list since everything is off by one because this extension runs before generation
perplexities = ppl_logits_processor.perplexities_list[:-1]
top_token_ids_list = ppl_logits_processor.top_token_ids_list[:-1]
top_tokens_list = [[shared.tokenizer.decode(token_id) for token_id in top_token_ids] for top_token_ids in top_token_ids_list]
top_probs_list = ppl_logits_processor.top_probs_list[:-1]
# Remove first element of generated_token_ids, generated_tokens, selected_probs because they are for the last token of the prompt
gen_token_ids = ppl_logits_processor.generated_token_ids[1:]
gen_tokens = [shared.tokenizer.decode(token_id) for token_id in gen_token_ids]
sel_probs = ppl_logits_processor.selected_probs[1:]
end_part = '</span>' # Helps with finding the index after replacing part of the text.
in_code = False # Since the <span> tags mess up code blocks, avoid coloring while inside a code block, based on finding tokens with '`' in them
if params['color_by_probability'] and params['color_by_perplexity']:
i = 0
for token, prob, ppl, top_tokens, top_probs in zip(gen_tokens, sel_probs, perplexities, top_tokens_list, top_probs_list):
if '`' in token:
in_code = not in_code
continue
if in_code:
continue
color = probability_perplexity_color_scale(prob, ppl)
if token in text[i:]:
text = text[:i] + text[i:].replace(token, add_color_html(token, color), 1)
i += text[i:].find(end_part) + len(end_part)
elif params['color_by_perplexity']:
i = 0
for token, ppl, top_tokens, top_probs in zip(gen_tokens, perplexities, top_tokens_list, top_probs_list):
if '`' in token:
in_code = not in_code
continue
if in_code:
continue
color = perplexity_color_scale(ppl)
if token in text[i:]:
text = text[:i] + text[i:].replace(token, add_color_html(token, color), 1)
i += text[i:].find(end_part) + len(end_part)
elif params['color_by_probability']:
i = 0
for token, prob, top_tokens, top_probs in zip(gen_tokens, sel_probs, top_tokens_list, top_probs_list):
if '`' in token:
in_code = not in_code
continue
if in_code:
continue
color = probability_color_scale(prob)
if token in text[i:]:
text = text[:i] + text[i:].replace(token, add_color_html(token, color), 1)
i += text[i:].find(end_part) + len(end_part)
print('Average perplexity:', round(np.mean(perplexities), 4))
return text
# Green-yellow-red color scale
def probability_color_scale(prob):
rv = 0
gv = 0
if prob <= 0.5:
rv = 'ff'
gv = hex(int(255*prob*2))[2:]
if len(gv) < 2:
gv = '0'*(2 - len(gv)) + gv
else:
rv = hex(int(255 - 255*(prob - 0.5)*2))[2:]
gv = 'ff'
if len(rv) < 2:
rv = '0'*(2 - len(rv)) + rv
return rv + gv + '00'
# Red component only, white for 0 perplexity (sorry if you're not in dark mode)
def perplexity_color_scale(ppl):
value = hex(max(int(255.0 - params['ppl_scale']*(float(ppl)-1.0)), 0))[2:]
if len(value) < 2:
value = '0'*(2 - len(value)) + value
return 'ff' + value + value
# Green-yellow-red for probability and blue component for perplexity
def probability_perplexity_color_scale(prob, ppl):
rv = 0
gv = 0
bv = hex(min(max(int(params['ppl_scale']*(float(ppl)-1.0)), 0), 255))[2:]
if len(bv) < 2:
bv = '0'*(2 - len(bv)) + bv
if prob <= 0.5:
rv = 'ff'
gv = hex(int(255*prob*2))[2:]
if len(gv) < 2:
gv = '0'*(2 - len(gv)) + gv
else:
rv = hex(int(255 - 255*(prob - 0.5)*2))[2:]
gv = 'ff'
if len(rv) < 2:
rv = '0'*(2 - len(rv)) + rv
return rv + gv + bv
def add_color_html(token, color):
return f'<span style="color: #{color}">{token}</span>'
"""
# This is still very broken at the moment, needs CSS too but I'm not very good at CSS (and neither is GPT-4 apparently) so I still need to figure that out.
def add_dropdown_html(token, color, top_tokens, top_probs):
html = f'<span class="hoverable" style="color: #{color}">{token}<div class="dropdown"><table class="dropdown-content">'
for token, prob in zip(top_tokens, top_probs):
# TODO: Background color? Bold for selected token?
# Bigger issue: Why is there a newline after the first token, and the dropdown fails there?
# The HTML ends up like <p><span>word</span></p><div>...</div>,
# even though for all other tokens it shows up correctly.
row_color = probability_color_scale(prob)
html += f'<tr><td style="color: #{row_color}">{token}</td><td style="color: #{row_color}">{prob}</td></tr>'
html += '</table></div></span>'
return html
"""
def ui():
color_by_ppl_check = gradio.Checkbox(value=False, label="Color by perplexity", info="Higher perplexity is more red. If also showing probability, higher perplexity has more blue component.")
def update_color_by_ppl_check(x):
params.update({'color_by_perplexity': x})
color_by_ppl_check.change(update_color_by_ppl_check, color_by_ppl_check, None)
color_by_prob_check = gradio.Checkbox(value=False, label="Color by probability", info="Green-yellow-red linear scale, with 100% green, 50% yellow, 0% red.")
def update_color_by_prob_check(x):
params.update({'color_by_probability': x})
color_by_prob_check.change(update_color_by_prob_check, color_by_prob_check, None)
# Doesn't work yet...
"""
prob_dropdown_check = gradio.Checkbox(value=False, label="Probability dropdown")
def update_prob_dropdown_check(x):
params.update({'probability_dropdown': x})
prob_dropdown_check.change(update_prob_dropdown_check, prob_dropdown_check, None)
"""

View file

@ -7,6 +7,7 @@ from transformers import BlipForConditionalGeneration, BlipProcessor
from modules import chat, shared
from modules.ui import gather_interface_values
from modules.utils import gradio
# If 'state' is True, will hijack the next chat generation with
# custom input text given by 'value' in the format [text, visible_text]
@ -42,6 +43,6 @@ def ui():
# Prepare the input hijack, update the interface values, call the generation function, and clear the picture
picture_select.upload(
lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.generate_chat_reply_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then(
lambda: None, None, picture_select, show_progress=False)

View file

@ -2,3 +2,4 @@ beautifulsoup4==4.12.2
chromadb==0.3.18
posthog==2.4.2
sentence_transformers==2.2.2
lxml

View file

@ -75,7 +75,7 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
cumulative += 'Processing the HTML sources...'
yield cumulative
for content in contents:
soup = BeautifulSoup(content, features="html.parser")
soup = BeautifulSoup(content, features="lxml")
for script in soup(["script", "style"]):
script.extract()

View file

@ -0,0 +1,15 @@
# whisper_stt
Allows you to enter your inputs in chat mode using your microphone.
## Settings
To adjust your default settings, you can add the following to your settings.yaml file.
```
whisper_stt-whipser_language: chinese
whisper_stt-whipser_model: tiny
whisper_stt-auto_submit: False
```
See source documentation for [model names](https://github.com/openai/whisper#available-models-and-languages) and (languages)[https://github.com/openai/whisper/blob/main/whisper/tokenizer.py] you can use.

View file

@ -8,8 +8,15 @@ input_hijack = {
'value': ["", ""]
}
# parameters which can be customized in settings.json of webui
params = {
'whipser_language': 'english',
'whipser_model': 'small.en',
'auto_submit': True
}
def do_stt(audio):
def do_stt(audio, whipser_model, whipser_language):
transcription = ""
r = sr.Recognizer()
@ -17,7 +24,7 @@ def do_stt(audio):
audio_data = sr.AudioData(sample_rate=audio[0], frame_data=audio[1], sample_width=4)
try:
transcription = r.recognize_whisper(audio_data, language="english", model="base.en")
transcription = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
except sr.UnknownValueError:
print("Whisper could not understand audio")
except sr.RequestError as e:
@ -26,11 +33,10 @@ def do_stt(audio):
return transcription
def auto_transcribe(audio, auto_submit):
def auto_transcribe(audio, auto_submit, whipser_model, whipser_language):
if audio is None:
return "", ""
transcription = do_stt(audio)
transcription = do_stt(audio, whipser_model, whipser_language)
if auto_submit:
input_hijack.update({"state": True, "value": [transcription, transcription]})
@ -38,10 +44,18 @@ def auto_transcribe(audio, auto_submit):
def ui():
with gr.Row():
audio = gr.Audio(source="microphone")
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=True)
with gr.Accordion("Whisper STT", open=True):
with gr.Row():
audio = gr.Audio(source="microphone")
with gr.Row():
with gr.Accordion("Settings", open=False):
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=params['auto_submit'])
whipser_model = gr.Dropdown(label='Whisper Model', value=params['whipser_model'], choices=["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large"])
whipser_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'], choices=["chinese", "german", "spanish", "russian", "korean", "french", "japanese", "portuguese", "turkish", "polish", "catalan", "dutch", "arabic", "swedish", "italian", "indonesian", "hindi", "finnish", "vietnamese", "hebrew", "ukrainian", "greek", "malay", "czech", "romanian", "danish", "hungarian", "tamil", "norwegian", "thai", "urdu", "croatian", "bulgarian", "lithuanian", "latin", "maori", "malayalam", "welsh", "slovak", "telugu", "persian", "latvian", "bengali", "serbian", "azerbaijani", "slovenian", "kannada", "estonian", "macedonian", "breton", "basque", "icelandic", "armenian", "nepali", "mongolian", "bosnian", "kazakh", "albanian", "swahili", "galician", "marathi", "punjabi", "sinhala", "khmer", "shona", "yoruba", "somali", "afrikaans", "occitan", "georgian", "belarusian", "tajik", "sindhi", "gujarati", "amharic", "yiddish", "lao", "uzbek", "faroese", "haitian creole", "pashto", "turkmen", "nynorsk", "maltese", "sanskrit", "luxembourgish", "myanmar", "tibetan", "tagalog", "malagasy", "assamese", "tatar", "hawaiian", "lingala", "hausa", "bashkir", "javanese", "sundanese"])
audio.change(
auto_transcribe, [audio, auto_submit], [shared.gradio['textbox'], audio]).then(
auto_transcribe, [audio, auto_submit, whipser_model, whipser_language], [shared.gradio['textbox'], audio]).then(
None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}")
whipser_model.change(lambda x: params.update({"whipser_model": x}), whipser_model, None)
whipser_language.change(lambda x: params.update({"whipser_language": x}), whipser_language, None)
auto_submit.change(lambda x: params.update({"auto_submit": x}), auto_submit, None)

View file

@ -97,6 +97,8 @@ llama-65b-gptq-3bit:
.*raven:
mode: 'instruct'
instruction_template: 'RWKV-Raven'
.*ctx8192:
truncation_length: 8192
.*moss-moon.*sft:
mode: 'instruct'
instruction_template: 'MOSS'
@ -143,6 +145,7 @@ llama-65b-gptq-3bit:
.*wizard.*mega:
mode: 'instruct'
instruction_template: 'Wizard-Mega'
custom_stopping_strings: '"</s>"'
.*ziya-:
mode: 'instruct'
instruction_template: 'Ziya'
@ -243,3 +246,26 @@ TheBloke_WizardLM-30B-GPTQ:
.*xgen.*-inst:
truncation_length: 8192
instruction_template: 'Vicuna-v0'
.*(platypus|gplatty|superplatty):
mode: 'instruct'
instruction_template: 'Alpaca'
.*longchat:
mode: 'instruct'
instruction_template: 'Vicuna-v1.1'
.*vicuna-33b:
mode: 'instruct'
instruction_template: 'Vicuna-v1.1'
.*redmond-hermes-coder:
mode: 'instruct'
instruction_template: 'Alpaca'
truncation_length: 8192
.*wizardcoder-15b:
mode: 'instruct'
instruction_template: 'Alpaca'
truncation_length: 8192
.*wizardlm-.*-v1.1:
mode: 'instruct'
instruction_template: 'Vicuna-v1.1'
.*godzilla:
mode: 'instruct'
instruction_template: 'Alpaca'

View file

@ -9,9 +9,9 @@ from modules.models import reload_model
def add_lora_to_model(lora_names):
if 'GPTQForCausalLM' in shared.model.__class__.__name__:
if 'GPTQForCausalLM' in shared.model.__class__.__name__ or shared.args.loader == 'AutoGPTQ':
add_lora_autogptq(lora_names)
elif shared.model.__class__.__name__ in ['ExllamaModel', 'ExllamaHF']:
elif shared.model.__class__.__name__ in ['ExllamaModel', 'ExllamaHF'] or shared.args.loader == 'ExLlama':
add_lora_exllama(lora_names)
else:
add_lora_transformers(lora_names)
@ -67,14 +67,15 @@ def add_lora_autogptq(lora_names):
return
if len(lora_names) == 0:
if len(shared.lora_names) > 0:
reload_model()
reload_model()
shared.lora_names = []
return
else:
if len(lora_names) > 1:
logger.warning('AutoGPTQ can only work with 1 LoRA at the moment. Only the first one in the list will be loaded.')
if not shared.args.no_inject_fused_attention:
logger.warning('Fused Atttention + AutoGPTQ may break Lora loading. Disable it.')
peft_config = GPTQLoraConfig(
inference_mode=True,
@ -107,18 +108,19 @@ def add_lora_transformers(lora_names):
# If any LoRA needs to be removed, start over
if len(removed_set) > 0:
# shared.model may no longer be PeftModel
if hasattr(shared.model, 'disable_adapter'):
shared.model.disable_adapter()
if hasattr(shared.model, 'disable_adapter'):
shared.model.disable_adapter()
shared.model = shared.model.base_model.model
if len(lora_names) > 0:
params = {}
if not shared.args.cpu:
params['dtype'] = shared.model.dtype
if hasattr(shared.model, "hf_device_map"):
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
elif shared.args.load_in_8bit:
params['device_map'] = {'': 0}
if shared.args.load_in_4bit or shared.args.load_in_8bit:
params['peft_type'] = shared.model.dtype
else:
params['dtype'] = shared.model.dtype
if hasattr(shared.model, "hf_device_map"):
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), adapter_name=lora_names[0], **params)

View file

@ -1,19 +1,47 @@
import builtins
import io
import requests
from modules.logging_colors import logger
original_open = open
original_get = requests.get
class RequestBlocker:
def __enter__(self):
self.original_get = requests.get
requests.get = my_get
def __exit__(self, exc_type, exc_value, traceback):
requests.get = self.original_get
requests.get = original_get
class OpenMonkeyPatch:
def __enter__(self):
builtins.open = my_open
def __exit__(self, exc_type, exc_value, traceback):
builtins.open = original_open
def my_get(url, **kwargs):
logger.info('Unwanted HTTP request redirected to localhost :)')
kwargs.setdefault('allow_redirects', True)
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
# Kindly provided by our friend WizardLM-30B
def my_open(*args, **kwargs):
filename = str(args[0])
if filename.endswith('index.html'):
with original_open(*args, **kwargs) as f:
file_contents = f.read()
file_contents = file_contents.replace(b'<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>', b'')
file_contents = file_contents.replace(b'cdnjs.cloudflare.com', b'127.0.0.1')
return io.BytesIO(file_contents)
else:
return original_open(*args, **kwargs)

View file

@ -3,6 +3,7 @@ import copy
import functools
import json
import re
from datetime import datetime
from pathlib import Path
import gradio as gr
@ -388,8 +389,25 @@ def load_history(file, history):
return history
def save_history_at_user_request(history, character, mode):
def make_timestamp_path(character=None):
return f"logs/{character or ''}{'_' if character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
path = None
if mode in ['chat', 'chat-instruct'] and character not in ['', 'None', None]:
path = make_timestamp_path(character)
else:
# Try to use mode as the file name, otherwise just use the timestamp
try:
path = make_timestamp_path(mode.capitalize())
except:
path = make_timestamp_path()
return save_history(history, path)
def save_persistent_history(history, character, mode):
if mode in ['chat', 'chat-instruct'] and character not in ['', 'None', None] and not shared.args.multi_user:
if mode in ['chat', 'chat-instruct'] and character not in ['', 'None', None] and not shared.args.multi_user:
save_history(history, path=Path(f'logs/{character}_persistent.json'))
@ -460,11 +478,16 @@ def load_character(character, name1, name2, instruct=False):
if character not in ['None', '', None]:
folder = 'characters' if not instruct else 'characters/instruction-following'
picture = generate_pfp_cache(character)
filepath = None
for extension in ["yml", "yaml", "json"]:
filepath = Path(f'{folder}/{character}.{extension}')
if filepath.exists():
break
if filepath is None:
logger.error(f"Could not find character file for {character} in {folder} folder. Please check your spelling.")
return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n")
file_contents = open(filepath, 'r', encoding='utf-8').read()
data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents)

View file

@ -1,10 +1,10 @@
import sys
from pathlib import Path
from torch import version as torch_version
from modules import shared
from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length
try:
from exllama.generator import ExLlamaGenerator
@ -90,7 +90,11 @@ class ExllamaModel:
self.generator.disallow_tokens(None)
self.generator.end_beam_search()
# Tokenizing the input
ids = self.generator.tokenizer.encode(prompt)
ids = ids[:, -get_max_prompt_length(state):]
self.generator.gen_begin_reuse(ids)
initial_len = self.generator.sequence[0].shape[0]
has_leading_space = False
@ -116,3 +120,6 @@ class ExllamaModel:
def encode(self, string, **kwargs):
return self.tokenizer.encode(string)
def decode(self, string, **kwargs):
return self.tokenizer.decode(string)[0]

View file

@ -106,15 +106,23 @@ def _apply_history_modifier_extensions(history):
return history
# Extension functions that override the default tokenizer output - currently only the first one will work
# Extension functions that override the default tokenizer output - The order of execution is not defined
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
for extension, _ in iterator():
if hasattr(extension, function_name):
return getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return prompt, input_ids, input_embeds
# Allow extensions to add their own logits processors to the stack being run.
# Each extension would call `processor_list.append({their LogitsProcessor}())`.
def _apply_logits_processor_extensions(function_name, processor_list, input_ids):
for extension, _ in iterator():
if hasattr(extension, function_name):
getattr(extension, function_name)(processor_list, input_ids)
# Get prompt length in tokens after applying extension functions which override the default tokenizer output
# currently only the first one will work
def _apply_custom_tokenized_length(prompt):
@ -183,6 +191,7 @@ EXTENSION_MAP = {
"history": _apply_history_modifier_extensions,
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
'logits_processor': partial(_apply_logits_processor_extensions, 'logits_processor_modifier'),
"input_hijack": _apply_input_hijack,
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
"custom_generate_reply": _apply_custom_generate_reply,

View file

@ -49,6 +49,7 @@ class LlamaCppModel:
'n_batch': shared.args.n_batch,
'use_mmap': not shared.args.no_mmap,
'use_mlock': shared.args.mlock,
'low_vram': shared.args.low_vram,
'n_gpu_layers': shared.args.n_gpu_layers
}
@ -65,6 +66,9 @@ class LlamaCppModel:
return self.model.tokenize(string)
def decode(self, tokens):
return self.model.detokenize(tokens)
def generate(self, prompt, state, callback=None):
prompt = prompt if type(prompt) is str else prompt.decode()
completion_chunks = self.model.create_completion(

View file

@ -34,6 +34,7 @@ loaders_and_params = {
'n_batch',
'threads',
'no_mmap',
'low_vram',
'mlock',
'llama_cpp_seed',
],
@ -53,14 +54,14 @@ loaders_and_params = {
'trust_remote_code',
'transformers_info'
],
'ExLlama' : [
'ExLlama': [
'gpu_split',
'max_seq_len',
'compress_pos_emb',
'alpha_value',
'exllama_info',
],
'ExLlama_HF' : [
'ExLlama_HF': [
'gpu_split',
'max_seq_len',
'compress_pos_emb',

View file

@ -61,6 +61,10 @@ def load_model(model_name, loader=None):
'ExLlama_HF': ExLlama_HF_loader
}
p = Path(model_name)
if p.exists():
model_name = p.parts[-1]
if loader is None:
if shared.args.loader is not None:
loader = shared.args.loader
@ -95,11 +99,18 @@ def load_tokenizer(model_name, model):
if any(s in model_name.lower() for s in ['gpt-4chan', 'gpt4chan']) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
elif path_to_model.exists():
tokenizer = AutoTokenizer.from_pretrained(
path_to_model,
trust_remote_code=shared.args.trust_remote_code,
use_fast=False
)
try:
tokenizer = AutoTokenizer.from_pretrained(
path_to_model,
trust_remote_code=shared.args.trust_remote_code,
use_fast=False
)
except ValueError:
tokenizer = AutoTokenizer.from_pretrained(
path_to_model,
trust_remote_code=shared.args.trust_remote_code,
use_fast=True
)
if tokenizer.__class__.__name__ == 'LlamaTokenizer':
pairs = [
@ -328,6 +339,7 @@ def clear_torch_cache():
def unload_model():
shared.model = shared.tokenizer = None
shared.lora_names = []
shared.model_dirty_from_training = False
clear_torch_cache()

View file

@ -99,7 +99,10 @@ def apply_model_settings_to_state(model, state):
for k in model_settings:
if k in state:
state[k] = model_settings[k]
if k in ['wbits', 'groupsize']:
state[k] = str(model_settings[k])
else:
state[k] = model_settings[k]
return state

View file

@ -126,6 +126,7 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
'''
Copied from the transformers library
'''
def __init__(self, penalty: float, _range: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

View file

@ -12,6 +12,7 @@ tokenizer = None
is_seq2seq = False
model_name = "None"
lora_names = []
model_dirty_from_training = False
# Chat variables
stop_everything = False
@ -120,6 +121,7 @@ parser.add_argument('--use_double_quant', action='store_true', help='use_double_
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
parser.add_argument('--n_batch', type=int, default=512, help='Maximum number of prompt tokens to batch together when calling llama_eval.')
parser.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.')
parser.add_argument('--low-vram', action='store_true', help='Low VRAM Mode')
parser.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.')
parser.add_argument('--cache-capacity', type=str, help='Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.')
parser.add_argument('--n-gpu-layers', type=int, default=0, help='Number of layers to offload to the GPU.')
@ -179,7 +181,7 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
# API
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.')
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
# Multimodal

View file

@ -8,6 +8,7 @@ import traceback
import numpy as np
import torch
import transformers
from transformers import LogitsProcessorList
import modules.shared as shared
from modules.callbacks import (
@ -264,6 +265,13 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
processor = state.get('logits_processor', LogitsProcessorList([]))
# In case folks just pass in a processor by itself.
if type(processor) != LogitsProcessorList:
processor = LogitsProcessorList([processor])
apply_extensions('logits_processor', processor, input_ids)
generate_params['logits_processor'] = processor
t0 = time.time()
try:
if not is_chat and not shared.is_seq2seq:

View file

@ -1,18 +1,23 @@
import os
os.environ["WANDB_MODE"] = "offline"
# os.environ["WANDB_DISABLED"] = "true"
import json
import math
import random
import shutil
import sys
import threading
import time
import traceback
from datetime import datetime
from pathlib import Path
import gradio as gr
import torch
import transformers
import shutil
from datetime import datetime
from modules.models import load_model, unload_model
from datasets import Dataset, load_dataset
from peft import (
@ -29,6 +34,7 @@ from modules.evaluate import (
save_past_evaluations
)
from modules.logging_colors import logger
from modules.utils import natural_keys
# This mapping is from a very recent commit, not yet released.
# If not available, default to a backup map for some common model types.
@ -56,7 +62,7 @@ train_log = {}
train_template = {}
WANT_INTERRUPT = False
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss"]
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
def create_train_interface():
@ -104,6 +110,7 @@ def create_train_interface():
raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Hard Cut blocks that have less or equal characters than this number')
with gr.Row():
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
@ -115,9 +122,12 @@ def create_train_interface():
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item. In case of raw text, the EOS will be added at the Hard Cut")
with gr.Row():
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
with gr.Row():
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
with gr.Row():
start_button = gr.Button("Start LoRA Training")
@ -148,7 +158,9 @@ def create_train_interface():
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
# Training events
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss]
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
start_button.click(do_train, all_params, output)
stop_button.click(do_interrupt, None, None, queue=False)
@ -223,7 +235,7 @@ def backup_adapter(input_folder):
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
# Create the new subfolder
subfolder_path = Path(f"{input_folder}/{creation_date_str}")
subfolder_path = Path(f"{input_folder}/{creation_date_str}")
subfolder_path.mkdir(parents=True, exist_ok=True)
# Check if the file already exists in the subfolder
@ -240,6 +252,7 @@ def backup_adapter(input_folder):
except Exception as e:
print("An error occurred in backup_adapter:", str(e))
def calc_trainable_parameters(model):
trainable_params = 0
all_param = 0
@ -252,11 +265,11 @@ def calc_trainable_parameters(model):
all_param += num_params
if param.requires_grad:
trainable_params += num_params
return trainable_params,all_param
return trainable_params, all_param
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float):
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str):
if shared.args.monkey_patch:
from monkeypatch.peft_tuners_lora_monkey_patch import (
@ -314,14 +327,22 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
def encode(text, add_bos_token):
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
# Check if the first two tokens are BOS
if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]:
result = result[1:]
if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
result = result[1:]
return result
def tokenize(prompt):
def tokenize(prompt, append_eos_token=False):
if train_only_after == '' or train_only_after not in prompt:
input_ids = encode(prompt, True)
if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len:
input_ids.append(shared.tokenizer.eos_token_id)
input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
labels = [1] * len(input_ids)
@ -330,6 +351,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
before_tokens = encode(prompt[:ind], True)
after_tokens = encode(prompt[ind:], False)
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
after_tokens.append(shared.tokenizer.eos_token_id)
full_length = len(after_tokens) + len(before_tokens)
if full_length > cutoff_len:
after_tokens = after_tokens[:cutoff_len - len(before_tokens)]
@ -350,31 +374,46 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']:
logger.info("Loading raw text file dataset...")
train_template["template_type"] = "raw_text"
logger.info("Loading raw text file dataset...")
fullpath = clean_path('training/datasets', f'{raw_text_file}')
fullpath = Path(fullpath)
if fullpath.is_dir():
logger.info('Training path directory {}'.format(raw_text_file))
raw_text = ""
file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name))
for file_path in file_paths:
if file_path.is_file():
with file_path.open('r', encoding='utf-8') as file:
raw_text += file.read().replace('\r', '')
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read().replace('\r', '')
logger.info(f"Loaded training file: {file_path.name}")
else:
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read().replace('\r', '')
cut_string = hard_cut_string.replace('\\n', '\n')
eos_added = 0
out_tokens = []
for text_part in raw_text.split(cut_string):
if text_part.strip() == '':
if len(text_part.strip()) <= min_chars:
continue
tokens = shared.tokenizer.encode(text_part)
if add_eos_token:
tokens.append(shared.tokenizer.eos_token_id)
eos_added += 1
step = cutoff_len - overlap_len
if step <= 0:
yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
return
tokens = list(split_chunks(tokens, step))
for i in range(1, len(tokens)):
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
out_tokens.extend(split_chunks(tokens, cutoff_len, step))
out_tokens.extend(tokens)
del tokens
if eos_added > 0:
print(f"EOS added to {eos_added} text blocks")
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]
@ -415,7 +454,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
def generate_and_tokenize_prompt(data_point):
prompt = generate_prompt(data_point)
return tokenize(prompt)
return tokenize(prompt, add_eos_token)
logger.info("Loading JSON datasets...")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
@ -427,11 +466,33 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
# == We MUST reload model if it went through any previous training, even failed one ==
if shared.model_dirty_from_training:
selected_model = shared.model_name
if selected_model:
print("\033[1;31;1m(Model has been modified by previous training, it needs to be reloaded...)\033[0;37;0m")
try:
yield f"Reloading {selected_model}..."
unload_model()
shared.model, shared.tokenizer = load_model(shared.model_name, None)
if shared.model is not None:
print("Model reloaded OK, continue with training.")
else:
return f"Failed to load {selected_model}."
except:
exc = traceback.format_exc()
logger.error('Failed to reload the model.')
print(exc)
return exc
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logger.info("Getting model ready...")
prepare_model_for_int8_training(shared.model)
# base model is now frozen and should not be reused for any other LoRA training than this one
shared.model_dirty_from_training = True
logger.info("Prepping for training...")
config = LoraConfig(
r=lora_rank,
@ -518,6 +579,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
train_dataset=train_data,
eval_dataset=eval_data,
args=transformers.TrainingArguments(
report_to=report_to if report_to != "None" else None,
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
@ -534,7 +596,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
load_best_model_at_end=eval_data is not None,
# TODO: Enable multi-device support
ddp_find_unused_parameters=None,
no_cuda=shared.args.cpu
no_cuda=shared.args.cpu,
),
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
callbacks=list([Callbacks()])
@ -559,15 +621,19 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
yield "Starting..."
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
if lora_all_param>0:
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
projections_string = ", ".join([projection.replace("_proj", "") for projection in model_to_lora_modules[model_id]])
print(f"Training '{model_id}' model using ({projections_string}) projections")
if lora_all_param > 0:
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
train_log.update({"base_model_name": shared.model_name})
train_log.update({"base_model_class": shared.model.__class__.__name__})
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
train_log.update({"projections": projections_string})
if stop_at_loss > 0:
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
@ -576,7 +642,26 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
yield "Interrupted before start."
return
def log_train_dataset(trainer):
decoded_entries = []
# Try to decode the entries and write the log file
try:
# Iterate over the first 10 elements in the dataset (or fewer if there are less than 10)
for i in range(min(10, len(trainer.train_dataset))):
decoded_text = shared.tokenizer.decode(trainer.train_dataset[i]['input_ids'])
decoded_entries.append({"value": decoded_text})
# Write the log file
Path('logs').mkdir(exist_ok=True)
with open(Path('logs/train_dataset_sample.json'), 'w') as json_file:
json.dump(decoded_entries, json_file, indent=4)
logger.info("Log file 'train_dataset_sample.json' created in the 'logs' directory.")
except Exception as e:
logger.error(f"Failed to create log file due to error: {e}")
def threaded_run():
log_train_dataset(trainer)
trainer.train()
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model.save_pretrained(lora_file_path)
@ -625,9 +710,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
yield f"Done! LoRA saved to `{lora_file_path}`"
def split_chunks(arr, step):
def split_chunks(arr, size, step):
for i in range(0, len(arr), step):
yield arr[i:i + step]
yield arr[i:i + size]
def cut_chunk_for_newline(chunk: str, max_length: int):

View file

@ -57,6 +57,7 @@ def list_model_elements():
'threads',
'n_batch',
'no_mmap',
'low_vram',
'mlock',
'n_gpu_layers',
'n_ctx',
@ -67,7 +68,6 @@ def list_model_elements():
'alpha_value'
]
for i in range(torch.cuda.device_count()):
elements.append(f'gpu_memory_{i}')
@ -76,7 +76,6 @@ def list_model_elements():
def list_interface_input_elements():
elements = [
'preset_menu',
'max_new_tokens',
'seed',
'temperature',
@ -158,11 +157,10 @@ def apply_interface_values(state, use_persistent=False):
if len(state) == 0:
return [gr.update() for k in elements] # Dummy, do nothing
else:
ans = [state[k] if k in state else gr.update() for k in elements]
return ans
return [state[k] if k in state else gr.update() for k in elements]
class ToolButton(gr.Button, gr.components.FormComponent):
class ToolButton(gr.Button, gr.components.IOComponent):
"""Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs):

View file

@ -114,6 +114,10 @@ def get_available_loras():
def get_datasets(path: str, ext: str):
# include subdirectories for raw txt files to allow training from a subdirectory of txt files
if ext == "txt":
return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt')) + list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)

View file

@ -2,6 +2,7 @@ accelerate==0.20.3
colorama
datasets
einops
fastapi==0.95.2
flexgen==0.1.7
gradio_client==0.2.5
gradio==3.33.1
@ -15,13 +16,15 @@ safetensors==0.3.1
sentencepiece
tqdm
scipy
tensorboard
wandb
transformers==4.30.2
git+https://github.com/huggingface/peft@03eb378eb914fbee709ff7c86ba5b1d033b89524
bitsandbytes==0.39.1; platform_system != "Windows"
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl; platform_system == "Windows"
llama-cpp-python==0.1.68; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.68/llama_cpp_python-0.1.68-cp310-cp310-win_amd64.whl; platform_system == "Windows"
bitsandbytes==0.40.0.post4; platform_system != "Windows"
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.0.post4-py3-none-win_amd64.whl; platform_system == "Windows"
llama-cpp-python==0.1.70; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.70/llama_cpp_python-0.1.70-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/jllllll/exllama/releases/download/0.0.5/exllama-0.0.5+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/jllllll/exllama/releases/download/0.0.5/exllama-0.0.5+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/jllllll/exllama/releases/download/0.0.6/exllama-0.0.6+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/jllllll/exllama/releases/download/0.0.6/exllama-0.0.6+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

View file

@ -2,7 +2,7 @@ import os
import warnings
from modules.logging_colors import logger
from modules.block_requests import RequestBlocker
from modules.block_requests import OpenMonkeyPatch, RequestBlocker
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
@ -54,7 +54,7 @@ from modules.utils import gradio
def load_model_wrapper(selected_model, loader, autoload=False):
if not autoload:
yield f"The settings for {selected_model} have been updated.\nClick on \"Load the model\" to load it."
yield f"The settings for {selected_model} have been updated.\nClick on \"Load\" to load it."
return
if selected_model == 'None':
@ -145,7 +145,8 @@ def download_model_wrapper(repo_id, progress=gr.Progress()):
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
yield ("Getting the output folder")
output_folder = downloader.get_output_folder(model, branch, is_lora)
base_folder = shared.args.lora_dir if is_lora else shared.args.model_dir
output_folder = downloader.get_output_folder(model, branch, is_lora, base_folder=base_folder)
if check:
progress(0.5)
@ -218,8 +219,8 @@ def create_model_menus():
shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch)
shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128, value=shared.args.n_gpu_layers)
shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=16384, step=256, label="n_ctx", value=shared.args.n_ctx)
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None")
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=shared.args.groupsize if shared.args.groupsize > 0 else "None")
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None")
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None")
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
shared.gradio['autogptq_info'] = gr.Markdown('On some systems, AutoGPTQ can be 2x slower than GPTQ-for-LLaMa. You can manually select the GPTQ-for-LLaMa loader above.')
@ -242,6 +243,7 @@ def create_model_menus():
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant)
shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
shared.gradio['low_vram'] = gr.Checkbox(label="low-vram", value=shared.args.low_vram)
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed)
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')
@ -846,7 +848,7 @@ def create_interface():
# Reset interface event
shared.gradio['reset_interface'].click(
set_interface_arguments, gradio('interface_modes_menu', 'extensions_menu', 'bool_menu'), None).then(
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;background:var(--body-background-fill)">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
shared.gradio['toggle_dark_mode'].click(lambda: None, None, None, _js='() => {document.getElementsByTagName("body")[0].classList.toggle("dark")}')
@ -976,7 +978,7 @@ def create_interface():
lambda: 'characters/instruction-following/', None, gradio('delete_root')).then(
lambda: gr.update(visible=True), None, gradio('file_deleter'))
shared.gradio['download_button'].click(chat.save_history, gradio('history'), gradio('download'))
shared.gradio['download_button'].click(chat.save_history_at_user_request, gradio('history', 'character_menu', 'mode'), gradio('download'))
shared.gradio['Submit character'].click(chat.upload_character, gradio('upload_json', 'upload_img_bot'), gradio('character_menu'))
shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, gradio('Submit character'))
shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, gradio('Submit character'))
@ -1068,10 +1070,11 @@ def create_interface():
# Launch the interface
shared.gradio['interface'].queue()
if shared.args.listen:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name=shared.args.listen_host or '0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
with OpenMonkeyPatch():
if shared.args.listen:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name=shared.args.listen_host or '0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
if __name__ == "__main__":

View file

@ -0,0 +1 @@
to load multiple raw text files create a subdirectory and put them all there

View file

@ -0,0 +1,3 @@
{
"instruction,output": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: %instruction%\n\nASSISTANT: %output%"
}