From 73a0def4af8e05aff3af05732f5d1e787cbd8616 Mon Sep 17 00:00:00 2001 From: practicaldreamer <78515588+practicaldreamer@users.noreply.github.com> Date: Wed, 12 Jul 2023 09:26:45 -0500 Subject: [PATCH] Add Feature to Log Sample of Training Dataset for Inspection (#1711) --- modules/training.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/modules/training.py b/modules/training.py index fa9281bb..442b92b3 100644 --- a/modules/training.py +++ b/modules/training.py @@ -579,8 +579,27 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch if WANT_INTERRUPT: yield "Interrupted before start." return + + def log_train_dataset(trainer): + decoded_entries = [] + # Try to decode the entries and write the log file + try: + # Iterate over the first 10 elements in the dataset (or fewer if there are less than 10) + for i in range(min(10, len(trainer.train_dataset))): + decoded_text = shared.tokenizer.decode(trainer.train_dataset[i]['input_ids']) + decoded_entries.append({"value": decoded_text}) + + # Write the log file + Path('logs').mkdir(exist_ok=True) + with open(Path('logs/train_dataset_sample.json'), 'w') as json_file: + json.dump(decoded_entries, json_file, indent=4) + + logger.info("Log file 'train_dataset_sample.json' created in the 'logs' directory.") + except Exception as e: + logger.error(f"Failed to create log file due to error: {e}") def threaded_run(): + log_train_dataset(trainer) trainer.train() # Note: save in the thread in case the gradio thread breaks (eg browser closed) lora_model.save_pretrained(lora_file_path)