""" An example of extension. It does nothing, but you can add transformations before the return statements to customize the webui behavior. Starting from history_modifier and ending in output_modifier, the functions are declared in the same order that they are called at generation time. """ import torch from modules import chat from modules.text_generation import ( decode, encode, generate_reply, ) from transformers import LogitsProcessor params = { "display_name": "Example Extension", "is_tab": False, } class MyLogits(LogitsProcessor): """ Manipulates the probabilities for the next token before it gets sampled. It gets used in the custom_logits_processor function below. """ def __init__(self): pass def __call__(self, input_ids, scores): # probs = torch.softmax(scores, dim=-1, dtype=torch.float) # probs[0] /= probs[0].sum() # scores = torch.log(probs / (1 - probs)) return scores def history_modifier(history): """ Modifies the chat history. Only used in chat mode. """ return history def state_modifier(state): """ Modifies the state variable, which is a dictionary containing the input values in the UI like sliders and checkboxes. """ return state def chat_input_modifier(text, visible_text, state): """ Modifies the internal and visible input strings in chat mode. """ return text, visible_text def input_modifier(string, state): """ In chat mode, modifies the user input. The modified version goes into history['internal'], and the original version goes into history['visible']. In default/notebook modes, modifies the whole prompt. """ return string def bot_prefix_modifier(string, state): """ Modifies the prefix for the next bot reply in chat mode. By default, the prefix will be something like "Bot Name:". """ return string def tokenizer_modifier(state, prompt, input_ids, input_embeds): """ Modifies the input ids and embeds. Used by the multimodal extension to put image embeddings in the prompt. Only used by loaders that use the transformers library for sampling. """ return prompt, input_ids, input_embeds def logits_processor_modifier(processor_list, input_ids): """ Adds logits processors to the list. Only used by loaders that use the transformers library for sampling. """ processor_list.append(MyLogits()) return processor_list def output_modifier(string, state): """ Modifies the LLM output before it gets presented. In chat mode, the modified version goes into history['internal'], and the original version goes into history['visible']. """ return string def custom_generate_chat_prompt(user_input, state, **kwargs): """ Replaces the function that generates the prompt from the chat history. Only used in chat mode. """ result = chat.generate_chat_prompt(user_input, state, **kwargs) return result def custom_css(): """ Returns a CSS string that gets appended to the CSS for the webui. """ return '' def custom_js(): """ Returns a javascript string that gets appended to the javascript for the webui. """ return '' def setup(): """ Gets executed only once, when the extension is imported. """ pass def ui(): """ Gets executed when the UI is drawn. Custom gradio elements and their corresponding event handlers should be defined here. """ pass