From c4f4f413897ff41266e66c35f253405ecafccbfb Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 21 Apr 2023 00:20:33 -0300 Subject: [PATCH] Add an "Evaluate" tab to calculate the perplexities of models (#1322) --- modules/evaluate.py | 140 ++++++++++++++++++++++++++++++++++++++++++++ modules/models.py | 14 ++--- modules/training.py | 65 ++++++++++++++++---- modules/ui.py | 3 +- requirements.txt | 3 +- 5 files changed, 203 insertions(+), 22 deletions(-) create mode 100644 modules/evaluate.py diff --git a/modules/evaluate.py b/modules/evaluate.py new file mode 100644 index 00000000..9822ddea --- /dev/null +++ b/modules/evaluate.py @@ -0,0 +1,140 @@ +import datetime +import traceback +from pathlib import Path + +import pandas as pd +import torch +from datasets import load_dataset +from tqdm import tqdm + +from modules import shared +from modules.models import load_model, unload_model +from modules.text_generation import encode +from server import get_model_specific_settings, update_model_parameters + + +def load_past_evaluations(): + if Path('logs/evaluations.csv').exists(): + df = pd.read_csv(Path('logs/evaluations.csv'), dtype=str) + df['Perplexity'] = pd.to_numeric(df['Perplexity']) + return df + else: + return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment']) +past_evaluations = load_past_evaluations() + + +def save_past_evaluations(df): + df.to_csv(Path('logs/evaluations.csv'), index=False) + + +def calculate_perplexity(models, input_dataset, stride, _max_length): + ''' + Based on: + https://huggingface.co/docs/transformers/perplexity#calculating-ppl-with-fixedlength-models + ''' + + global past_evaluations + cumulative_log = '' + cumulative_log += "Loading the input dataset...\n" + yield cumulative_log + + # Copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/utils/datautils.py + if input_dataset == 'wikitext': + data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + text = "\n\n".join(data['text']) + elif input_dataset == 'ptb': + data = load_dataset('ptb_text_only', 'penn_treebank', split='validation') + text = "\n\n".join(data['sentence']) + elif input_dataset == 'ptb_new': + data = load_dataset('ptb_text_only', 'penn_treebank', split='test') + text = " ".join(data['sentence']) + else: + with open(Path(f'training/datasets/{input_dataset}.txt'), 'r', encoding='utf-8') as f: + text = f.read() + + for model in models: + if is_in_past_evaluations(model, input_dataset, stride, _max_length): + cumulative_log += f"{model} has already been tested. Ignoring.\n" + yield cumulative_log + continue + + if model != 'current model': + try: + yield cumulative_log + f"Loading {model}...\n" + model_settings = get_model_specific_settings(model) + shared.settings.update(model_settings) # hijacking the interface defaults + update_model_parameters(model_settings) # hijacking the command-line arguments + shared.model_name = model + unload_model() + shared.model, shared.tokenizer = load_model(shared.model_name) + except: + cumulative_log += f"Failed to load {model}. Moving on.\n" + yield cumulative_log + continue + + cumulative_log += f"Processing {model}...\n" + yield cumulative_log + "Tokenizing the input dataset...\n" + encodings = encode(text, add_special_tokens=False) + seq_len = encodings.shape[1] + max_length = _max_length or shared.model.config.max_position_embeddings + nlls = [] + prev_end_loc = 0 + for begin_loc in tqdm(range(0, seq_len, stride)): + yield cumulative_log + f"Evaluating... {100*begin_loc/seq_len:.2f}%" + end_loc = min(begin_loc + max_length, seq_len) + trg_len = end_loc - prev_end_loc # may be different from stride on last loop + input_ids = encodings[:, begin_loc:end_loc] + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + with torch.no_grad(): + outputs = shared.model(input_ids, labels=target_ids) + + # loss is calculated using CrossEntropyLoss which averages over valid labels + # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels + # to the left by 1. + neg_log_likelihood = outputs.loss + + nlls.append(neg_log_likelihood) + + prev_end_loc = end_loc + if end_loc == seq_len: + break + + ppl = torch.exp(torch.stack(nlls).mean()) + add_entry_to_past_evaluations(float(ppl), shared.model_name, input_dataset, stride, _max_length) + save_past_evaluations(past_evaluations) + cumulative_log += f"Done. The perplexity is: {float(ppl)}\n\n" + yield cumulative_log + + +def add_entry_to_past_evaluations(perplexity, model, dataset, stride, max_length): + global past_evaluations + entry = { + 'Model': model, + 'LoRAs': ', '.join(shared.lora_names) or '-', + 'Dataset': dataset, + 'Perplexity': perplexity, + 'stride': str(stride), + 'max_length': str(max_length), + 'Date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'Comment': '' + } + past_evaluations = pd.concat([past_evaluations, pd.DataFrame([entry])], ignore_index=True) + + +def is_in_past_evaluations(model, dataset, stride, max_length): + entries = past_evaluations[(past_evaluations['Model'] == model) & + (past_evaluations['Dataset'] == dataset) & + (past_evaluations['max_length'] == str(max_length)) & + (past_evaluations['stride'] == str(stride))] + + if entries.shape[0] > 0: + return True + else: + return False + + +def generate_markdown_table(): + sorted_df = past_evaluations.sort_values(by=['Dataset', 'stride', 'Perplexity', 'Date']) + return sorted_df diff --git a/modules/models.py b/modules/models.py index d639ca65..800d0be2 100644 --- a/modules/models.py +++ b/modules/models.py @@ -53,7 +53,7 @@ def load_model(model_name): # Load the model in simple 16-bit mode by default if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]): - model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=trust_remote_code) + model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=trust_remote_code) if torch.has_mps: device = torch.device('mps') model = model.to(device) @@ -81,11 +81,11 @@ def load_model(model_name): num_bits=4, group_size=64, group_dim=2, symmetric=False)) - model = OptLM(f"facebook/{shared.model_name}", env, shared.args.model_dir, policy) + model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy) # DeepSpeed ZeRO-3 elif shared.args.deepspeed: - model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model.module.eval() # Inference print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") @@ -169,7 +169,7 @@ def load_model(model_name): if shared.args.disk: params["offload_folder"] = shared.args.disk_cache_dir - checkpoint = Path(f'{shared.args.model_dir}/{shared.model_name}') + checkpoint = Path(f'{shared.args.model_dir}/{model_name}') if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': config = AutoConfig.from_pretrained(checkpoint) @@ -190,7 +190,7 @@ def load_model(model_name): llama_attn_hijack.hijack_llama_attention() # Loading the tokenizer - if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): + if any((k in model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) elif type(model) is transformers.LlamaForCausalLM: tokenizer = None @@ -205,7 +205,7 @@ def load_model(model_name): # Otherwise, load it from the model folder and hope that these # are not outdated tokenizer files. if tokenizer is None: - tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True) + tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True) try: tokenizer.eos_token_id = 2 tokenizer.bos_token_id = 1 @@ -213,7 +213,7 @@ def load_model(model_name): except: pass else: - tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), trust_remote_code=trust_remote_code) + tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code) print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer diff --git a/modules/training.py b/modules/training.py index 1a12a0e4..000a1cea 100644 --- a/modules/training.py +++ b/modules/training.py @@ -10,9 +10,12 @@ import gradio as gr import torch import transformers from datasets import Dataset, load_dataset -from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, prepare_model_for_int8_training +from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training, + set_peft_model_state_dict) from modules import shared, ui +from modules.evaluate import calculate_perplexity, generate_markdown_table, save_past_evaluations +from server import get_available_loras, get_available_models # This mapping is from a very recent commit, not yet released. # If not available, default to a backup map for the 3 safe model types. @@ -40,10 +43,6 @@ def get_datasets(path: str, ext: str): return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower) -def get_available_loras(): - return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) - - def create_train_interface(): with gr.Tab('Train LoRA', elem_id='lora-train-tab'): with gr.Row(): @@ -82,9 +81,9 @@ def create_train_interface(): eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.') - with gr.Tab(label='Raw Text File'): + with gr.Tab(label="Raw text file"): with gr.Row(): - raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.') + raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.') ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button') with gr.Row(): @@ -106,11 +105,48 @@ def create_train_interface(): output = gr.Markdown(value="Ready") - all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit, warmup_steps, optimizer] - copy_from.change(do_copy_params, [copy_from] + all_params, all_params) - start_button.click(do_train, all_params, output) - stop_button.click(do_interrupt, None, None, queue=False) - higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha]) + with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'): + with gr.Row(): + with gr.Column(): + models = gr.Dropdown(get_available_models(), label='Models', multiselect=True) + evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.') + with gr.Row(): + stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.') + max_length = gr.Slider(label='max_length', minimum=1, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.') + + with gr.Row(): + start_current_evaluation = gr.Button("Evaluate loaded model") + start_evaluation = gr.Button("Evaluate selected models") + stop_evaluation = gr.Button("Interrupt") + + with gr.Column(): + evaluation_log = gr.Markdown(value = '') + + evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True) + save_comments = gr.Button('Save comments') + + # Training events + all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit, warmup_steps, optimizer] + copy_from.change(do_copy_params, [copy_from] + all_params, all_params) + start_button.click(do_train, all_params, output) + stop_button.click(do_interrupt, None, None, queue=False) + higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha]) + + # Evaluation events. For some reason, the interrupt event + # doesn't work with the .then() syntax, so I write them one + # by one in this ugly but functional way. + ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False) + start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False) + + tmp = gr.State('') + start_current_evaluation.click(lambda: ['current model'], None, tmp) + ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False) + start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False) + + stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False) + save_comments.click( + save_past_evaluations, evaluation_table, None).then( + lambda: "Comments saved.", None, evaluation_log, show_progress=False) def do_interrupt(): @@ -133,6 +169,7 @@ def do_copy_params(lora_name: str, *args): result.append(params[key]) else: result.append(args[i]) + return result @@ -155,7 +192,8 @@ def clean_path(base_path: str, path: str): def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, do_shuffle: bool, higher_rank_limit: bool, warmup_steps: int, optimizer: str): if shared.args.monkey_patch: - from monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_gptq_lora_model + from monkeypatch.peft_tuners_lora_monkey_patch import \ + replace_peft_model_with_gptq_lora_model replace_peft_model_with_gptq_lora_model() global WANT_INTERRUPT @@ -300,6 +338,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch if '4bit' in str(type(m)): if m.is_v1_model: m.zeros = m.zeros.half() + m.scales = m.scales.half() class Tracked(): diff --git a/modules/ui.py b/modules/ui.py index 121b6c5a..d84cbacc 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,8 @@ theme = gr.themes.Default( font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'], ).set( border_color_primary='#c5c5d2', - button_large_padding='6px 12px' + button_large_padding='6px 12px', + body_text_color_subdued='#484848' ) def list_model_elements(): diff --git a/requirements.txt b/requirements.txt index 6c7e22ec..e5f0a8f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,12 +5,13 @@ flexgen==0.1.7 gradio==3.25.0 markdown numpy +pandas Pillow>=9.5.0 +pyyaml requests rwkv==0.7.3 safetensors==0.3.0 sentencepiece -pyyaml tqdm git+https://github.com/huggingface/peft transformers==4.28.1