From 8545052c9d994370b110047e634c4593d02d50f9 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 22 Aug 2023 20:18:16 -0700 Subject: [PATCH] Add the option to use samplers in the logit viewer --- css/main.css | 5 +++++ js/main.js | 9 +++++++++ modules/callbacks.py | 1 + modules/logits.py | 29 ++++++++++++++++++++--------- modules/sampler_hijack.py | 13 +++++++++++++ modules/training.py | 2 +- modules/ui_default.py | 17 +++++++++++++---- modules/ui_notebook.py | 17 +++++++++++++---- 8 files changed, 75 insertions(+), 18 deletions(-) diff --git a/css/main.css b/css/main.css index 3408375c..405b57e0 100644 --- a/css/main.css +++ b/css/main.css @@ -237,6 +237,11 @@ audio { border-radius: 0.4em; } +.no-background { + background: var(--background-fill-primary) !important; + padding: 0px !important; +} + /*****************************************************/ /*************** Chat UI declarations ****************/ /*****************************************************/ diff --git a/js/main.js b/js/main.js index 6a27c3b4..e409cc3d 100644 --- a/js/main.js +++ b/js/main.js @@ -82,3 +82,12 @@ observer.observe(targetElement, config); //------------------------------------------------ document.getElementById('chat-input').parentNode.style.background = 'transparent'; document.getElementById('chat-input').parentNode.style.border = 'none'; + +//------------------------------------------------ +// Remove some backgrounds +//------------------------------------------------ +const noBackgroundelements = document.querySelectorAll('.no-background'); +for(i = 0; i < noBackgroundelements.length; i++) { + noBackgroundelements[i].parentNode.style.border = 'none'; + noBackgroundelements[i].parentNode.parentNode.parentNode.style.alignItems = 'center'; +} diff --git a/modules/callbacks.py b/modules/callbacks.py index 1fa95e47..e29e397d 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -24,6 +24,7 @@ class Stream(transformers.StoppingCriteria): def __call__(self, input_ids, scores) -> bool: if self.callback_func is not None: self.callback_func(input_ids[0]) + return False diff --git a/modules/logits.py b/modules/logits.py index 99cb336f..3bfeb6b0 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -1,19 +1,30 @@ import torch -from modules import shared +from modules import sampler_hijack, shared +from modules.text_generation import generate_reply + +global_scores = None -def get_next_logits(prompt): - tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda() - output = shared.model(input_ids=tokens) +def get_next_logits(prompt, state, use_samplers, previous): + if use_samplers: + state['max_new_tokens'] = 1 + state['auto_max_new_tokens'] = False + for _ in generate_reply(prompt, state): + pass + + scores = sampler_hijack.global_scores[-1] + else: + tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda() + output = shared.model(input_ids=tokens) + scores = output['logits'][-1][-1] - scores = output['logits'][-1][-1] probs = torch.softmax(scores, dim=-1, dtype=torch.float) - topk_values, topk_indices = torch.topk(probs, k=20, largest=True, sorted=True) - topk_values = [f"{float(i):.5f}" % i for i in topk_values] + topk_values = [f"{float(i):.5f}" for i in topk_values] + output = '' for row in list(zip(topk_values, shared.tokenizer.convert_ids_to_tokens(topk_indices))): - output += f"{row[0]} {row[1]}\n" + output += f"{row[0]} - {row[1]}\n" - return output + return output, previous diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index d5ebbb76..0a724f47 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -10,6 +10,8 @@ from transformers.generation.logits_process import ( TemperatureLogitsWarper ) +global_scores = None + class TailFreeLogitsWarper(LogitsWarper): def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -122,6 +124,16 @@ class MirostatLogitsWarper(LogitsWarper): return scores +class SpyLogitsWarper(LogitsWarper): + def __init__(self): + pass + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + global global_scores + global_scores = scores + return scores + + class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor): ''' Copied from the transformers library @@ -168,6 +180,7 @@ def get_logits_warper_patch(self, generation_config): else: warpers += warpers_to_add + warpers.append(SpyLogitsWarper()) return warpers diff --git a/modules/training.py b/modules/training.py index 7be0d24f..a993f6f0 100644 --- a/modules/training.py +++ b/modules/training.py @@ -64,7 +64,7 @@ def create_ui(): with gr.Column(scale=5): lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file') with gr.Column(): - always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).') + always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background']) with gr.Row(): with gr.Column(): diff --git a/modules/ui_default.py b/modules/ui_default.py index 5470a6ad..29b9bee5 100644 --- a/modules/ui_default.py +++ b/modules/ui_default.py @@ -44,8 +44,15 @@ def create_ui(): shared.gradio['html-default'] = gr.HTML() with gr.Tab('Logits'): - shared.gradio['get_logits-default'] = gr.Button('Get next token probabilities') - shared.gradio['logits-default'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar']) + with gr.Row(): + with gr.Column(scale=10): + shared.gradio['get_logits-default'] = gr.Button('Get next token probabilities') + with gr.Column(scale=1): + shared.gradio['use_samplers-default'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background']) + + with gr.Row(): + shared.gradio['logits-default'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar']) + shared.gradio['logits-default-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits', 'add_scrollbar']) def create_event_handlers(): @@ -83,5 +90,7 @@ def create_event_handlers(): lambda x: x + '.txt', gradio('prompt_menu-default'), gradio('delete_filename')).then( lambda: gr.update(visible=True), None, gradio('file_deleter')) - shared.gradio['textbox-default'].change(lambda x : f"{count_tokens(x)}", gradio('textbox-default'), gradio('token-counter-default'), show_progress=False) - shared.gradio['get_logits-default'].click(logits.get_next_logits, gradio('textbox-default'), gradio('logits-default'), show_progress=False) + shared.gradio['textbox-default'].change(lambda x: f"{count_tokens(x)}", gradio('textbox-default'), gradio('token-counter-default'), show_progress=False) + shared.gradio['get_logits-default'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + logits.get_next_logits, gradio('textbox-default', 'interface_state', 'use_samplers-default', 'logits-default'), gradio('logits-default', 'logits-default-previous'), show_progress=False) diff --git a/modules/ui_notebook.py b/modules/ui_notebook.py index 7fbf7a85..9ff0c3fe 100644 --- a/modules/ui_notebook.py +++ b/modules/ui_notebook.py @@ -30,8 +30,15 @@ def create_ui(): shared.gradio['html-notebook'] = gr.HTML() with gr.Tab('Logits'): - shared.gradio['get_logits-notebook'] = gr.Button('Get next token probabilities') - shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits_notebook', 'add_scrollbar']) + with gr.Row(): + with gr.Column(scale=10): + shared.gradio['get_logits-notebook'] = gr.Button('Get next token probabilities') + with gr.Column(scale=1): + shared.gradio['use_samplers-notebook'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background']) + + with gr.Row(): + shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar']) + shared.gradio['logits-notebook-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits', 'add_scrollbar']) with gr.Row(): shared.gradio['Generate-notebook'] = gr.Button('Generate', variant='primary', elem_classes='small-button') @@ -85,5 +92,7 @@ def create_event_handlers(): lambda x: x + '.txt', gradio('prompt_menu-notebook'), gradio('delete_filename')).then( lambda: gr.update(visible=True), None, gradio('file_deleter')) - shared.gradio['textbox-notebook'].input(lambda x : f"{count_tokens(x)}", gradio('textbox-notebook'), gradio('token-counter-notebook'), show_progress=False) - shared.gradio['get_logits-notebook'].click(logits.get_next_logits, gradio('textbox-notebook'), gradio('logits-notebook')) + shared.gradio['textbox-notebook'].input(lambda x: f"{count_tokens(x)}", gradio('textbox-notebook'), gradio('token-counter-notebook'), show_progress=False) + shared.gradio['get_logits-notebook'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + logits.get_next_logits, gradio('textbox-notebook', 'interface_state', 'use_samplers-notebook', 'logits-notebook'), gradio('logits-notebook', 'logits-notebook-previous'), show_progress=False)