Small style changes

This commit is contained in:
oobabooga 2023-03-27 21:24:39 -03:00
parent c2cad30772
commit 2f0571bfa4
3 changed files with 20 additions and 7 deletions

View file

@ -41,7 +41,7 @@ ol li p, ul li p {
display: inline-block;
}
#main, #parameters, #chat-settings, #interface-mode, #lora {
#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab {
border: 0;
}

View file

@ -1,10 +1,17 @@
import sys, torch, json, threading, time
import json
import sys
import threading
import time
from pathlib import Path
import gradio as gr
from datasets import load_dataset
import torch
import transformers
from modules import ui, shared
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict
from datasets import load_dataset
from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict,
prepare_model_for_int8_training)
from modules import shared, ui
WANT_INTERRUPT = False
CURRENT_STEPS = 0
@ -44,7 +51,7 @@ def create_train_interface():
with gr.Row():
startButton = gr.Button("Start LoRA Training")
stopButton = gr.Button("Interrupt")
output = gr.Markdown(value="(...)")
output = gr.Markdown(value="Ready")
startEvent = startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
stopButton.click(doInterrupt, [], [], cancels=[], queue=False)
@ -169,16 +176,20 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
).__get__(loraModel, type(loraModel))
if torch.__version__ >= "2" and sys.platform != "win32":
loraModel = torch.compile(loraModel)
# == Main run and monitor loop ==
# TODO: save/load checkpoints to resume from?
print("Starting training...")
yield "Starting..."
def threadedRun():
trainer.train()
thread = threading.Thread(target=threadedRun)
thread.start()
lastStep = 0
startTime = time.perf_counter()
while thread.is_alive():
time.sleep(0.5)
if WANT_INTERRUPT:
@ -197,8 +208,10 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
timerInfo = f"`{1.0/its:.2f}` s/it"
totalTimeEstimate = (1.0/its) * (MAX_STEPS)
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
print("Training complete, saving...")
loraModel.save_pretrained(loraName)
if WANT_INTERRUPT:
print("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{loraName}`"

View file

@ -9,8 +9,8 @@ from pathlib import Path
import gradio as gr
from modules import chat, shared, ui, training
import modules.extensions as extensions_module
from modules import chat, shared, training, ui
from modules.html_generator import generate_chat_html
from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt