diff --git a/modules/training.py b/modules/training.py index 62ba181c..5ba8d352 100644 --- a/modules/training.py +++ b/modules/training.py @@ -20,7 +20,7 @@ MAX_STEPS = 0 CURRENT_GRADIENT_ACCUM = 1 def get_dataset(path: str, ext: str): - return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path(path).glob(f'*.{ext}'))), key=str.lower) + return ['None'] + sorted(set((k.stem for k in Path(path).glob(f'*.{ext}'))), key=str.lower) def create_train_interface(): with gr.Tab('Train LoRA', elem_id='lora-train-tab'): @@ -104,7 +104,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int actual_lr = float(learning_rate) if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: - yield f"Cannot input zeroes." + yield "Cannot input zeroes." return gradient_accumulation_steps = batch_size // micro_batch_size diff --git a/server.py b/server.py index 44b135e9..2755e892 100644 --- a/server.py +++ b/server.py @@ -36,12 +36,12 @@ def get_available_models(): return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) def get_available_presets(): - return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) + return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower) def get_available_prompts(): prompts = [] - prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True) - prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('*.txt'))), key=str.lower) + prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True) + prompts += sorted(set((k.stem for k in Path('prompts').glob('*.txt'))), key=str.lower) prompts += ['None'] return prompts @@ -53,7 +53,7 @@ def get_available_extensions(): return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) def get_available_softprompts(): - return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) + return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower) def get_available_loras(): return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)