text-generation-webui/extensions/openai/completions.py

554 lines
20 KiB
Python
Raw Normal View History

import base64
import copy
import re
import time
from collections import deque
from io import BytesIO
2023-09-16 05:11:16 +02:00
import requests
import tiktoken
import torch
import torch.nn.functional as F
from PIL import Image
2023-11-17 03:03:06 +01:00
from transformers import LogitsProcessor, LogitsProcessorList
2023-09-16 05:11:16 +02:00
from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg
from modules import shared
from modules.chat import (
generate_chat_prompt,
generate_chat_reply,
load_character_memoized,
load_instruction_template_memoized
)
from modules.presets import load_preset_memoized
2024-01-22 12:25:55 +01:00
from modules.text_generation import (
decode,
encode,
generate_reply,
get_reply_from_output_ids
)
class LogitsBiasProcessor(LogitsProcessor):
def __init__(self, logit_bias={}):
self.logit_bias = logit_bias
if self.logit_bias:
self.keys = list([int(key) for key in self.logit_bias.keys()])
2023-09-16 05:11:16 +02:00
values = [self.logit_bias[str(key)] for key in self.keys]
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device)
debug_msg(f"{self})")
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logit_bias:
debug_msg(logits[0, self.keys], " + ", self.values)
logits[0, self.keys] += self.values
debug_msg(" --> ", logits[0, self.keys])
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0])))
return logits
def __repr__(self):
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
2023-09-16 05:11:16 +02:00
class LogprobProcessor(LogitsProcessor):
def __init__(self, logprobs=None):
self.logprobs = logprobs
self.token_alternatives = {}
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
2023-07-12 20:33:25 +02:00
if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1)
2023-09-16 05:11:16 +02:00
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
2023-09-16 05:11:16 +02:00
top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs))
debug_msg(repr(self))
return logits
def __repr__(self):
return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>"
def convert_logprobs_to_tiktoken(model, logprobs):
2023-09-16 05:11:16 +02:00
# more problems than it's worth.
# try:
# encoder = tiktoken.encoding_for_model(model)
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
# except KeyError:
# # assume native tokens if we can't find the tokenizer
# return logprobs
return logprobs
def process_parameters(body, is_legacy=False):
generate_params = body
max_tokens_str = 'length' if is_legacy else 'max_tokens'
generate_params['max_new_tokens'] = body.pop(max_tokens_str)
if generate_params['truncation_length'] == 0:
generate_params['truncation_length'] = shared.settings['truncation_length']
2023-07-12 20:33:25 +02:00
2024-01-22 13:12:23 +01:00
if generate_params['temperature'] == 0:
generate_params['do_sample'] = False
generate_params['top_k'] = 1
if body['preset'] is not None:
preset = load_preset_memoized(body['preset'])
generate_params.update(preset)
generate_params['custom_stopping_strings'] = []
if 'stop' in body: # str or array, max len 4 (ignored)
if isinstance(body['stop'], str):
generate_params['custom_stopping_strings'] = [body['stop']]
elif isinstance(body['stop'], list):
generate_params['custom_stopping_strings'] = body['stop']
logits_processor = []
logit_bias = body.get('logit_bias', None)
2023-07-12 20:33:25 +02:00
if logit_bias: # {str: float, ...}
logits_processor = [LogitsBiasProcessor(logit_bias)]
2023-07-12 20:33:25 +02:00
logprobs = None # coming to chat eventually
if 'logprobs' in body:
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
logits_processor.extend([generate_params['logprob_proc']])
else:
logprobs = None
2023-07-12 20:33:25 +02:00
if logits_processor: # requires logits_processor support
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
return generate_params
def convert_history(history):
'''
Chat histories in this program are in the format [message, reply].
This function converts OpenAI histories to that format.
'''
chat_dialogue = []
current_message = ""
current_reply = ""
user_input = ""
system_message = ""
2024-01-22 13:15:51 +01:00
# Multimodal: convert OpenAI format to multimodal extension format
2024-01-22 13:38:43 +01:00
if any('content' in entry and isinstance(entry['content'], list) for entry in history):
new_history = []
for entry in history:
if isinstance(entry['content'], list):
image_url = None
content = None
for item in entry['content']:
if not isinstance(item, dict):
continue
2024-01-22 13:07:25 +01:00
if item['type'] == 'image_url' and isinstance(item['image_url'], dict):
image_url = item['image_url']['url']
elif item['type'] == 'text' and isinstance(item['text'], str):
content = item['text']
2024-01-22 13:07:25 +01:00
if image_url and content:
new_history.append({"image_url": image_url, "role": "user"})
new_history.append({"content": content, "role": "user"})
else:
new_history.append(entry)
2024-01-22 13:07:25 +01:00
history = new_history
for entry in history:
if "image_url" in entry:
image_url = entry['image_url']
if "base64" in image_url:
image_url = re.sub('^data:image/.+;base64,', '', image_url)
img = Image.open(BytesIO(base64.b64decode(image_url)))
else:
try:
my_res = requests.get(image_url)
img = Image.open(BytesIO(my_res.content))
except Exception:
raise 'Image cannot be loaded from the URL!'
2024-01-22 13:07:25 +01:00
buffered = BytesIO()
if img.mode in ("RGBA", "P"):
img = img.convert("RGB")
2024-01-22 13:07:25 +01:00
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
content = f'<img src="data:image/jpeg;base64,{img_str}">'
else:
content = entry["content"]
role = entry["role"]
if role == "user":
user_input = content
if current_message:
chat_dialogue.append([current_message, ''])
current_message = ""
2024-01-22 13:07:25 +01:00
current_message = content
elif role == "assistant":
current_reply = content
if current_message:
chat_dialogue.append([current_message, current_reply])
current_message = ""
current_reply = ""
else:
chat_dialogue.append(['', current_reply])
elif role == "system":
system_message = content
# if current_message:
# chat_dialogue.append([current_message, ''])
return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)}
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -> dict:
if body.get('functions', []):
raise InvalidRequestError(message="functions is not supported.", param='functions')
if body.get('function_call', ''):
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
2023-09-16 05:11:16 +02:00
if 'messages' not in body:
raise InvalidRequestError(message="messages is required", param='messages')
2023-07-12 20:33:25 +02:00
messages = body['messages']
for m in messages:
if 'role' not in m:
raise InvalidRequestError(message="messages: missing role", param='messages')
elif m['role'] == 'function':
raise InvalidRequestError(message="role: function is not supported.", param='messages')
if 'content' not in m and "image_url" not in m:
raise InvalidRequestError(message="messages: missing content", param='messages')
2023-09-16 05:11:16 +02:00
# Chat Completions
object_type = 'chat.completions' if not stream else 'chat.completions.chunk'
created_time = int(time.time())
2023-07-12 20:33:25 +02:00
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# generation parameters
generate_params = process_parameters(body, is_legacy=is_legacy)
continue_ = body['continue_']
# Instruction template
if body['instruction_template_str']:
instruction_template_str = body['instruction_template_str']
elif body['instruction_template']:
instruction_template = body['instruction_template']
instruction_template = "Alpaca" if instruction_template == "None" else instruction_template
instruction_template_str = load_instruction_template_memoized(instruction_template)
else:
instruction_template_str = shared.settings['instruction_template_str']
chat_template_str = body['chat_template_str'] or shared.default_settings['chat_template_str']
chat_instruct_command = body['chat_instruct_command'] or shared.default_settings['chat-instruct_command']
# Chat character
character = body['character'] or shared.default_settings['character']
2023-11-06 14:22:01 +01:00
character = "Assistant" if character == "None" else character
name1 = body['user_name'] or shared.default_settings['name1']
name1, name2, _, greeting, context = load_character_memoized(character, name1, '')
2024-01-10 04:08:02 +01:00
name2 = body['bot_name'] or name2
context = body['context'] or context
greeting = body['greeting'] or greeting
user_bio = body['user_bio'] or ''
# History
user_input, custom_system_message, history = convert_history(messages)
generate_params.update({
'mode': body['mode'],
'name1': name1,
'name2': name2,
'context': context,
'greeting': greeting,
'user_bio': user_bio,
'instruction_template_str': instruction_template_str,
'custom_system_message': custom_system_message,
'chat_template_str': chat_template_str,
'chat-instruct_command': chat_instruct_command,
'history': history,
'stream': stream
})
max_tokens = generate_params['max_new_tokens']
if max_tokens in [None, 0]:
generate_params['max_new_tokens'] = 512
generate_params['auto_max_new_tokens'] = True
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
def chat_streaming_chunk(content):
# begin streaming
chunk = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": None,
"delta": {'role': 'assistant', 'content': content},
}],
}
2023-07-12 20:33:25 +02:00
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
2023-07-12 20:33:25 +02:00
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# chunk[resp_list][0]["logprobs"] = None
return chunk
if stream:
yield chat_streaming_chunk('')
# generate reply #######################################
prompt = generate_chat_prompt(user_input, generate_params)
token_count = len(encode(prompt)[0])
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_chat_reply(
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
answer = ''
seen_content = ''
completion_token_count = 0
for a in generator:
answer = a['internal'][-1][1]
if stream:
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
chunk = chat_streaming_chunk(new_content)
yield chunk
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
stop_reason = "length"
if stream:
chunk = chat_streaming_chunk('')
chunk[resp_list][0]['finish_reason'] = stop_reason
chunk['usage'] = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
yield chunk
else:
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"message": {"role": "assistant", "content": answer}
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# resp[resp_list][0]["logprobs"] = None
yield resp
def completions_common(body: dict, is_legacy: bool = False, stream=False):
object_type = 'text_completion.chunk' if stream else 'text_completion'
created_time = int(time.time())
2023-07-12 20:33:25 +02:00
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
prompt_str = 'context' if is_legacy else 'prompt'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
2023-09-16 05:11:16 +02:00
if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
# common params
generate_params = process_parameters(body, is_legacy=is_legacy)
max_tokens = generate_params['max_new_tokens']
generate_params['stream'] = stream
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
2023-11-07 17:43:45 +01:00
suffix = body['suffix'] if body['suffix'] else ''
echo = body['echo']
if not stream:
prompt_arg = body[prompt_str]
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
prompt_arg = [prompt_arg]
resp_list_data = []
total_completion_token_count = 0
total_prompt_token_count = 0
for idx, prompt in enumerate(prompt_arg, start=0):
if isinstance(prompt[0], int):
# token lists
if requested_model == shared.model_name:
prompt = decode(prompt)[0]
else:
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
2023-11-07 17:43:45 +01:00
prefix = prompt if echo else ''
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
for a in generator:
answer = a
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
respi = {
"index": idx,
"finish_reason": stop_reason,
2023-11-07 17:43:45 +01:00
"text": prefix + answer + suffix,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}
resp_list_data.extend([respi])
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
resp_list: resp_list_data,
"usage": {
"prompt_tokens": total_prompt_token_count,
"completion_tokens": total_completion_token_count,
"total_tokens": total_prompt_token_count + total_completion_token_count
}
}
yield resp
else:
prompt = body[prompt_str]
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
2023-11-07 17:43:45 +01:00
prefix = prompt if echo else ''
token_count = len(encode(prompt)[0])
def text_streaming_chunk(content):
# begin streaming
chunk = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": None,
"text": content,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}],
}
return chunk
2023-11-07 17:43:45 +01:00
yield text_streaming_chunk(prefix)
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
seen_content = ''
completion_token_count = 0
for a in generator:
answer = a
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
chunk = text_streaming_chunk(new_content)
yield chunk
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
2023-11-07 17:43:45 +01:00
chunk = text_streaming_chunk(suffix)
chunk[resp_list][0]["finish_reason"] = stop_reason
chunk["usage"] = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
yield chunk
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
generator = chat_completions_common(body, is_legacy, stream=False)
return deque(generator, maxlen=1).pop()
def stream_chat_completions(body: dict, is_legacy: bool = False):
for resp in chat_completions_common(body, is_legacy, stream=True):
yield resp
def completions(body: dict, is_legacy: bool = False) -> dict:
generator = completions_common(body, is_legacy, stream=False)
return deque(generator, maxlen=1).pop()
def stream_completions(body: dict, is_legacy: bool = False):
for resp in completions_common(body, is_legacy, stream=True):
yield resp