From 3a9d90c3a1127ee3c402e2613f9b3c7abd3ca12e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 10 Oct 2023 13:52:10 -0700 Subject: [PATCH] Download models with 4 threads by default --- download-model.py | 6 +++--- modules/ui_model_menu.py | 10 ++-------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/download-model.py b/download-model.py index d37ae32c..8b8d7b25 100644 --- a/download-model.py +++ b/download-model.py @@ -177,10 +177,10 @@ class ModelDownloader: count += len(data) self.progress_bar(float(count) / float(total_size), f"{filename}") - def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1): + def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4): thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True) - def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=1, specific_file=None, is_llamacpp=False): + def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=4, specific_file=None, is_llamacpp=False): self.progress_bar = progress_bar # Create the folder and writing the metadata @@ -236,7 +236,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('MODEL', type=str, default=None, nargs='?') parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') - parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') + parser.add_argument('--threads', type=int, default=4, help='Number of files to download simultaneously.') parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).') parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index bfa95c07..37746a03 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -225,16 +225,11 @@ def load_lora_wrapper(selected_loras): def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False, check=False): try: - downloader_module = importlib.import_module("download-model") - downloader = downloader_module.ModelDownloader() - progress(0.0) - yield ("Cleaning up the model/branch names") + downloader = importlib.import_module("download-model").ModelDownloader() model, branch = downloader.sanitize_model_and_branch_names(repo_id, None) - yield ("Getting the download links from Hugging Face") links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file) - if return_links: yield '\n\n'.join([f"`{Path(link).name}`" for link in links]) return @@ -242,7 +237,6 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur yield ("Getting the output folder") base_folder = shared.args.lora_dir if is_lora else shared.args.model_dir output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, base_folder=base_folder) - if check: progress(0.5) yield ("Checking previously downloaded files") @@ -250,7 +244,7 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur progress(1.0) else: yield (f"Downloading file{'s' if len(links) > 1 else ''} to `{output_folder}/`") - downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=1, is_llamacpp=is_llamacpp) + downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=4, is_llamacpp=is_llamacpp) yield ("Done!") except: progress(1.0)