From 566898a79a0915879273f3d77017908bcf7d62ab Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sat, 25 Mar 2023 12:08:26 -0700 Subject: [PATCH] initial lora training tab --- modules/training.py | 139 ++++++++++++++++++ requirements.txt | 2 + server.py | 7 +- .../datasets/put-trainer-datasets-here.txt | 0 training/formats/alpaca-chatbot-format.json | 4 + training/formats/alpaca-format.json | 4 + training/formats/put-trainer-formats-here.txt | 0 7 files changed, 153 insertions(+), 3 deletions(-) create mode 100644 modules/training.py create mode 100644 training/datasets/put-trainer-datasets-here.txt create mode 100644 training/formats/alpaca-chatbot-format.json create mode 100644 training/formats/alpaca-format.json create mode 100644 training/formats/put-trainer-formats-here.txt diff --git a/modules/training.py b/modules/training.py new file mode 100644 index 00000000..96cd6e7c --- /dev/null +++ b/modules/training.py @@ -0,0 +1,139 @@ +import sys, torch, json +from pathlib import Path +import gradio as gr +from datasets import load_dataset +import transformers +from modules import ui, shared +from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict + +def get_json_dataset(path: str): + def get_set(): + return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path(path).glob('*.json'))), key=str.lower) + return get_set + +def create_train_interface(): + with gr.Tab('Train LoRA', elem_id='lora-train-tab'): + loraName = gr.Textbox(label="Name", info="The name of your new LoRA file") + # TODO: Add explanations of batch sizes and recommendations. Note that batch/microBatch determines gradient accumulation and explain what that means. Note the effects on VRAM usage from changing these values. + microBatchSize = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='(TODO)') + batchSize = gr.Slider(label='Batch Size', value=128, minimum=1, maximum=1024, step=4, info='(TODO)') + epochs = gr.Slider(label='Epochs', value=1, minimum=1, maximum=1000, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.') + learningRate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.') + # TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale. + loraRank = gr.Slider(label='LoRA Rank', value=8, minimum=1, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, high values like 128 or 256 are good for teaching content upgrades. Higher ranks also require higher VRAM.') + loraAlpha = gr.Slider(label='LoRA Alpha', value=16, minimum=1, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.') + # TODO: Better explain what this does. + loraDropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers.') + cutoffLen = gr.Slider(label='Cutoff Length', minimum=1,maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.') + with gr.Row(): + datasetFunction = get_json_dataset('training/datasets') + dataset = gr.Dropdown(choices=datasetFunction(), value='None', label='Dataset') + ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': datasetFunction()}, 'refresh-button') + with gr.Row(): + evalDataset = gr.Dropdown(choices=datasetFunction(), value='None', label='Evaluation Dataset') + ui.create_refresh_button(evalDataset, lambda : None, lambda : {'choices': datasetFunction()}, 'refresh-button') + with gr.Row(): + formatsFunction = get_json_dataset('training/formats') + format = gr.Dropdown(choices=formatsFunction(), value='None', label='Data Format') + ui.create_refresh_button(format, lambda : None, lambda : {'choices': formatsFunction()}, 'refresh-button') + startButton = gr.Button("Start LoRA Training") + output = gr.Markdown(value="(...)") + startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output]) + +def cleanPath(basePath: str, path: str): + """"Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" + # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path. + # Or swap it to a strict whitelist of [a-zA-Z_0-9] + path = path.replace('\\', '/').replace('..', '_') + if basePath is None: + return path + return f'{Path(basePath).absolute()}/{path}' + +def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, learningRate: float, loraRank: int, loraAlpha: int, loraDropout: float, cutoffLen: int, dataset: str, evalDataset: str, format: str): + # Input validation / processing + # TODO: --lora-dir PR once pulled will need to be applied here + loraName = f"loras/{cleanPath(None, loraName)}" + if dataset is None: + return "**Missing dataset choice input, cannot continue.**" + if format is None: + return "**Missing format choice input, cannot continue.**" + gradientAccumulationSteps = batchSize // microBatchSize + actualLR = float(learningRate) + model = shared.model + tokenizer = shared.tokenizer + tokenizer.pad_token = 0 + tokenizer.padding_side = "left" + # Prep the dataset, format, etc + with open(cleanPath('training/formats', f'{format}.json'), 'r') as formatFile: + formatData: dict[str, str] = json.load(formatFile) + def tokenize(prompt): + result = tokenizer(prompt, truncation=True, max_length=cutoffLen + 1, padding="max_length") + return { + "input_ids": result["input_ids"][:-1], + "attention_mask": result["attention_mask"][:-1], + } + def generate_prompt(data_point: dict[str, str]): + for options, data in formatData.items(): + if set(options.split(',')) == set(data_point.keys()): + for key, val in data_point.items(): + data = data.replace(f'%{key}%', val) + return data + raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(formatData.keys())}"') + def generate_and_tokenize_prompt(data_point): + prompt = generate_prompt(data_point) + return tokenize(prompt) + data = load_dataset("json", data_files=cleanPath('training/datasets', f'{dataset}.json')) + train_data = data['train'].shuffle().map(generate_and_tokenize_prompt) + if evalDataset == 'None': + evalData = None + else: + evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json')) + evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt) + # Start prepping the model itself + model = prepare_model_for_int8_training(model) + config = LoraConfig( + r=loraRank, + lora_alpha=loraAlpha, + # TODO: Should target_modules be configurable? + target_modules=[ "q_proj", "v_proj" ], + lora_dropout=loraDropout, + bias="none", + task_type="CAUSAL_LM" + ) + model = get_peft_model(model, config) + trainer = transformers.Trainer( + model=model, + train_dataset=train_data, + eval_dataset=evalData, + args=transformers.TrainingArguments( + per_device_train_batch_size=microBatchSize, + gradient_accumulation_steps=gradientAccumulationSteps, + # TODO: Should more of these be configurable? Probably. + warmup_steps=100, + num_train_epochs=epochs, + learning_rate=actualLR, + fp16=True, + logging_steps=20, + evaluation_strategy="steps" if evalData is not None else "no", + save_strategy="steps", + eval_steps=200 if evalData is not None else None, + save_steps=200, + output_dir=loraName, + save_total_limit=3, + load_best_model_at_end=True if evalData is not None else False, + # TODO: Enable multi-device support + ddp_find_unused_parameters=None, + ), + data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + model.config.use_cache = False + old_state_dict = model.state_dict + model.state_dict = ( + lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) + ).__get__(model, type(model)) + if torch.__version__ >= "2" and sys.platform != "win32": + model = torch.compile(model) + # Actually start and run and save at the end + trainer.train() + model.save_pretrained(loraName) + return "Done!" diff --git a/requirements.txt b/requirements.txt index e5b3de69..c93ce671 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,6 @@ rwkv==0.7.0 safetensors==0.3.0 sentencepiece tqdm +peft +datasets git+https://github.com/huggingface/transformers diff --git a/server.py b/server.py index f423e368..cd95d5ef 100644 --- a/server.py +++ b/server.py @@ -8,10 +8,8 @@ from pathlib import Path import gradio as gr -import modules.chat as chat +from modules import chat, shared, ui, training import modules.extensions as extensions_module -import modules.shared as shared -import modules.ui as ui from modules.html_generator import generate_chat_html from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt @@ -443,6 +441,9 @@ def create_interface(): shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'cmd_arguments_menu']], None) shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'

Reloading...

\'; setTimeout(function(){location.reload()},2500)}') + + with gr.Tab("Training", elem_id="training-tab"): + training.create_train_interface() if shared.args.extensions is not None: extensions_module.create_extensions_block() diff --git a/training/datasets/put-trainer-datasets-here.txt b/training/datasets/put-trainer-datasets-here.txt new file mode 100644 index 00000000..e69de29b diff --git a/training/formats/alpaca-chatbot-format.json b/training/formats/alpaca-chatbot-format.json new file mode 100644 index 00000000..4b38103f --- /dev/null +++ b/training/formats/alpaca-chatbot-format.json @@ -0,0 +1,4 @@ +{ + "instruction,output": "User: %instruction%\nAssistant: %output%", + "instruction,input,output": "User: %instruction%: %input%\nAssistant: %output%" +} diff --git a/training/formats/alpaca-format.json b/training/formats/alpaca-format.json new file mode 100644 index 00000000..dd6df956 --- /dev/null +++ b/training/formats/alpaca-format.json @@ -0,0 +1,4 @@ +{ + "instruction,output": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n%instruction%\n\n### Response:\n%output%", + "instruction,input,output": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n%instruction%\n\n### Input:\n%input%\n\n### Response:\n%output%" +} diff --git a/training/formats/put-trainer-formats-here.txt b/training/formats/put-trainer-formats-here.txt new file mode 100644 index 00000000..e69de29b