Training PRO extension update (#4036)

This commit is contained in:
FartyPants 2023-09-22 17:51:31 -04:00 committed by GitHub
parent c5e0ab7174
commit 26f10854f3
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 536 additions and 69 deletions

View file

@ -0,0 +1,56 @@
# Training_PRO
This is an expanded Training tab
Maintained by FP
https://github.com/FartyPants/Training_PRO
- Chunking: precise raw text slicer (PRTS) uses sentence slicing and making sure things are clean on all ends
- overlap chunking - this special overlapping will make additional overlap block based on logical rules (aka no overlap block on hard cut)
- custom scheduler (follow the code to make your own) In LR Scheduler select FP_low_epoch_annealing - this scheduler will keep the LR constant for first epoch then use cosine for the rest - this part would be best to spawn into a new py file
- saves graph png file at the end with learning rate and loss per epoch
- adding EOS to each block or to hard cut only
- automatically lowers gradient accumulation if you go overboard and set gradient accumulation that will be higher than actual data - transformers would then throw error (or they used to, not sure if still true) but in any way, it will fix bad data
- turn BOS on and OFF
- target selector
- DEMENTOR LEARNING (experimental) Deep Memorization Enforcement Through Overlapping and Repetition. This is an experiment for long-text learning using low epochs (basically use 1 epoch with constant LR or 2 epochs with FP_low_epoch_annealing LR scheduler)
- Getting rid of micro batch size/batch size confusion. Now there is True Batch Size and Gradient accumulation slider, consisten with all the other training out there
- Ability to save Checkpoint during training with a button
- Ability to change Stop Loss during training
- different modes of checkpoint auto saving
- Function to Check Dataset and suggest parameters such as warmup and checkpoint save frequency before training
### Notes:
This uses it's own chunking code for raw text based on sentence splitting. This will avoid weird cuts in the chunks and each chunk should now start with sentence and end on some sentence. It works hand in hand with Hard Cut. A propper use is to structure your text into logical blocks (ideas) separated by three \n then use three \n in hard cut. This way each chunk will contain only one flow of ideas and not derail in the thoughts. And Overlapping code will create overlapped blocks on sentence basis too, but not cross hard cut, thus not cross different ideas either. Does it make any sense? No? Hmmmm...
### Targets
Normal LORA is q, v and that's what you should use. You can use (q k v o) or (q k v) and it will give you a lot more trainable parameters. The benefit is that you can keep rank lower and still attain the same coherency as q v with high rank. Guanaco has been trained with QLORA and q k v o for example and they swear by it.
### DEMENTOR LEARNING (experimental) Deep Memorization Enforcement Through Overlapping and Repetition
This is and experimental chunking to train long-form text in low number of epochs (basically 1) with sliding repetition. The depth of learning directly depends on the cutoff_length. Increasing cutoff length will also increase number of blocks created from long-form text (which is contrary to normal training). It is based on my own wild experiments.
### Getting rid of batch size and micro batch size
Keeping consistency with everyone else.
Listen, There is only ONE batch size - the True batch size (called previously micro-batch size in WebUI) - this is how many blocks are processed at once (during a single step). It eats GPU, but it really helps with the quality training (in fact the ideal batch size would be the same as number of blocks - which is unrealistic) - so the idea is to cram as much True Batch Size before your GPU blows with OOM. On 24GB this is about 10 for 13b (loaded with 4-bit)
So no micro batch size - it is now called True Batch Size, because that's what it is.
The other thing is Gradient Accumulation - this is an emulation of the above Batch size - a virtual batch size, if you will. If your GPU can't handle real batch size then you may fake it using Gradient Accumulation. This will accumulate the gradients over so many steps defined here and then update the weights at the end without increase in GPU.
Gradient accumulation is like a virtual Batch size multiplier without the GPU penalty.
If your batch size is 4 and your gradient accumulation is 2 then it sort of behaves as if we have batch size 8. *Sort of* because Batch size of 4 and GA of 2 is NOT the same as batch size of 2 and GA of 4. (It produces different weights - hence it's not an equivalent). The idea is that if you don't have GPU - using GA to extend batch size is the next best thing (good enough) since you have no other choice.
If all you can afford is 1 batch size, then increasing GA will likely make the learning better in some range of GA (it's not always more is better).
However - GA is not some golden goose. As said, it isn't the same as batch size. In fact GA may worsen your learning as well.
I would suggest a series of experiment where you would put batch size as high as possible without OOM, set GA 1, then repeat training while increasing the GA (2, 4...), and see how the model changes. It's likely that it would follow some sort of curve where GA will seem to help before it will make it worse. Some people believe that if you can squeeze 6 BATCH Size, then you should not bother with GA at all... YMMW
High Batch Size vs High GA would also likely produce different results in terms of learning words vs style. How? Hmmmm... good question.
One optical "benefit" of GA is that the loss will fluctuate less (because of all the gradient accumulation, which works as a form of noise smoothing as well).

View file

@ -8,6 +8,39 @@ from torch.optim.lr_scheduler import LambdaLR
#FPHAM custom training scheduller block - should be extracted to separate file
last_print_label = ''
# hold constant to the half of epochs then cosine down to 0
def _get_fp_half_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
global last_print_label
print_label = ''
half_steps = num_training_steps//2
num_warmup_steps = min(num_warmup_steps,half_steps)
if current_step < num_warmup_steps:
print_label = 'Scheduler: Warmup'
elif current_step < half_steps:
print_label = 'Scheduler: Hold'
else:
print_label = 'Scheduler: Annealing'
if print_label != last_print_label:
print(print_label)
last_print_label = print_label
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
if current_step < half_steps:
return 1.0
progress = float(current_step - half_steps) / float(max(1, num_training_steps - half_steps))
num_cycles = 0.5
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
# constant to the first epochs then cosine down to 0 over the rest epochs
def _get_fp_cosine_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
global last_print_label
@ -38,7 +71,7 @@ def _get_fp_cosine_schedule_with_warmup_lr_lambda(current_step: int, *, num_warm
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
def custom_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
def custom_cosine_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
"""
Args:
optimizer ([`~torch.optim.Optimizer`]):
@ -62,6 +95,30 @@ def custom_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def custom_half_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
"""
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_fp_half_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_firstepoch_steps = num_firstepoch_steps,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
class FPSchedulerTrainer(transformers.Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -69,22 +126,44 @@ class FPSchedulerTrainer(transformers.Trainer):
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
#Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument.
num_train_epochs = self.args.num_train_epochs
num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
num_firstepoch_steps = math.ceil(num_training_steps/num_train_epochs)
num_warmup_acc = num_warmup_steps*self.args.gradient_accumulation_steps
num_firstepoch_steps_acc = num_firstepoch_steps*self.args.gradient_accumulation_steps
num_training_steps_acc = num_training_steps*self.args.gradient_accumulation_steps
print (f"Warm-up steps aligned to Gradient accumulation ({self.args.gradient_accumulation_steps}) = {num_warmup_acc} actual warmup steps")
if self.args.lr_scheduler_type == 'cosine':
num_train_epochs = self.args.num_train_epochs
num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
num_firstepoch_steps = math.ceil(num_training_steps/num_train_epochs)
num_warmup_acc = num_warmup_steps*self.args.gradient_accumulation_steps
num_firstepoch_steps_acc = num_firstepoch_steps*self.args.gradient_accumulation_steps
num_training_steps_acc = num_training_steps*self.args.gradient_accumulation_steps
num_warmup_acc_min = min(num_warmup_acc, num_firstepoch_steps_acc)
num_warmup_acc_min = min(num_warmup_acc, num_firstepoch_steps_acc)
if num_warmup_acc>num_firstepoch_steps_acc:
print(f"\033[1;31;1mWARNING: The number of warmup steps is set too high! It will be clamped to 1 epoch, essentially going from warmup to annealing.\033[0;37;0m")
print (f"FP Scheduler Warmup: 0-[{num_warmup_acc_min}], Hold [{num_warmup_acc_min}]-{num_firstepoch_steps_acc}, Annealing {num_firstepoch_steps_acc}-{num_training_steps_acc}")
else:
print (f"FP Scheduler Warmup: 0-{num_warmup_acc_min}, Hold {num_warmup_acc_min}-{num_firstepoch_steps_acc}, Annealing {num_firstepoch_steps_acc}-{num_training_steps_acc}")
self.lr_scheduler = custom_scheduler_with_warmup(
self.lr_scheduler = custom_cosine_scheduler_with_warmup(
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_firstepoch_steps = num_firstepoch_steps,
)
self._created_lr_scheduler = True
return self.lr_scheduler
elif self.args.lr_scheduler_type == 'constant':
half_step_acc = num_training_steps_acc//2
num_warmup_acc_min = min(num_warmup_acc, half_step_acc)
if num_warmup_acc>half_step_acc:
print(f"\033[1;31;1mWARNING: The number of warmup steps is set too high! It will be clamped to half of all epochs, essentially going from warmup to annealing in the middle.\033[0;37;0m")
print (f"FP Scheduler Warmup: 0-[{num_warmup_acc_min}], Hold [{num_warmup_acc_min}]-{half_step_acc}, Annealing {half_step_acc}-{num_training_steps_acc}")
else:
print (f"FP Scheduler Warmup: 0-{num_warmup_acc_min}, Hold {num_warmup_acc_min}-{half_step_acc}, Annealing {half_step_acc}-{num_training_steps_acc}")
self.lr_scheduler = custom_half_scheduler_with_warmup(
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,

View file

@ -20,7 +20,7 @@ import transformers
from .custom_scheduler import FPSchedulerTrainer
from .matplotgraph import create_graph
from .train_utils import get_available_loras_local, precise_cut
from .train_utils import get_available_loras_local, precise_cut, sliding_block_cut
from datasets import Dataset, load_dataset
from peft import (
@ -53,14 +53,23 @@ params = {
"is_tab": True
}
non_serialized_params = {
"debug_slicer": False,
"Lora_sortedByTime": False,
"stop_at_loss": 0,
"save_steps_under_loss": 0.0,
"save_checkpoint_now": False,
"training_loop": False,
"current_stability": 0,
}
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to", "precize_slicing_overlap", "add_eos_token_type", "save_steps_under_loss", "add_bos_token", "training_projection"]
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to", "precize_slicing_overlap", "add_eos_token_type", "save_steps_under_loss", "add_bos_token", "training_projection","sliding_window","warmup_ratio","grad_accumulation"]
WANT_INTERRUPT = False
train_log = {}
train_template = {}
train_log_graph = []
Lora_sortedByTime = False
train_choices = ["all","q-k-v-o","q-k-v","k-v-down","q-v"]
@ -70,13 +79,14 @@ def ui():
tmp = gr.State('')
with gr.Row():
with gr.Column():
gr.Markdown("This is enhanced version of Lora Training with a sentence based RAW text chunking code")
# YY.MM.DD
gr.Markdown("`Ver: 23.09.22` This is enhanced version of QLora Training. [Maintained by FP](https://github.com/FartyPants/Training_PRO/tree/main)")
with gr.Row():
with gr.Column(scale=5):
with gr.Row():
copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=get_available_loras_local(Lora_sortedByTime), elem_classes=['slim-dropdown'])
create_refresh_button(copy_from, lambda: None, lambda: {'choices': get_available_loras_local(Lora_sortedByTime)}, 'refresh-button')
copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=get_available_loras_local(non_serialized_params['Lora_sortedByTime']), elem_classes=['slim-dropdown'])
create_refresh_button(copy_from, lambda: None, lambda: {'choices': get_available_loras_local(non_serialized_params['Lora_sortedByTime'])}, 'refresh-button')
with gr.Column():
sort_byTime = gr.Checkbox(label='Sort list by Date', value=False, info='Sorts Loras by date created.', elem_classes=['no-background'])
@ -91,28 +101,38 @@ def ui():
with gr.Column():
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.')
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
batch_size = gr.Slider(visible= False, label='Batch Size', value=0, minimum=0, maximum=1024, step=4, info='Now Replaced with Gradient accumulation. Keeping it for sake of old saved data')
micro_batch_size = gr.Slider(label='True Batch Size', value=4, minimum=1, maximum=128, step=1, info='Specifies how many text blocks per step will be trained. The higher value, the better the concept of training will be, but it requires more GPU memory and it reduces speed.')
grad_accumulation = gr.Slider(label='Gradient Accumulation Steps', value=1, minimum=1, maximum=256, step=1, info="Virtually multiplies the Batch Size by averaging the learning over more than one step. Evens out loss fluctuations but also increases number of total steps.")
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
with gr.Column():
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
save_steps_under_loss = gr.Slider(label='Save Loss Threshold', value=1.9, minimum=0.0, maximum=3.0, step=0.1, info='Save checkpoints only if the loss is less or equall Threshold loss. (0 = save all)')
stop_at_loss = gr.Slider(label='Stop at loss (Can be changed during training)', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached.')
gr.Markdown(" ")
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt', 'FP_low_epoch_annealing'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown'])
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt', 'FP_low_epoch_annealing', 'FP_half_time_annealing'], info='Learning rate scheduler - defines how the learning rate changes over time. Custom schedulers: `FP_low_epoch_annealing` constant for 1 epoch then cosine anneal. `FP_half_time_annealing` constant for half time then cosine anneal', elem_classes=['slim-dropdown'])
with gr.Accordion(label='Checkpoints', open=True):
with gr.Row():
with gr.Column():
save_steps = gr.Number(label='Save every n steps', value=0, info='A checkpoint will be saved every n steps. (0 = OFF)')
with gr.Column():
save_steps_under_loss = gr.Slider(label='Save at 10% Loss change', value=1.8, minimum=0.0, maximum=3.0, step=0.1, info="Saves checkpoints at (or bellow) this loss and then each time loss falls by at least 10% This works independently from 'Save every n steps'")
with gr.Row():
save_chackpoint_now = gr.Button('Queue Checkpoint Now')
with gr.Accordion(label='Advanced Options', open=True):
with gr.Row():
with gr.Column():
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='Number of max steps used for a linear warmup. Value different than 0 has precedent over Warmup Ratio. The actual number of steps will be the closest multiple of graddient accumulation')
warmup_ratio = gr.Slider(label='Warmup Ratio', minimum=0.0, maximum=0.2, step=0.025, value=0.0, info='Ratio of total training steps that will be used for a linear warmup. It applies only if Warmup Step is 0.')
training_projection = gr.Radio(value = train_choices[4], label='LLaMA Target Projections', info='Change the targets (LORA is typically q-v)', choices=train_choices)
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.', elem_classes=['slim-dropdown'])
with gr.Column():
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
add_bos_token = gr.Checkbox(label='Add BOS token', value=True, info="Adds BOS token for each dataset item")
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item")
@ -124,18 +144,20 @@ def ui():
with gr.Column():
with gr.Tab(label='Formatted Dataset'):
with gr.Row():
format = gr.Dropdown(choices=utils.get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown'])
create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('training/formats', 'json')}, 'refresh-button')
with gr.Column():
with gr.Row():
dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown'])
create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
with gr.Row():
eval_dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'])
create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
with gr.Row():
dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown'])
create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
with gr.Row():
eval_dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'])
create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
with gr.Column():
with gr.Row():
format = gr.Dropdown(choices=utils.get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown'])
create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('training/formats', 'json')}, 'refresh-button')
with gr.Row():
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
with gr.Tab(label="Raw text file"):
with gr.Row():
@ -144,10 +166,17 @@ def ui():
with gr.Row():
with gr.Column():
precize_slicing_overlap = gr.Checkbox(label='Create Overlapping blocks', value = True)
precize_slicing_overlap = gr.Checkbox(label='Add Overlapping blocks', value = True)
sliding_window = gr.Checkbox(label='DEMENTOR Long-form Learning by FP (Highly Experimental, use low epochs)', value = False, info='Deep Memorization Enforcement Through Overlapping and Repetition. (I named it, so shush). Special process for learning long-form text using low amount of epochs.')
#debug_slicer = gr.Checkbox(label='Dump sentencelist.json to logs', value = non_serialized_params['debug_slicer'], info='Debug Slicer')
with gr.Column():
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a cut between logical blocks of text (ex. Ideas or Chapters). Helps prevent unwanted overlap between unrelated ideas.')
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Text blocks that have less or equal characters than this number.')
with gr.Row():
with gr.Column():
check_dataset_btn = gr.Button('Load and Check Dataset and suggest data entries')
check_dataset_txt = gr.Textbox(label='Dataset info', value='')
with gr.Row():
start_button = gr.Button("Start LoRA Training", variant='primary')
@ -181,13 +210,150 @@ def ui():
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
# Training events
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to, precize_slicing_overlap, add_eos_token_type, save_steps_under_loss, add_bos_token, training_projection]
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to, precize_slicing_overlap, add_eos_token_type, save_steps_under_loss, add_bos_token, training_projection,sliding_window,warmup_ratio,grad_accumulation]
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
def fix_old_version(batch_size_val,micro_batch_size_val, grad_accumulation_val):
if batch_size_val>0:
gradient_acc = batch_size_val // micro_batch_size_val
print(f"Using Old version of Batch Size ({batch_size_val}) to set Gradient Accumulation: {gradient_acc}")
return gradient_acc
return grad_accumulation_val
copy_from.change(do_copy_params, [copy_from] + all_params, all_params).then(fix_old_version,[batch_size,micro_batch_size, grad_accumulation],grad_accumulation)
start_button.click(do_train, all_params, output)
stop_button.click(do_interrupt, None, None, queue=False)
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
def trigger_stop_at_loss(stop_at_loss_value):
non_serialized_params.update({"stop_at_loss": stop_at_loss_value})
if non_serialized_params['training_loop']:
print(f"Queue: [Stop at loss Change] to {stop_at_loss_value}")
stop_at_loss.change(trigger_stop_at_loss, stop_at_loss, None)
def trigger_save_checkpoint():
non_serialized_params.update({"save_checkpoint_now": True})
if non_serialized_params['training_loop']:
print("Queue: [Save checkpoint] Checkpoint will be saved after the current step is finished.")
else:
print("Use during the training to save the checkpoint at any time.")
save_chackpoint_now.click(trigger_save_checkpoint, None, None)
dataset_calc_params = [save_steps,micro_batch_size, epochs, cutoff_len, dataset, format, raw_text_file, warmup_steps, hard_cut_string, min_chars, precize_slicing_overlap,sliding_window,warmup_ratio,grad_accumulation]
def check_dataset(save_steps:int, micro_batch_size: int, epochs: int, cutoff_len: int, dataset:str, format:str, raw_text_file:str, warmup_steps:int, hard_cut_string:str, min_chars:int, precize_slicing_overlap:bool,sliding_window:bool,warmup_ratio:float,grad_accumulation:int):
result = "Specify JSON dastaset or raw text file"
total_blocks = 0
if shared.tokenizer is None:
yield "Tokenizer is not available. Please Load some Model first."
return
if raw_text_file not in ['None', '']:
logger.info("Loading raw text file dataset...")
fullpath = clean_path('training/datasets', f'{raw_text_file}')
fullpath = Path(fullpath)
if fullpath.is_dir():
logger.info('Training path directory {}'.format(raw_text_file))
raw_text = ""
file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name))
for file_path in file_paths:
if file_path.is_file():
with file_path.open('r', encoding='utf-8') as file:
raw_text += file.read().replace('\r', '')
logger.info(f"Loaded training file: {file_path.name}")
else:
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read().replace('\r', '')
if min_chars<0:
min_chars = 0
# == New more precise slicing on sentence boundary ==
if sliding_window:
text_chunks = sliding_block_cut(raw_text, min_chars, False, cutoff_len, hard_cut_string,non_serialized_params['debug_slicer'])
else:
text_chunks = precise_cut(raw_text, precize_slicing_overlap, min_chars, False, cutoff_len, hard_cut_string,non_serialized_params['debug_slicer'])
total_blocks = len(text_chunks)
result = f"Raw Text: ({raw_text_file}.txt) has {total_blocks} blocks (with cutoff length = {cutoff_len})"
del text_chunks
else:
if dataset in ['None', '']:
yield "Select dataset or Raw text."
return
if format in ['None', '']:
yield "Select format choice for dataset."
return
with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
format_data: dict[str, str] = json.load(formatFile)
def generate_prompt(data_point: dict[str, str]):
for options, data in format_data.items():
if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)):
for key, val in data_point.items():
if type(val) is str:
data = data.replace(f'%{key}%', val)
return data
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
def tokenize_dummy(prompt):
input_ids = shared.tokenizer.encode(prompt, truncation=True, max_length=cutoff_len)
labels = [1] * len(input_ids)
input_ids = torch.tensor(input_ids)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
}
def generate_and_tokenize_prompt(data_point):
prompt = generate_prompt(data_point)
return tokenize_dummy(prompt)
logger.info("Loading JSON datasets...")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
total_blocks = train_data.num_rows
result = f"Dataset: ({dataset}.json) has {total_blocks} blocks (with cutoff length = {cutoff_len})"
if total_blocks>0:
number_ofSteps = int(math.ceil(total_blocks / micro_batch_size) * epochs)
num_stepsPer_epoch = int(math.ceil(number_ofSteps/epochs))
min_warm = math.ceil(100 / grad_accumulation)
warmup_steps_suggest = min(int(min_warm*grad_accumulation), int(math.ceil(number_ofSteps * 0.1)))
warmup_steps_suggest = min(warmup_steps_suggest,num_stepsPer_epoch)
save_each_n_min = int(math.ceil(number_ofSteps/10))
save_each_n_max = int(math.ceil(number_ofSteps/5))
gradient_accumulation_max = int(total_blocks)//micro_batch_size
result += f"\n[Batch Size: {micro_batch_size}, Epochs: {epochs}, Gradient Accumulation: {grad_accumulation}]\n"
result += f"Total number of steps: {number_ofSteps}\n"
result += f"Steps per each Epoch: {num_stepsPer_epoch}\n"
result += f"Warmup steps suggestion: {warmup_steps_suggest} (Current: {int(warmup_steps)})\n"
result += f"Checkpoint suggestion: Save every {save_each_n_min} - {save_each_n_max} steps (Current: {int(save_steps)})"
if gradient_accumulation_max < grad_accumulation:
result += f"\n\nWARNING: Gradient Accumulation {grad_accumulation} is too high: It should be below {gradient_accumulation_max}"
yield result
return
check_dataset_btn.click(check_dataset, dataset_calc_params ,check_dataset_txt)
# Evaluation events. For some reason, the interrupt event
# doesn't work with the .then() syntax, so I write them one
# by one in this ugly but functional way.
@ -205,15 +371,12 @@ def ui():
lambda: "Comments saved.", None, evaluation_log, show_progress=False)
def reload_lora():
global Lora_sortedByTime
return gr.Dropdown.update(choices=get_available_loras_local(Lora_sortedByTime))
return gr.Dropdown.update(choices=get_available_loras_local(non_serialized_params['Lora_sortedByTime']))
def global_lora_time(sort_byTime):
global Lora_sortedByTime
Lora_sortedByTime = sort_byTime
sort_byTime.change(global_lora_time, sort_byTime, None).then(reload_lora,None,copy_from)
# nonserialized items
sort_byTime.change(lambda x: non_serialized_params.update({"Lora_sortedByTime": x}), sort_byTime, None).then(reload_lora,None,copy_from)
#debug_slicer.change(lambda x: non_serialized_params.update({"debug_slicer": x}), debug_slicer, None)
def do_interrupt():
@ -299,7 +462,7 @@ def calc_trainable_parameters(model):
return trainable_params, all_param
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str, precize_slicing_overlap: bool, add_eos_token_type: str, save_steps_under_loss: float, add_bos_token: bool, training_projection: str):
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str, precize_slicing_overlap: bool, add_eos_token_type: str, save_steps_under_loss: float, add_bos_token: bool, training_projection: str,sliding_window:bool,warmup_ratio:float, grad_accumulation: int):
if shared.args.monkey_patch:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@ -342,11 +505,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
yield "LoRA training with GPTQ-for-LLaMa requires loading with `--monkey-patch`"
return
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
if cutoff_len <= 0 or micro_batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
yield "Cannot input zeroes."
return
gradient_accumulation_steps = batch_size // micro_batch_size
#in new version we dumped this in favor of grad_accumulation
#set it to zero fo new save
batch_size = 0
gradient_accumulation_steps = grad_accumulation #batch_size // micro_batch_size
shared.tokenizer.pad_token_id = 0
shared.tokenizer.padding_side = "left"
@ -401,6 +568,11 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
print(f"*** LoRA: {lora_name} ***")
non_serialized_params.update({"stop_at_loss": stop_at_loss})
non_serialized_params.update({"save_steps_under_loss": save_steps_under_loss+0.01})
non_serialized_params.update({"save_checkpoint_now": False})
non_serialized_params.update({"training_loop": False})
non_serialized_params.update({"current_stability": 0})
# END OF FPHAM SENTENCE SPLIT functions ===================
@ -434,11 +606,17 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
#print (f"add_eos_token {add_eos_token}, add_EOS_to_all {add_EOS_to_all}, add_EOS_to_HC {add_EOS_to_HC}")
# == New more precise slicing on sentence boundary ==
text_chunks = precise_cut(raw_text, precize_slicing_overlap, min_chars, add_EOS_to_HC, cutoff_len, hard_cut_string)
if sliding_window:
text_chunks = sliding_block_cut(raw_text, min_chars, add_EOS_to_HC, cutoff_len, hard_cut_string,non_serialized_params['debug_slicer'])
else:
text_chunks = precise_cut(raw_text, precize_slicing_overlap, min_chars, add_EOS_to_HC, cutoff_len, hard_cut_string,non_serialized_params['debug_slicer'])
train_data = Dataset.from_list([tokenize(x, add_EOS_to_all, add_bos_token) for x in text_chunks])
if add_EOS_to_all:
print(f"Added EOS to {len(text_chunks)} blocks")
print(f"All Data Blocks: {len(text_chunks)}")
del text_chunks
eval_data = None
else:
@ -478,6 +656,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
print(f"BOS: {add_bos_token} EOS: {add_eos_token}")
print(f"Data Blocks: {train_data.num_rows}")
if eval_dataset == 'None':
eval_data = None
@ -575,17 +754,63 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if WANT_INTERRUPT:
control.should_epoch_stop = True
control.should_training_stop = True
elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
else:
current_loss = float(train_log.get('loss', 0.0))
if current_loss <= save_steps_under_loss or save_steps_under_loss==0.0:
lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m Checkpoint-{tracked.current_steps} saved")
current_epoch = float(train_log.get('epoch', 0.0))
force_save = False
folder_save = f"checkpoint-{tracked.current_steps}"
if non_serialized_params['save_checkpoint_now']:
force_save = True
non_serialized_params.update({"save_checkpoint_now": False})
print(f"\033[1;31;1mSave Checkpoint manually trigerred.\033[0;37;0m")
folder_save = f"checkpoint-{tracked.current_steps}-user"
patience = 3 # Set the number of consecutive steps for tracking stability
if gradient_accumulation_steps==1:
patience = 5
min_steps = 10
if current_loss < non_serialized_params['save_steps_under_loss'] and current_loss > 0 and state.global_step > min_steps:
current_stability = non_serialized_params['current_stability']
current_stability += 1
non_serialized_params.update({"current_stability": current_stability})
if current_stability >= patience:
current_stability = 0
non_serialized_params.update({"current_stability": current_stability})
current_loss_dec = round(current_loss, 2)
loss_str = f"{current_loss_dec:.2f}"
loss_str = loss_str.replace('.', '_')
new_save = (current_loss_dec-0.1) + 0.01
non_serialized_params.update({"save_steps_under_loss": new_save})
folder_save = f"checkpoint-{tracked.current_steps}-loss-{loss_str}"
force_save = True
else:
# Reset stability if the loss goes above the threshold
non_serialized_params.update({"current_stability": 0})
if state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
folder_save = f"checkpoint-{tracked.current_steps}"
force_save = True
if force_save:
lora_model.save_pretrained(f"{lora_file_path}/{folder_save}/")
print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m Saved: [{folder_save}]")
# Save log
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file:
with open(f"{lora_file_path}/{folder_save}/training_log.json", 'w', encoding='utf-8') as file:
json.dump(train_log, file, indent=2)
# == Save training prompt ==
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_prompt.json", 'w', encoding='utf-8') as file:
with open(f"{lora_file_path}/{folder_save}/training_prompt.json", 'w', encoding='utf-8') as file:
json.dump(train_template, file, indent=2)
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
tracked.current_steps += 1
@ -623,19 +848,19 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m")
# FPHAM SAMPLE REQ Transformers error handling
sample_req = int(train_data.num_rows)//micro_batch_size
gradient_accumulation_max = int(train_data.num_rows)//micro_batch_size
if sample_req < gradient_accumulation_steps:
if gradient_accumulation_max < gradient_accumulation_steps:
print(f"\033[1;31;1mWARNING: Current gradient accumulation is too high for the amount of training data.\033[0;37;0m")
print(f"Gradient accumulation: {gradient_accumulation_steps} should be less than: {sample_req}. \033[1;31;1mThis could crash Accelerate/Transformers\033[0;37;0m")
min_batchSize = sample_req*micro_batch_size
print(f"Gradient accumulation: {gradient_accumulation_steps} should be less than: {gradient_accumulation_max}. \033[1;31;1mThis could crash Accelerate/Transformers\033[0;37;0m")
#min_batchSize = sample_req*micro_batch_size
print(f"Preferable fix: \033[1;31;1mIncrease the size of dataset\033[0;37;0m")
print(f"... or Decrerase Batch Size \033[1;31;1m{batch_size}\033[0;37;0m to below {min_batchSize}")
gradient_accumulation_steps = max(1,sample_req-1)
print(f"... or Decrerase Gradient Accumulation \033[1;31;1m{gradient_accumulation_steps}\033[0;37;0m to below {gradient_accumulation_max}")
gradient_accumulation_steps = max(1,gradient_accumulation_max-1)
print(f"Last resort fix for this run: Lowering Gradient accumulation to {gradient_accumulation_steps}. [Good luck]")
else:
print(f"Data Size Check: Gradient accumulation: {gradient_accumulation_steps} <= Data/Batch {sample_req} ... [OK]")
print(f"Data Size Check: Gradient accumulation: {gradient_accumulation_steps} <= Blocks/Batch {gradient_accumulation_max} ... [OK]")
#END OF FPHAM SAMPLE REQ
@ -646,12 +871,16 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if lr_scheduler_type == 'FP_low_epoch_annealing':
custom_scheduller = True
lr_scheduler_type_arg = 'cosine'
elif lr_scheduler_type == 'FP_half_time_annealing':
custom_scheduller = True
lr_scheduler_type_arg = 'constant'
args=transformers.TrainingArguments(
report_to=report_to if report_to != "None" else None,
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
warmup_ratio = warmup_ratio,
num_train_epochs=epochs,
learning_rate=actual_lr,
fp16=False if shared.args.cpu else True,
@ -770,6 +999,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
elif tracked.current_steps != last_step:
last_step = tracked.current_steps
time_elapsed = time.perf_counter() - start_time
lastloss = float(train_log.get('loss', 0.0))
non_serialized_params.update({"training_loop": True})
if lastloss > 0:
lastloss_str = f", ... Current Loss: `{lastloss:.2f}`"
else:
lastloss_str = ""
if time_elapsed <= 0:
timer_info = ""
total_time_estimate = 999
@ -782,16 +1020,23 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
total_time_estimate = (1.0 / its) * (tracked.max_steps)
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
if stop_at_loss != non_serialized_params['stop_at_loss']:
stop_at_loss = non_serialized_params['stop_at_loss']
print(f"Stop at loss changed \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining {lastloss_str}"
# Saving in the train thread might fail if an error occurs, so save here if so.
non_serialized_params.update({"training_loop": False})
if not tracked.did_save:
logger.info("Training complete, saving...")
lora_model.save_pretrained(lora_file_path)
if WANT_INTERRUPT:
logger.info("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`."
yield f"Interrupted by user. LoRA saved to `{lora_file_path}`."
else:
logger.info("Training complete!")
yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training."

View file

@ -86,9 +86,8 @@ def split_sentences(text: str, cutoff_len: int):
# hard cut defined by hard_cut_string or </s> will always end at the end of data block
# no overlapping blocks will be created across hard cut or across </s> token
def precise_cut(text: str, overlap: bool, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str):
def precise_cut(text: str, overlap: bool, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str, debug_slicer:bool):
debug_slicer = False
EOSX_str = '<//>' #hardcut placeholder
EOS_str = '</s>'
print("Precise raw text slicer: ON")
@ -187,6 +186,94 @@ def precise_cut(text: str, overlap: bool, min_chars_cut: int, eos_to_hc: bool, c
output_file = "logs/sentencelist.json"
with open(output_file, 'w') as f:
json.dump(sentencelist_dict, f,indent=2)
print("Saved sentencelist.json in logs folder")
return sentencelist
return sentencelist
def sliding_block_cut(text: str, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str, debug_slicer:bool):
EOSX_str = '<//>' #hardcut placeholder
EOS_str = '</s>'
print("Mega Block Overlap: ON")
cut_string = hard_cut_string.replace('\\n', '\n')
text = text.replace(cut_string, EOSX_str)
sentences = split_sentences(text, cutoff_len)
print(f"Sentences: {len(sentences)}")
sentencelist = []
max_cut = cutoff_len-1
#print(f"max_cut: {max_cut}")
advancing_to = 0
prev_block_lastsentence = ""
for i in range(len(sentences)):
totalLength = 0
currentSentence = ''
lastsentence = ""
if i >= advancing_to:
for k in range(i, len(sentences)):
current_length = sentences[k]['size']
if totalLength + current_length <= max_cut and not currentSentence.endswith(EOSX_str):
currentSentence += sentences[k]['text']
totalLength += current_length
lastsentence = sentences[k]['text']
else:
if len(currentSentence.strip()) > min_chars_cut:
if prev_block_lastsentence!=lastsentence:
sentencelist.append(currentSentence.strip())
prev_block_lastsentence = lastsentence
advancing_to = 0
if currentSentence.endswith(EOSX_str):
advancing_to = k
currentSentence = ""
totalLength = 0
break
if currentSentence != "":
if len(currentSentence.strip()) > min_chars_cut:
sentencelist.append(currentSentence.strip())
unique_blocks = len(sentencelist)
print(f"Text Blocks: {unique_blocks}")
num_EOS = 0
for i in range(len(sentencelist)):
if eos_to_hc:
sentencelist[i] = sentencelist[i].replace(EOSX_str, EOS_str)
else:
sentencelist[i] = sentencelist[i].replace(EOSX_str, '')
#someone may have had stop strings in the raw text...
sentencelist[i] = sentencelist[i].replace("</s></s>", EOS_str)
num_EOS += sentencelist[i].count(EOS_str)
if num_EOS > 0:
print(f"+ EOS count: {num_EOS}")
#final check for useless lines
sentencelist = [item for item in sentencelist if item.strip() != "</s>"]
sentencelist = [item for item in sentencelist if item.strip() != ""]
if debug_slicer:
# Write the log file
Path('logs').mkdir(exist_ok=True)
sentencelist_dict = {index: sentence for index, sentence in enumerate(sentencelist)}
output_file = "logs/sentencelist.json"
with open(output_file, 'w') as f:
json.dump(sentencelist_dict, f,indent=2)
print("Saved sentencelist.json in logs folder")
return sentencelist