Training PRO a month worth of updates (#4345)

This commit is contained in:
FartyPants (FP HAM) 2023-10-22 11:38:09 -04:00 committed by GitHub
parent c18504f369
commit 6a61158adf
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 803 additions and 99 deletions

View file

@ -1,10 +1,27 @@
# Training_PRO
This is an expanded Training tab
This is an expanded and reworked Training tab
Maintained by FP
[![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/Q5Q5MOB4M)
Repo home:
https://github.com/FartyPants/Training_PRO
In general the repo above is ahead of the extension included in text WebUi.
## News
- NEFtune: add noise to help with generalization
- Loss Graph in interface.
- Supports Mistral training
- some roundabout around pytorch and transformers version desync
![image](https://github.com/FartyPants/Training_PRO/assets/23346289/e389ec69-d7ad-4922-9ad9-865625997479)
## Features/Changes
- Chunking: precise raw text slicer (PRTS) uses sentence slicing and making sure things are clean on all ends
- overlap chunking - this special overlapping will make additional overlap block based on logical rules (aka no overlap block on hard cut)
- custom scheduler (follow the code to make your own) In LR Scheduler select FP_low_epoch_annealing - this scheduler will keep the LR constant for first epoch then use cosine for the rest - this part would be best to spawn into a new py file
@ -19,11 +36,30 @@ https://github.com/FartyPants/Training_PRO
- Ability to change Stop Loss during training
- different modes of checkpoint auto saving
- Function to Check Dataset and suggest parameters such as warmup and checkpoint save frequency before training
- Graph Training Loss in interface
- more custom schedulers
### Notes:
This uses it's own chunking code for raw text based on sentence splitting. This will avoid weird cuts in the chunks and each chunk should now start with sentence and end on some sentence. It works hand in hand with Hard Cut. A propper use is to structure your text into logical blocks (ideas) separated by three \n then use three \n in hard cut. This way each chunk will contain only one flow of ideas and not derail in the thoughts. And Overlapping code will create overlapped blocks on sentence basis too, but not cross hard cut, thus not cross different ideas either. Does it make any sense? No? Hmmmm...
### Custom schedulers
A bunch of custom (combination) schedulers are added to the LR schedule. These are based on my own experiments
**FP_low_epoch_annealing**
Uses constant LR (with warmup) for 1 epoch only. The rest of the epoch(s) is cosine annealing. So 10 epochs - 1 will be constant 9 will be nose dive down. However a typical usage would be 2 epochs (hence low epoch in name). 1st is constant, the second is annealing. Simple. I use it 90% of time.
**FP_half_time_annealing**
Like the low epoch, but now the total number of steps is divided by 2. First half is constant, second half is annealing. So 10 epochs - 5 will be constant, 5 will be cosine nose down.
**FP_raise_fall_creative**
This is a sine raise till half of the total steps then cosine fall the rest. (Or you may think of the curve as sine in its entirety. The most learning is done in the hump, in the middle. The warmup entry has no effect, since sine is automatically warm up.
The idea is to start very mildly as not to overfit with the first blocks of dataset. It seems to broaden the scope of the model making it less strict for tight dataset.
### Targets
Normal LORA is q, v and that's what you should use. You can use (q k v o) or (q k v) and it will give you a lot more trainable parameters. The benefit is that you can keep rank lower and still attain the same coherency as q v with high rank. Guanaco has been trained with QLORA and q k v o for example and they swear by it.

View file

@ -4,10 +4,35 @@ import transformers
import math
from torch.optim.lr_scheduler import LambdaLR
from peft import (
PeftModel,
)
RED = "\033[91m"
YELLOW = "\033[93m"
GREEN = "\033[92m"
RESET = "\033[0m"
#FPHAM custom training scheduller block - should be extracted to separate file
last_print_label = ''
custom_scheduler_params = {'trigger_loss': 0.0, 'ramp_down_ratio':1.0, 'current_loss': 0.0,'dynamic_scheduler_stop': False, 'calc_ramp_down_at_step': 0, 'calc_num_training_steps': 0}
def custom_scheduler_global_update(current_loss: float):
custom_scheduler_params.update({'current_loss': current_loss})
def custom_scheduler_global_setup(trigger_loss: float, ramp_down_ratio: float):
custom_scheduler_params.update({'trigger_loss': trigger_loss})
custom_scheduler_params.update({'ramp_down_ratio': ramp_down_ratio})
# calculates the total num steps after trigger
custom_scheduler_params.update({'calc_num_training_steps': 0})
#calculates steps when the ramp_down trigger occured
custom_scheduler_params.update({'calc_ramp_down_at_step': 0})
# triggers scheduler stopping after it reached calc_num_training_steps
custom_scheduler_params.update({'dynamic_scheduler_stop': False})
# hold constant to the half of epochs then cosine down to 0
def _get_fp_half_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
@ -40,6 +65,35 @@ def _get_fp_half_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup
num_cycles = 0.5
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
# raise up in cosine, then fall back in cosine
def _get_fp_cosine_raise_and_fall_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
global last_print_label
print_label = ''
half_steps = num_training_steps//2
#num_warmup_steps = min(num_warmup_steps,half_steps)
if current_step < half_steps:
print_label = 'Scheduler: Raise'
else:
print_label = 'Scheduler: Fall'
if print_label != last_print_label:
print(print_label)
last_print_label = print_label
# linear
# return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - half_steps) / float(max(1, num_training_steps - half_steps))
num_cycles = 0.5
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
# constant to the first epochs then cosine down to 0 over the rest epochs
def _get_fp_cosine_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
@ -70,6 +124,43 @@ def _get_fp_cosine_schedule_with_warmup_lr_lambda(current_step: int, *, num_warm
num_cycles = 0.5
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
# halve lr each epoch
def _get_fp_cdrop_rate_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
global last_print_label
print_label = ''
num_warmup_steps = min(num_warmup_steps, num_firstepoch_steps)
current_epoch = (current_step // num_firstepoch_steps) + 1
if current_step < num_warmup_steps:
print_label = 'Scheduler: Warmup'
elif current_step < num_firstepoch_steps:
print_label = 'Scheduler: Hold'
else:
print_label = 'Scheduler: Drop Rate'
if print_label != last_print_label:
print(print_label)
last_print_label = print_label
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
if current_step < num_firstepoch_steps:
return 1.0
# Compute the learning rate for the annealing phase
learning_rate = 1.0 / float(2 ** (current_epoch - 1))
return learning_rate
# epoch decay: 1/(1 + decay * epoch)
def custom_cosine_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
"""
@ -119,9 +210,157 @@ def custom_half_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def custom_raise_fall_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
"""
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_fp_cosine_raise_and_fall_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_firstepoch_steps = num_firstepoch_steps,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def neftune_forward(self, input: torch.Tensor):
"""
Implements the NEFTune forward pass for the model. Note this works only for
torch.nn.Embedding layers. This method is slightly adapted from the original source code
that can be found here: https://github.com/neelsjain/NEFTune
Args:
input (`torch.Tensor`):
The input tensor to the model.
noise_alpha (`float`):
The noise alpha value to use for the NEFTune forward pass.
"""
embeddings = torch.nn.functional.embedding(
input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
)
if self.training:
# Add noise to the embeddings
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)
return embeddings
class FPNEFtuneTrainer(transformers.Trainer):
def __init__(self,neftune_noise_alpha:float = 0.0, model = None, *args, **kwargs):
self.neftune_noise_alpha = neftune_noise_alpha
if self.neftune_noise_alpha > 0.0:
model = self._activate_neftune(model)
super().__init__(model = model, *args, **kwargs)
def _activate_neftune(self, model):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
"""
print(f"Activating {RED}NEFtune{RESET} with scale: {self.neftune_noise_alpha}")
if isinstance(model, transformers.PreTrainedModel):
embeddings = model.get_input_embeddings()
elif isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
old_forward = embeddings.forward
# This hack seems to be needed to properly use a custom forward pass
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
bound_method = neftune_forward.__get__(embeddings, embeddings.__class__)
setattr(embeddings, "forward", bound_method)
# embeddings.forward = neftune_forward
embeddings._trl_old_forward = old_forward
return model
def train(self, *args, **kwargs):
output = super().train(*args, **kwargs)
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer
if self.neftune_noise_alpha is not None:
if isinstance(self.model, transformers.PreTrainedModel):
embeddings = self.model.get_input_embeddings()
elif isinstance(self.model, PeftModel):
embeddings = self.model.base_model.get_input_embeddings()
if hasattr(embeddings, "_trl_old_forward"):
embeddings.forward = embeddings._trl_old_forward
del embeddings._trl_old_forward
del embeddings.neftune_noise_alpha
return output
class FPSchedulerTrainer(transformers.Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self,neftune_noise_alpha:float = 0.0, model = None, *args, **kwargs):
self.neftune_noise_alpha = neftune_noise_alpha
if self.neftune_noise_alpha > 0.0:
model = self._activate_neftune(model)
super().__init__(model = model, *args, **kwargs)
def _activate_neftune(self, model):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
"""
print(f"Activating {RED}NEFtune{RESET} with scale: {self.neftune_noise_alpha}")
if isinstance(model, transformers.PreTrainedModel):
embeddings = model.get_input_embeddings()
elif isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
old_forward = embeddings.forward
# This hack seems to be needed to properly use a custom forward pass
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
bound_method = neftune_forward.__get__(embeddings, embeddings.__class__)
setattr(embeddings, "forward", bound_method)
# embeddings.forward = neftune_forward
embeddings._trl_old_forward = old_forward
return model
def train(self, *args, **kwargs):
output = super().train(*args, **kwargs)
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer
if self.neftune_noise_alpha is not None:
if isinstance(self.model, transformers.PreTrainedModel):
embeddings = self.model.get_input_embeddings()
elif isinstance(self.model, PeftModel):
embeddings = self.model.base_model.get_input_embeddings()
if hasattr(embeddings, "_trl_old_forward"):
embeddings.forward = embeddings._trl_old_forward
del embeddings._trl_old_forward
del embeddings.neftune_noise_alpha
return output
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
#Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument.
@ -133,6 +372,8 @@ class FPSchedulerTrainer(transformers.Trainer):
num_firstepoch_steps_acc = num_firstepoch_steps*self.args.gradient_accumulation_steps
num_training_steps_acc = num_training_steps*self.args.gradient_accumulation_steps
custom_scheduler_params.update({'dynamic_scheduler_stop': False})
print (f"Warm-up steps aligned to Gradient accumulation ({self.args.gradient_accumulation_steps}) = {num_warmup_acc} actual warmup steps")
if self.args.lr_scheduler_type == 'cosine':
@ -171,5 +412,22 @@ class FPSchedulerTrainer(transformers.Trainer):
)
self._created_lr_scheduler = True
return self.lr_scheduler
elif self.args.lr_scheduler_type == 'constant_with_warmup':
half_step_acc = num_training_steps_acc//2
if num_warmup_steps>0:
print(f"Warmup doesn't apply to this scheduler [Raise-Fall]")
print (f"Scheduler Raise: 0-{half_step_acc}, Fall {half_step_acc}-{num_training_steps_acc}")
self.lr_scheduler = custom_raise_fall_scheduler_with_warmup(
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_firstepoch_steps = num_firstepoch_steps,
)
self._created_lr_scheduler = True
return self.lr_scheduler
else:
return super().create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

View file

@ -15,12 +15,16 @@ from datetime import datetime
from pathlib import Path
import gradio as gr
import pandas as pd
import torch
import transformers
from .custom_scheduler import FPSchedulerTrainer
from functools import partial
from .custom_scheduler import FPSchedulerTrainer, FPNEFtuneTrainer
from .matplotgraph import create_graph
from .train_utils import get_available_loras_local, precise_cut, sliding_block_cut
from .train_utils import get_available_loras_local, precise_cut, sliding_block_cut, download_file_from_url
from datasets import Dataset, load_dataset
from peft import (
@ -48,6 +52,59 @@ from modules.models import reload_model
from modules.utils import natural_keys
## just temporary to avoid warning
import inspect
from typing import Callable, Optional, Tuple, ContextManager
if hasattr(torch.utils.checkpoint, 'noop_context_fn'):
def my_checkpoint(
function,
*args,
use_reentrant: Optional[bool] = None,
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = torch.utils.checkpoint.noop_context_fn,
determinism_check: str = torch.utils.checkpoint._DEFAULT_DETERMINISM_MODE,
debug: bool = False,
**kwargs
):
if use_reentrant is None:
#print ("reentran = NONE")
use_reentrant = True
# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop("preserve_rng_state", True)
if kwargs and use_reentrant:
raise ValueError(
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
)
if use_reentrant:
if context_fn is not torch.utils.checkpoint.noop_context_fn or debug is not False:
raise ValueError(
"Passing `context_fn` or `debug` is only supported when "
"use_reentrant=False."
)
return torch.utils.checkpoint.CheckpointFunction.apply(function, preserve, *args)
else:
print ("reentran = FALSE")
gen = torch.utils.checkpoint._checkpoint_without_reentrant_generator(
function, preserve, context_fn, determinism_check, debug, *args, **kwargs
)
# Runs pre-forward logic
next(gen)
ret = function(*args, **kwargs)
# Runs post-forward logic
try:
next(gen)
except StopIteration:
return ret
params = {
"display_name": "Training PRO",
"is_tab": True
@ -61,10 +118,14 @@ non_serialized_params = {
"save_checkpoint_now": False,
"training_loop": False,
"current_stability": 0,
"save_epochs": 0,
"checkpoint_offset": 0,
"epoch_offset":0,
}
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to", "precize_slicing_overlap", "add_eos_token_type", "save_steps_under_loss", "add_bos_token", "training_projection","sliding_window","warmup_ratio","grad_accumulation"]
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to", "precize_slicing_overlap", "add_eos_token_type", "save_steps_under_loss", "add_bos_token", "training_projection","sliding_window","warmup_ratio","grad_accumulation","neft_noise_alpha"]
WANT_INTERRUPT = False
train_log = {}
@ -72,15 +133,24 @@ train_template = {}
train_log_graph = []
train_choices = ["all","q-k-v-o","q-k-v","k-v-down","q-v"]
statistics = {
'loss': [],
'lr': [],
}
RED = "\033[91m"
YELLOW = "\033[93m"
GREEN = "\033[92m"
RESET = "\033[0m"
def ui():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
tmp = gr.State('')
with gr.Row():
with gr.Column():
# YY.MM.DD
gr.Markdown("`Ver: 23.09.22` This is enhanced version of QLora Training. [Maintained by FP](https://github.com/FartyPants/Training_PRO/tree/main)")
gr.Markdown("`Ver: 23.10.20` This is enhanced version of QLora Training. [Maintained by FP](https://github.com/FartyPants/Training_PRO/tree/main)")
with gr.Row():
with gr.Column(scale=5):
@ -103,20 +173,19 @@ def ui():
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
batch_size = gr.Slider(visible= False, label='Batch Size', value=0, minimum=0, maximum=1024, step=4, info='Now Replaced with Gradient accumulation. Keeping it for sake of old saved data')
micro_batch_size = gr.Slider(label='True Batch Size', value=4, minimum=1, maximum=128, step=1, info='Specifies how many text blocks per step will be trained. The higher value, the better the concept of training will be, but it requires more GPU memory and it reduces speed.')
grad_accumulation = gr.Slider(label='Gradient Accumulation Steps', value=1, minimum=1, maximum=256, step=1, info="Virtually multiplies the Batch Size by averaging the learning over more than one step. Evens out loss fluctuations but also increases number of total steps.")
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
grad_accumulation = gr.Slider(label='Gradient Accumulation Steps', value=1, minimum=1, maximum=256, step=1, info="Virtually multiplies the Batch Size by averaging the learning over more than one step. VRAM friendly. Evens out loss fluctuations but can also degrade training fidelity.")
with gr.Column():
stop_at_loss = gr.Slider(label='Stop at loss (Can be changed during training)', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached.')
gr.Markdown(" ")
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt', 'FP_low_epoch_annealing', 'FP_half_time_annealing'], info='Learning rate scheduler - defines how the learning rate changes over time. Custom schedulers: `FP_low_epoch_annealing` constant for 1 epoch then cosine anneal. `FP_half_time_annealing` constant for half time then cosine anneal', elem_classes=['slim-dropdown'])
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt', 'FP_low_epoch_annealing', 'FP_half_time_annealing','FP_raise_fall_creative'], info='Learning rate scheduler - defines how the learning rate changes over time. Custom schedulers: FP_low_epoch_annealing, FP_half_time_annealing, FP_raise_fall_creative (see README)', elem_classes=['slim-dropdown'])
with gr.Accordion(label='Checkpoints', open=True):
with gr.Row():
with gr.Column():
save_steps = gr.Number(label='Save every n steps', value=0, info='A checkpoint will be saved every n steps. (0 = OFF)')
save_steps = gr.Number(label='Save every n steps', value=0, info='A checkpoint will be saved every n steps and at each Epoch boundary. (0 = OFF)')
with gr.Column():
save_steps_under_loss = gr.Slider(label='Save at 10% Loss change', value=1.8, minimum=0.0, maximum=3.0, step=0.1, info="Saves checkpoints at (or bellow) this loss and then each time loss falls by at least 10% This works independently from 'Save every n steps'")
with gr.Row():
@ -125,9 +194,9 @@ def ui():
with gr.Accordion(label='Advanced Options', open=True):
with gr.Row():
with gr.Column():
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='Number of max steps used for a linear warmup. Value different than 0 has precedent over Warmup Ratio. The actual number of steps will be the closest multiple of graddient accumulation')
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='Number of max steps used for a linear warmup. Reduces early over-fitting by the first training blocks. Value has precedent over Warmup Ratio. Aligns to the closest multiple of graddient accumulation')
warmup_ratio = gr.Slider(label='Warmup Ratio', minimum=0.0, maximum=0.2, step=0.025, value=0.0, info='Ratio of total training steps that will be used for a linear warmup. It applies only if Warmup Step is 0.')
neft_noise_alpha = gr.Slider(label='NEFtune noise scale', minimum=0.0, maximum=15, step=1, value=0.0, info='Add noise to the training to improve generalization. [0 - OFF, Starting value to experiment: 5]')
training_projection = gr.Radio(value = train_choices[4], label='LLaMA Target Projections', info='Change the targets (LORA is typically q-v)', choices=train_choices)
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.', elem_classes=['slim-dropdown'])
@ -136,33 +205,42 @@ def ui():
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
add_bos_token = gr.Checkbox(label='Add BOS token', value=True, info="Adds BOS token for each dataset item")
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item")
add_eos_token_type = gr.Dropdown(label='EOS placement (raw text)', choices=['Every Block', 'Hard Cut Blocks Only'], value='Every Block', info='', allow_custom_value = False)
add_eos_token_type = gr.Dropdown(label='EOS placement (Text file)', choices=['Every Block', 'Hard Cut Blocks Only'], value='Every Block', info='', allow_custom_value = False)
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
# for future
#with gr.Accordion(label='Dynamic Scheduler', open = False):
# ds_min_epochs = gr.Number(label='Minimum Epochs', value='1', info='Minimum epochs that will be always performed before ramp down can be triggered')
# ds_max_epochs = gr.Number(label='Maximum Epochs (fallback)', value='50', info='Maximum Epochs before the training will bail out completely (should be a large number)')
# ds_loss_trigger = gr.Slider(label='Trigger Loss', minimum=0.0, maximum=2.8, step=0.1, value=1.6, info='Loss at which the ramp down schedule will be triggered')
# ds_loss_rolling_window = gr.Number(label='Loss rolling average', value='4', info='Calculate loss by averaging last x numbers to avoid jumps and noise')
# ds_epochs_to_ramp = gr.Slider(label='Ramp down ratio', minimum=0.0, maximum=2.0, step=0.1, value=1.00, info='How long the ramp down will last relative to ellapsed steps (before trigger)')
# gr.Markdown('These are settings for FP_dynamic_loss_trigger scheduler. The scheduler will do warm up, then hold constant untill a loss falls under Trigger Loss, then it will commence linear ramp down schedule and stop. The length of ramp down is set by Ramp down ratio where (ramp down steps) = ratio * (elapsed steps). (The time to completition shown will be very high untill ramp down is triggered.)')
with gr.Column():
with gr.Tab(label='Formatted Dataset'):
with gr.Row():
with gr.Column():
with gr.Row():
dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown'])
create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
dataset = gr.Dropdown(choices=get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown'])
create_refresh_button(dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
with gr.Row():
eval_dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'])
create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
eval_dataset = gr.Dropdown(choices=get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'])
create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
with gr.Column():
with gr.Row():
format = gr.Dropdown(choices=utils.get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown'])
create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('training/formats', 'json')}, 'refresh-button')
format = gr.Dropdown(choices=get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown'])
create_refresh_button(format, lambda: None, lambda: {'choices': get_datasets('training/formats', 'json')}, 'refresh-button')
with gr.Row():
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
with gr.Tab(label="Raw text file"):
with gr.Tab(label="Text file"):
with gr.Row():
raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.', elem_classes=['slim-dropdown'])
create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The text file to use for training.', elem_classes=['slim-dropdown'])
create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
with gr.Row():
with gr.Column():
@ -173,22 +251,40 @@ def ui():
with gr.Column():
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a cut between logical blocks of text (ex. Ideas or Chapters). Helps prevent unwanted overlap between unrelated ideas.')
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Text blocks that have less or equal characters than this number.')
with gr.Tab(label="URL"):
with gr.Row():
with gr.Column():
check_dataset_btn = gr.Button('Load and Check Dataset and suggest data entries')
download_file_url = gr.Textbox(label='Download JSON or txt file to datasets (or formats) folder', value='',info='The URL of a file to download. If on github, make sure you get url of the raw file (https://raw.githubusercontent.com/...). If huggin face, make sure the url has /resolve/ in it not /blob/')
with gr.Row():
download_check_overwrite = gr.Checkbox(label='Overwrite', value=False, info='Overwrite if file exist')
download_folder = gr.Radio(label="Destination", value='training/datasets', choices=['training/datasets', 'training/formats'], interactive=True)
download_button = gr.Button('Download')
download_status = gr.Textbox(label='Download Status', value='', interactive=False)
with gr.Row():
with gr.Column():
with gr.Row():
cutoff_len = gr.Slider(label='Chunk Length (Cutoff Length)', minimum=32, maximum=2048, value=256, step=32, info='The maximum length of a chunk (in tokens). Applies to both JSON dataset and text files. Higher values require much more VRAM.')
with gr.Row():
with gr.Column():
check_dataset_btn = gr.Button('Verify Dataset/Text File and suggest data entries')
check_dataset_txt = gr.Textbox(label='Dataset info', value='')
with gr.Row():
start_button = gr.Button("Start LoRA Training", variant='primary')
stop_button = gr.Button("Interrupt")
with gr.Accordion(label="Graph", open=True):
with gr.Row():
# show_actions_button = False - we use old gradio
plot_graph = gr.LinePlot(x="epoch", y="value", title="Loss Metrics", overlay_point=True, tooltip=["epoch", "value"], x_lim=[0, 1], y_lim=[0, 3.5], width=500, height=250)
output = gr.Markdown(value="Ready")
with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'):
with gr.Row():
with gr.Column():
models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True)
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.')
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.')
with gr.Row():
with gr.Column():
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
@ -210,7 +306,7 @@ def ui():
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
# Training events
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to, precize_slicing_overlap, add_eos_token_type, save_steps_under_loss, add_bos_token, training_projection,sliding_window,warmup_ratio,grad_accumulation]
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to, precize_slicing_overlap, add_eos_token_type, save_steps_under_loss, add_bos_token, training_projection,sliding_window,warmup_ratio,grad_accumulation, neft_noise_alpha]
def fix_old_version(batch_size_val,micro_batch_size_val, grad_accumulation_val):
if batch_size_val>0:
@ -220,8 +316,9 @@ def ui():
return grad_accumulation_val
copy_from.change(do_copy_params, [copy_from] + all_params, all_params).then(fix_old_version,[batch_size,micro_batch_size, grad_accumulation],grad_accumulation)
start_button.click(do_train, all_params, output)
copy_from.change(partial(do_copy_params, all_params= all_params), copy_from, all_params).then(fix_old_version,[batch_size,micro_batch_size, grad_accumulation],grad_accumulation)
start_button.click(do_train, all_params, [output,plot_graph])
stop_button.click(do_interrupt, None, None, queue=False)
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
@ -241,20 +338,27 @@ def ui():
print("Use during the training to save the checkpoint at any time.")
def update_button():
return gr.Button.update('[Checkpoint in Queue]', variant='stop', interactive=True)
save_chackpoint_now.click(trigger_save_checkpoint, None, None)
def update_button2():
time.sleep(1.0)
return gr.Button.update('Queue Checkpoint Now', variant='secondary',interactive = True)
save_chackpoint_now.click(trigger_save_checkpoint, None, None).then(update_button, None,save_chackpoint_now).then(update_button2, None,save_chackpoint_now)
dataset_calc_params = [save_steps,micro_batch_size, epochs, cutoff_len, dataset, format, raw_text_file, warmup_steps, hard_cut_string, min_chars, precize_slicing_overlap,sliding_window,warmup_ratio,grad_accumulation]
def check_dataset(save_steps:int, micro_batch_size: int, epochs: int, cutoff_len: int, dataset:str, format:str, raw_text_file:str, warmup_steps:int, hard_cut_string:str, min_chars:int, precize_slicing_overlap:bool,sliding_window:bool,warmup_ratio:float,grad_accumulation:int):
result = "Specify JSON dastaset or raw text file"
result = "Specify JSON dastaset or Text file"
total_blocks = 0
if shared.tokenizer is None:
yield "Tokenizer is not available. Please Load some Model first."
return
if raw_text_file not in ['None', '']:
logger.info("Loading raw text file dataset...")
logger.info("Loading Text file...")
fullpath = clean_path('training/datasets', f'{raw_text_file}')
fullpath = Path(fullpath)
if fullpath.is_dir():
@ -268,8 +372,12 @@ def ui():
logger.info(f"Loaded training file: {file_path.name}")
else:
try:
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read().replace('\r', '')
except:
yield f"{raw_text_file}.txt doesn't seem to exsist anymore... check your training/datasets folder"
return
if min_chars<0:
@ -282,12 +390,12 @@ def ui():
text_chunks = precise_cut(raw_text, precize_slicing_overlap, min_chars, False, cutoff_len, hard_cut_string,non_serialized_params['debug_slicer'])
total_blocks = len(text_chunks)
result = f"Raw Text: ({raw_text_file}.txt) has {total_blocks} blocks (with cutoff length = {cutoff_len})"
result = f"Text: ({raw_text_file}.txt) has {total_blocks} blocks (Block Size {cutoff_len} tokens)"
del text_chunks
else:
if dataset in ['None', '']:
yield "Select dataset or Raw text."
yield "Select dataset or text file."
return
if format in ['None', '']:
@ -323,10 +431,26 @@ def ui():
logger.info("Loading JSON datasets...")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
data_keys = []
if data:
if 'train' in data: # Check if the 'train' split exists in the dataset
data_keys = list(data['train'][0].keys())
print("Data Keys:", data_keys)
else:
print("The dataset is empty.")
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
total_blocks = train_data.num_rows
result = f"Dataset: ({dataset}.json) has {total_blocks} blocks (with cutoff length = {cutoff_len})"
result = f"Dataset: ({dataset}.json) has {total_blocks} blocks @ length = {cutoff_len} tokens\n(Keys: {data_keys} - Format: {format}.json): "
#for options, data in format_data.items():
# format_keys = options.split(',')
# result += f"{format_keys}, "
#result = result.rstrip()
#result = result.rstrip(',')
if total_blocks>0:
number_ofSteps = int(math.ceil(total_blocks / micro_batch_size) * epochs)
@ -340,11 +464,13 @@ def ui():
save_each_n_max = int(math.ceil(number_ofSteps/5))
gradient_accumulation_max = int(total_blocks)//micro_batch_size
result += f"\n[Batch Size: {micro_batch_size}, Epochs: {epochs}, Gradient Accumulation: {grad_accumulation}]\n"
result += f"Total number of steps: {number_ofSteps}\n"
result += f"Steps per each Epoch: {num_stepsPer_epoch}\n"
result += f"Warmup steps suggestion: {warmup_steps_suggest} (Current: {int(warmup_steps)})\n"
result += f"Checkpoint suggestion: Save every {save_each_n_min} - {save_each_n_max} steps (Current: {int(save_steps)})"
result += f"Suggestions:\n"
result += f"Checkpoints: Save every {save_each_n_min} - {save_each_n_max} steps (Current: {int(save_steps)})\n"
result += f"Warmup steps: {warmup_steps_suggest} (Current: {int(warmup_steps)})"
if gradient_accumulation_max < grad_accumulation:
result += f"\n\nWARNING: Gradient Accumulation {grad_accumulation} is too high: It should be below {gradient_accumulation_max}"
@ -378,19 +504,34 @@ def ui():
sort_byTime.change(lambda x: non_serialized_params.update({"Lora_sortedByTime": x}), sort_byTime, None).then(reload_lora,None,copy_from)
#debug_slicer.change(lambda x: non_serialized_params.update({"debug_slicer": x}), debug_slicer, None)
def update_dataset():
return gr.update(choices=get_datasets('training/datasets', 'json')), gr.update(choices=get_datasets('training/datasets', 'txt'))
download_button.click(download_file_from_url, [download_file_url,download_check_overwrite,download_folder] , download_status).then(update_dataset,None,[dataset , raw_text_file])
def get_datasets(path: str, ext: str):
# include subdirectories for raw txt files to allow training from a subdirectory of txt files
#if ext == "txt":
# return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt')) + list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
def do_interrupt():
global WANT_INTERRUPT
WANT_INTERRUPT = True
def do_copy_params(lora_name: str, *args):
def do_copy_params(lora_name: str, all_params):
if lora_name:
f_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json"
if Path(f_name).is_file():
with open(f_name, 'r', encoding='utf-8') as format_file:
params: dict[str, str] = json.load(format_file)
else:
params = {}
else:
params = {}
result = list()
for i in range(0, len(PARAMETERS)):
@ -398,7 +539,7 @@ def do_copy_params(lora_name: str, *args):
if key in params:
result.append(params[key])
else:
result.append(args[i])
result.append(all_params[i])
return result
@ -462,7 +603,8 @@ def calc_trainable_parameters(model):
return trainable_params, all_param
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str, precize_slicing_overlap: bool, add_eos_token_type: str, save_steps_under_loss: float, add_bos_token: bool, training_projection: str,sliding_window:bool,warmup_ratio:float, grad_accumulation: int):
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str, precize_slicing_overlap: bool, add_eos_token_type: str, save_steps_under_loss: float, add_bos_token: bool, training_projection: str,sliding_window:bool,warmup_ratio:float, grad_accumulation: int,neft_noise_alpha:float):
if shared.args.monkey_patch:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@ -470,14 +612,20 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
)
replace_peft_model_with_int4_lora_model()
global train_log_graph
global WANT_INTERRUPT
WANT_INTERRUPT = False
statistics['loss'] = []
statistics['loss'].append({'epoch': 0, 'value': 0})
zero_pd = pd.DataFrame(statistics['loss'])
# == Input validation / processing ==
yield "Preparing the input..."
yield "Preparing the input...", zero_pd
lora_file_path = clean_path(None, lora_name)
if lora_file_path.strip() == '':
yield "Missing or invalid LoRA file name input."
yield "Missing or invalid LoRA file name input.", zero_pd
return
lora_file_path = f"{Path(shared.args.lora_dir)}/{lora_file_path}"
@ -490,23 +638,23 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
model_id = "llama"
if model_type == "PeftModelForCausalLM":
if len(shared.lora_names) > 0:
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*", zero_pd
logger.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
else:
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*", zero_pd
logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
else:
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*", zero_pd
logger.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
time.sleep(5)
if shared.args.loader == 'GPTQ-for-LLaMa' and not shared.args.monkey_patch:
yield "LoRA training with GPTQ-for-LLaMa requires loading with `--monkey-patch`"
yield "LoRA training with GPTQ-for-LLaMa requires loading with `--monkey-patch`", zero_pd
return
if cutoff_len <= 0 or micro_batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
yield "Cannot input zeroes."
yield "Cannot input zeroes.", zero_pd
return
#in new version we dumped this in favor of grad_accumulation
@ -566,20 +714,40 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
train_template.clear()
#reset stuff
print(f"*** LoRA: {lora_name} ***")
non_serialized_params.update({"stop_at_loss": stop_at_loss})
non_serialized_params.update({"save_steps_under_loss": save_steps_under_loss+0.01})
non_serialized_params.update({"save_checkpoint_now": False})
non_serialized_params.update({"training_loop": False})
non_serialized_params.update({"current_stability": 0})
non_serialized_params.update({"save_epochs": 0})
non_serialized_params.update({"checkpoint_offset": 0})
non_serialized_params.update({"epoch_offset": 0})
train_log_graph.clear()
# === once fixed, this can be removed ==============================
if hasattr(torch.utils.checkpoint, 'noop_context_fn'):
print("Testing Pytorch...")
old_checkpoint_signature = inspect.signature(torch.utils.checkpoint.checkpoint)
# Get the signature of your new checkpoint function
my_checkpoint_signature = inspect.signature(my_checkpoint)
# Check if the signatures match
if old_checkpoint_signature.parameters == my_checkpoint_signature.parameters:
print(F"{RED}Overriding Torch checkpoint function to avoid repeated 'use_reentrant not explicitly set' warnings{RESET}")
#print(" - Note: Transformers need to pass use_reentrant in llama.modeling_llama in def forward, layer_outputs = torch.utils.checkpoint.checkpoint")
#print(" Once they do, this function can be removed")
torch.utils.checkpoint.checkpoint = my_checkpoint
# END OF FPHAM SENTENCE SPLIT functions ===================
# == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']:
train_template["template_type"] = "raw_text"
logger.info("Loading raw text file dataset...")
logger.info("Loading text file...")
fullpath = clean_path('training/datasets', f'{raw_text_file}')
fullpath = Path(fullpath)
if fullpath.is_dir():
@ -621,11 +789,11 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
eval_data = None
else:
if dataset in ['None', '']:
yield "Missing dataset choice input, cannot continue."
yield "Missing dataset choice input, cannot continue.", zero_pd
return
if format in ['None', '']:
yield "Missing format choice input, cannot continue."
yield "Missing format choice input, cannot continue.", zero_pd
return
train_template["template_type"] = "dataset"
@ -670,8 +838,11 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if selected_model:
print("\033[1;31;1m(Model has been modified by previous training, it needs to be reloaded...)\033[0;37;0m")
try:
yield f"Reloading {selected_model}..."
yield f"Reloading {selected_model}...", zero_pd
reload_model()
shared.tokenizer.pad_token_id = 0
shared.tokenizer.padding_side = "left"
if shared.model is not None:
print("Model reloaded OK, continue with training.")
else:
@ -685,20 +856,23 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logger.info("Getting model ready...")
# here we can disable gradient checkpoint, by default = true, use_gradient_checkpointing=True
prepare_model_for_kbit_training(shared.model)
# base model is now frozen and should not be reused for any other LoRA training than this one
shared.model_dirty_from_training = True
print(f"Transformers Model Type: {YELLOW}{model_type}{RESET}")
if training_projection==train_choices[0]:
model_to_lora_modules["llama"] = ["gate_proj","down_proj","up_proj","q_proj","k_proj","v_proj","o_proj"]
model_to_lora_modules[model_id] = ["gate_proj","down_proj","up_proj","q_proj","k_proj","v_proj","o_proj"]
elif training_projection==train_choices[1]:
model_to_lora_modules["llama"] = ["q_proj","k_proj", "v_proj", "o_proj"]
model_to_lora_modules[model_id] = ["q_proj","k_proj", "v_proj", "o_proj"]
elif training_projection==train_choices[2]:
model_to_lora_modules["llama"] = ["q_proj","k_proj", "v_proj"]
model_to_lora_modules[model_id] = ["q_proj","k_proj", "v_proj"]
elif training_projection==train_choices[3]:
model_to_lora_modules["llama"] = ["k_proj", "v_proj", "down_proj"]
model_to_lora_modules[model_id] = ["k_proj", "v_proj", "down_proj"]
else:
model_to_lora_modules["llama"] = ["q_proj", "v_proj"]
model_to_lora_modules[model_id] = ["q_proj", "v_proj"]
logger.info("Preparing for training...")
@ -725,8 +899,34 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
logger.info("Loading existing LoRA data...")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
set_peft_model_state_dict(lora_model, state_dict_peft)
print(f" + Continue Training on {RED}{lora_file_path}/adapter_model.bin{RESET}")
#load training_log.json if exist
if Path(f"{lora_file_path}/training_log.json").is_file():
with open(f"{lora_file_path}/training_log.json", 'r') as json_file:
json_ilog = json.load(json_file)
for key, value in json_ilog.items():
if key=='current_steps':
non_serialized_params.update({"checkpoint_offset": int(value+1)})
print(f" + Checkpoints will be saved with offset: {RED}{non_serialized_params['checkpoint_offset']}{RESET}")
if key=='epoch':
non_serialized_params.update({"epoch_offset": value})
print(f" + Epoch offset: {RED}{non_serialized_params['epoch_offset']}{RESET}")
if Path(f"{lora_file_path}/training_graph.json").is_file():
try:
with open(f"{lora_file_path}/training_graph.json", 'r') as json_file:
train_log_graph = json.load(json_file)
print(" + Training Graph loaded")
except:
yield traceback.format_exc().replace('\n', '\n\n')
print(f"Can't read training_graph")
except:
yield traceback.format_exc().replace('\n', '\n\n'), zero_pd
return
if shared.args.monkey_patch:
@ -751,30 +951,36 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
tracked.current_steps = state.global_step * gradient_accumulation_steps
tracked.max_steps = state.max_steps * gradient_accumulation_steps
ssteps10 = int(max(2,(state.max_steps/epochs)*0.1))
if WANT_INTERRUPT:
control.should_epoch_stop = True
control.should_training_stop = True
else:
current_loss = float(train_log.get('loss', 0.0))
current_epoch = float(train_log.get('epoch', 0.0))
current_epoch_int = int(float(train_log.get('epoch', 0.0)))
force_save = False
folder_save = f"checkpoint-{tracked.current_steps}"
current_steps_offset = tracked.current_steps + non_serialized_params['checkpoint_offset']
folder_save = f"checkpoint-{current_steps_offset}"
# save if triggered by user
if non_serialized_params['save_checkpoint_now']:
force_save = True
non_serialized_params.update({"save_checkpoint_now": False})
print(f"\033[1;31;1mSave Checkpoint manually trigerred.\033[0;37;0m")
folder_save = f"checkpoint-{tracked.current_steps}-user"
folder_save = f"checkpoint-{current_steps_offset}-user"
patience = 3 # Set the number of consecutive steps for tracking stability
if gradient_accumulation_steps==1:
patience = 5
patience = 4
min_steps = 10
min_steps = ssteps10
# Save each time the loss is below the threshold
if current_loss < non_serialized_params['save_steps_under_loss'] and current_loss > 0 and state.global_step > min_steps:
current_stability = non_serialized_params['current_stability']
current_stability += 1
@ -789,7 +995,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
new_save = (current_loss_dec-0.1) + 0.01
non_serialized_params.update({"save_steps_under_loss": new_save})
folder_save = f"checkpoint-{tracked.current_steps}-loss-{loss_str}"
folder_save = f"checkpoint-{current_steps_offset}-loss-{loss_str}"
force_save = True
@ -797,8 +1003,25 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# Reset stability if the loss goes above the threshold
non_serialized_params.update({"current_stability": 0})
# Save full epochs
if actual_save_steps>0 and current_epoch_int > non_serialized_params['save_epochs'] and state.global_step > min_steps:
current_epoch_offset = current_epoch_int
if non_serialized_params['epoch_offset'] > 0:
current_epoch_offset = current_epoch_int + round(non_serialized_params['epoch_offset'], 2)
ep_off_str = f"{current_epoch_offset}"
ep_off_str = ep_off_str.replace('.', '_')
folder_save = f"checkpoint-{current_steps_offset}-epoch-{ep_off_str}"
non_serialized_params.update({"save_epochs": current_epoch_int})
force_save = True
# save each actual_save_steps
if state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
folder_save = f"checkpoint-{tracked.current_steps}"
folder_save = f"checkpoint-{current_steps_offset}"
force_save = True
if force_save:
@ -820,21 +1043,45 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs):
train_log.update(logs)
current_steps_offset = tracked.current_steps + non_serialized_params['checkpoint_offset']
current_epoch_offset = train_log.get('epoch', 0.0) + non_serialized_params['epoch_offset']
train_log.update({"current_steps": tracked.current_steps})
train_log.update({"current_steps_adjusted": current_steps_offset})
train_log.update({"epoch_adjusted": current_epoch_offset})
if WANT_INTERRUPT:
print("\033[1;31;1mInterrupted by user\033[0;37;0m")
if non_serialized_params['checkpoint_offset']>0:
print(f"\033[1;30;40mStep: {tracked.current_steps:6} [+{non_serialized_params['checkpoint_offset']}] \033[0;37;0m", end='')
else:
print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m", end='')
entry = {
'current_steps': int(train_log.get('current_steps',0)),
graphentry = {
'current_steps': int(train_log.get('current_steps_adjusted',0)),
'loss': float(train_log.get('loss', 0.0)),
'learning_rate': float(train_log.get('learning_rate', 0.0)),
'epoch': float(train_log.get('epoch', 0.0))
'epoch': float(train_log.get('epoch_adjusted', 0.0))
}
cur_loss = float(train_log.get('loss', 0.0))
cur_lr = float(train_log.get('learning_rate', 0.0))
cur_epoch = float(train_log.get('epoch', 0.0))
if len(statistics['loss']) == 1:
first_epoch = statistics['loss'][0]['epoch']
first_value = statistics['loss'][0]['value']
if first_value ==0:
statistics['loss'] = []
statistics['loss'].append({'epoch': cur_epoch, 'value': cur_loss})
statistics['lr'].append({'epoch': cur_epoch, 'value': cur_lr})
# Add the entry to the continuous log
train_log_graph.append(entry)
train_log_graph.append(graphentry)
# Save the graph log for now, we can later generate full graph
with open(f"{lora_file_path}/training_graph.json", 'w') as file:
@ -845,22 +1092,22 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if loss <= stop_at_loss:
control.should_epoch_stop = True
control.should_training_stop = True
print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m")
print(f"{RED}Stop Loss {stop_at_loss} reached.{RESET}")
# FPHAM SAMPLE REQ Transformers error handling
gradient_accumulation_max = int(train_data.num_rows)//micro_batch_size
if gradient_accumulation_max < gradient_accumulation_steps:
print(f"\033[1;31;1mWARNING: Current gradient accumulation is too high for the amount of training data.\033[0;37;0m")
print(f"Gradient accumulation: {gradient_accumulation_steps} should be less than: {gradient_accumulation_max}. \033[1;31;1mThis could crash Accelerate/Transformers\033[0;37;0m")
print(f"{RED}WARNING:{RESET} Current gradient accumulation is {RED}too high{RESET} for the amount of training data.")
print(f"Gradient accumulation: {gradient_accumulation_steps} should be less than: {gradient_accumulation_max}. {RED}This could crash Accelerate/Transformers{RESET}")
#min_batchSize = sample_req*micro_batch_size
print(f"Preferable fix: \033[1;31;1mIncrease the size of dataset\033[0;37;0m")
print(f"... or Decrerase Gradient Accumulation \033[1;31;1m{gradient_accumulation_steps}\033[0;37;0m to below {gradient_accumulation_max}")
print(f"Preferable fix: {RED}Increase the size of dataset{RESET}")
print(f"... or Decrerase Gradient Accumulation {RED}{gradient_accumulation_steps}{RESET} to below {GREEN}{gradient_accumulation_max}{RESET}")
gradient_accumulation_steps = max(1,gradient_accumulation_max-1)
print(f"Last resort fix for this run: Lowering Gradient accumulation to {gradient_accumulation_steps}. [Good luck]")
print(f"Last resort fix for this run: Lowering Gradient accumulation to {GREEN}{gradient_accumulation_steps}{RESET} [Good luck]")
else:
print(f"Data Size Check: Gradient accumulation: {gradient_accumulation_steps} <= Blocks/Batch {gradient_accumulation_max} ... [OK]")
print(f"Data Size Check: Gradient accumulation: {YELLOW}{gradient_accumulation_steps}{RESET} <= Blocks/Batch {gradient_accumulation_max} ... {GREEN}[OK]{RESET}")
#END OF FPHAM SAMPLE REQ
@ -874,6 +1121,11 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
elif lr_scheduler_type == 'FP_half_time_annealing':
custom_scheduller = True
lr_scheduler_type_arg = 'constant'
elif lr_scheduler_type =='FP_raise_fall_creative':
custom_scheduller = True
lr_scheduler_type_arg = 'constant_with_warmup'
#gradient_checkpointing=True
args=transformers.TrainingArguments(
report_to=report_to if report_to != "None" else None,
@ -899,6 +1151,17 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if custom_scheduller:
trainer = FPSchedulerTrainer(
neftune_noise_alpha=neft_noise_alpha,
model=lora_model,
train_dataset=train_data,
eval_dataset=eval_data,
args=args,
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
callbacks=list([Callbacks()])
)
elif neft_noise_alpha > 0:
trainer = FPNEFtuneTrainer(
neftune_noise_alpha=neft_noise_alpha,
model=lora_model,
train_dataset=train_data,
eval_dataset=eval_data,
@ -934,28 +1197,37 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Main run and monitor loop ==
logger.info("Starting training...")
yield "Starting..."
yield "Starting...", zero_pd
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
projections_string = ", ".join([projection.replace("_proj", "") for projection in model_to_lora_modules[model_id]])
print(f"Training '{model_id}' model using ({projections_string}) projections")
print(f"Training '{model_id}' model using {YELLOW}({projections_string}){RESET} projections")
if lora_all_param > 0:
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
print(f"Trainable params: {lora_trainable_param:,d} ({RED}{100 * lora_trainable_param / lora_all_param:.4f} %{RESET}), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
train_log.update({"base_model_name": shared.model_name})
train_log.update({"base_model_class": shared.model.__class__.__name__})
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
train_log.update({"projections": projections_string})
if non_serialized_params['checkpoint_offset'] > 0:
train_log.update({"last_run_steps_offset": non_serialized_params['checkpoint_offset']})
train_log.update({"last_run_epoch_offset": non_serialized_params['epoch_offset']})
if non_serialized_params['checkpoint_offset'] > 0:
print(f"Continue training on {RED}previous adapter{RESET} from epoch: {RED}{non_serialized_params['epoch_offset']}{RESET}")
if stop_at_loss > 0:
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
print(f"Monitoring loss {RED}(Auto-Stop at: {stop_at_loss}){RESET}")
if WANT_INTERRUPT:
yield "Interrupted before start."
yield "Interrupted before start.", zero_pd
return
def log_train_dataset(trainer):
@ -993,8 +1265,28 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
while thread.is_alive():
time.sleep(0.5)
if statistics['loss']:
max_value_dict = max(statistics['loss'], key=lambda x: x['value'])
max_value = max_value_dict['value']+0.4
first_epoch = statistics['loss'][0]['epoch']
last_epoch = statistics['loss'][-1]['epoch']
else:
max_value = 3.5
last_epoch = 0
first_epoch = 0
if WANT_INTERRUPT:
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
losses = gr.LinePlot.update(
value = pd.DataFrame(statistics['loss']),
x="epoch", y="value",
title="Loss Metrics",
overlay_point=True, tooltip=["epoch", "value"],
x_lim=[first_epoch,last_epoch], y_lim=[0,max_value],
width=500, height=250 )
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*", losses
elif tracked.current_steps != last_step:
last_step = tracked.current_steps
@ -1022,12 +1314,41 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if stop_at_loss != non_serialized_params['stop_at_loss']:
stop_at_loss = non_serialized_params['stop_at_loss']
print(f"Stop at loss changed \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
print(f"Stop at loss changed {RED}(Auto-Stop at: {stop_at_loss}){RESET}")
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining {lastloss_str}"
losses = gr.LinePlot.update(
value = pd.DataFrame(statistics['loss']),
x="epoch", y="value",
title="Loss Metrics",
overlay_point=True, tooltip=["epoch", "value"],
x_lim=[first_epoch,last_epoch], y_lim=[0,max_value],
width=500, height=250 )
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining {lastloss_str}", losses
# Saving in the train thread might fail if an error occurs, so save here if so.
#return_pd = pd.DataFrame(statistics['loss'])
if statistics['loss']:
max_value_dict = max(statistics['loss'], key=lambda x: x['value'])
max_value = max_value_dict['value']+0.4
first_epoch = statistics['loss'][0]['epoch']
last_epoch = statistics['loss'][-1]['epoch']
else:
max_value = 3.5
last_epoch = 0
first_epoch = 0
return_pd = gr.LinePlot.update(
value = pd.DataFrame(statistics['loss']),
x="epoch", y="value",
title="Loss Metrics",
overlay_point=True, tooltip=["epoch", "value"],
x_lim=[first_epoch,last_epoch], y_lim=[0,max_value],
width=500, height=250)
non_serialized_params.update({"training_loop": False})
if not tracked.did_save:
@ -1036,10 +1357,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if WANT_INTERRUPT:
logger.info("Training interrupted.")
yield f"Interrupted by user. LoRA saved to `{lora_file_path}`."
yield f"Interrupted by user. LoRA saved to `{lora_file_path}`.", return_pd
else:
logger.info("Training complete!")
yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training."
yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training.", return_pd
create_graph(lora_file_path, lora_name)

View file

@ -1,13 +1,26 @@
import os
from modules import shared, utils
from pathlib import Path
import requests
import tqdm
import json
'''
def get_gpu_memory_usage(rank):
return {
'total': round(torch.cuda.get_device_properties(rank).total_memory / (1024**3), 2),
'max': round(torch.cuda.max_memory_allocated(rank) / (1024**3), 2),
'reserved': round(torch.cuda.memory_reserved(rank) / (1024**3), 2),
'allocated': round(torch.cuda.memory_allocated(rank) / (1024**3), 2)
}
'''
def list_subfoldersByTime(directory):
if not directory.endswith('/'):
directory += '/'
subfolders = []
subfolders.append('None')
path = directory
name_list = os.listdir(path)
full_list = [os.path.join(path,i) for i in name_list]
@ -277,3 +290,79 @@ def sliding_block_cut(text: str, min_chars_cut: int, eos_to_hc: bool, cutoff_len
print("Saved sentencelist.json in logs folder")
return sentencelist
# Example usage:
# download_file_from_url('https://example.com/path/to/your/file.ext', '/output/directory')
def download_file_from_url(url, overwrite, output_dir_in, valid_extensions = {'.txt', '.json'}):
try:
# Validate and sanitize the URL
#parsed_url = urllib.parse.urlparse(url)
#if not parsed_url.netloc:
# raise ValueError("Invalid URL")
#filename = os.path.basename(parsed_url.path)
# Get the filename from the URL
session = requests.Session()
headers = {}
mode = 'wb'
filename = url.split('/')[-1]
output_dir = str(output_dir_in)
# Construct the full path to the output file
local_filename = os.path.join(output_dir, filename)
# Check if the local file already exists
overw = ''
if os.path.exists(local_filename):
if not overwrite:
yield f"File '{local_filename}' already exists. Aborting."
return
else:
overw = ' [Overwrite existing]'
filename_lower = filename.lower()
# Send an HTTP GET request to the URL with a timeout
file_extension = os.path.splitext(filename_lower)[-1]
if file_extension not in valid_extensions:
yield f"Invalid file extension: {file_extension}. Only {valid_extensions} files are supported."
return
with session.get(url, stream=True, headers=headers, timeout=10) as r:
r.raise_for_status()
# total size can be wildly inaccurate
#total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 4
with open(local_filename, mode) as f:
count = 0
for data in r.iter_content(block_size):
f.write(data)
count += len(data)
yield f"Downloaded: {count} " + overw
# Verify file size if possible
if os.path.exists(local_filename):
downloaded_size = os.path.getsize(local_filename)
if downloaded_size > 0:
yield f"File '{filename}' downloaded to '{output_dir}' ({downloaded_size} bytes)."
print("File Downloaded")
else:
print("Downloaded file is zero")
yield f"Failed. Downloaded file size is zero)."
else:
print(f"Error: {local_filename} failed to download.")
yield f"Error: {local_filename} failed to download"
except Exception as e:
print(f"An error occurred: {e}")
yield f"An error occurred: {e}"
finally:
# Close the session to release resources
session.close()