Add the option to use samplers in the logit viewer

This commit is contained in:
oobabooga 2023-08-22 20:18:16 -07:00
parent 25e5eaa6a6
commit 8545052c9d
8 changed files with 75 additions and 18 deletions

View file

@ -237,6 +237,11 @@ audio {
border-radius: 0.4em;
}
.no-background {
background: var(--background-fill-primary) !important;
padding: 0px !important;
}
/*****************************************************/
/*************** Chat UI declarations ****************/
/*****************************************************/

View file

@ -82,3 +82,12 @@ observer.observe(targetElement, config);
//------------------------------------------------
document.getElementById('chat-input').parentNode.style.background = 'transparent';
document.getElementById('chat-input').parentNode.style.border = 'none';
//------------------------------------------------
// Remove some backgrounds
//------------------------------------------------
const noBackgroundelements = document.querySelectorAll('.no-background');
for(i = 0; i < noBackgroundelements.length; i++) {
noBackgroundelements[i].parentNode.style.border = 'none';
noBackgroundelements[i].parentNode.parentNode.parentNode.style.alignItems = 'center';
}

View file

@ -24,6 +24,7 @@ class Stream(transformers.StoppingCriteria):
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False

View file

@ -1,19 +1,30 @@
import torch
from modules import shared
from modules import sampler_hijack, shared
from modules.text_generation import generate_reply
global_scores = None
def get_next_logits(prompt):
def get_next_logits(prompt, state, use_samplers, previous):
if use_samplers:
state['max_new_tokens'] = 1
state['auto_max_new_tokens'] = False
for _ in generate_reply(prompt, state):
pass
scores = sampler_hijack.global_scores[-1]
else:
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
output = shared.model(input_ids=tokens)
scores = output['logits'][-1][-1]
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
topk_values, topk_indices = torch.topk(probs, k=20, largest=True, sorted=True)
topk_values = [f"{float(i):.5f}" % i for i in topk_values]
topk_values = [f"{float(i):.5f}" for i in topk_values]
output = ''
for row in list(zip(topk_values, shared.tokenizer.convert_ids_to_tokens(topk_indices))):
output += f"{row[0]} {row[1]}\n"
output += f"{row[0]} - {row[1]}\n"
return output
return output, previous

View file

@ -10,6 +10,8 @@ from transformers.generation.logits_process import (
TemperatureLogitsWarper
)
global_scores = None
class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@ -122,6 +124,16 @@ class MirostatLogitsWarper(LogitsWarper):
return scores
class SpyLogitsWarper(LogitsWarper):
def __init__(self):
pass
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
global global_scores
global_scores = scores
return scores
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
'''
Copied from the transformers library
@ -168,6 +180,7 @@ def get_logits_warper_patch(self, generation_config):
else:
warpers += warpers_to_add
warpers.append(SpyLogitsWarper())
return warpers

View file

@ -64,7 +64,7 @@ def create_ui():
with gr.Column(scale=5):
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
with gr.Column():
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).')
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
with gr.Row():
with gr.Column():

View file

@ -44,8 +44,15 @@ def create_ui():
shared.gradio['html-default'] = gr.HTML()
with gr.Tab('Logits'):
with gr.Row():
with gr.Column(scale=10):
shared.gradio['get_logits-default'] = gr.Button('Get next token probabilities')
with gr.Column(scale=1):
shared.gradio['use_samplers-default'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background'])
with gr.Row():
shared.gradio['logits-default'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar'])
shared.gradio['logits-default-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits', 'add_scrollbar'])
def create_event_handlers():
@ -84,4 +91,6 @@ def create_event_handlers():
lambda: gr.update(visible=True), None, gradio('file_deleter'))
shared.gradio['textbox-default'].change(lambda x: f"<span>{count_tokens(x)}</span>", gradio('textbox-default'), gradio('token-counter-default'), show_progress=False)
shared.gradio['get_logits-default'].click(logits.get_next_logits, gradio('textbox-default'), gradio('logits-default'), show_progress=False)
shared.gradio['get_logits-default'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
logits.get_next_logits, gradio('textbox-default', 'interface_state', 'use_samplers-default', 'logits-default'), gradio('logits-default', 'logits-default-previous'), show_progress=False)

View file

@ -30,8 +30,15 @@ def create_ui():
shared.gradio['html-notebook'] = gr.HTML()
with gr.Tab('Logits'):
with gr.Row():
with gr.Column(scale=10):
shared.gradio['get_logits-notebook'] = gr.Button('Get next token probabilities')
shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits_notebook', 'add_scrollbar'])
with gr.Column(scale=1):
shared.gradio['use_samplers-notebook'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background'])
with gr.Row():
shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar'])
shared.gradio['logits-notebook-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits', 'add_scrollbar'])
with gr.Row():
shared.gradio['Generate-notebook'] = gr.Button('Generate', variant='primary', elem_classes='small-button')
@ -86,4 +93,6 @@ def create_event_handlers():
lambda: gr.update(visible=True), None, gradio('file_deleter'))
shared.gradio['textbox-notebook'].input(lambda x: f"<span>{count_tokens(x)}</span>", gradio('textbox-notebook'), gradio('token-counter-notebook'), show_progress=False)
shared.gradio['get_logits-notebook'].click(logits.get_next_logits, gradio('textbox-notebook'), gradio('logits-notebook'))
shared.gradio['get_logits-notebook'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
logits.get_next_logits, gradio('textbox-notebook', 'interface_state', 'use_samplers-notebook', 'logits-notebook'), gradio('logits-notebook', 'logits-notebook-previous'), show_progress=False)