From 2c14df81a82bfbdaee5662812c80cc95a00cffdb Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 10 Apr 2023 11:36:39 -0300 Subject: [PATCH] Use download-model.py to download the model --- download-model.py | 22 +++++----- server.py | 100 ++++++++++++++++++---------------------------- 2 files changed, 50 insertions(+), 72 deletions(-) diff --git a/download-model.py b/download-model.py index a48a1b8c..fc17e716 100644 --- a/download-model.py +++ b/download-model.py @@ -20,17 +20,6 @@ import tqdm from tqdm.contrib.concurrent import thread_map -parser = argparse.ArgumentParser() -parser.add_argument('MODEL', type=str, default=None, nargs='?') -parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') -parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') -parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') -parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') -parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') -parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') -args = parser.parse_args() - - def select_model_from_default_options(): models = { "OPT 6.7B": ("facebook", "opt-6.7b", "main"), @@ -244,6 +233,17 @@ def check_model_files(model, branch, links, sha256, output_folder): if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('MODEL', type=str, default=None, nargs='?') + parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') + parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') + parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') + parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') + parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') + parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') + args = parser.parse_args() + branch = args.branch model = args.MODEL if model is None: diff --git a/server.py b/server.py index ae5a905f..5c0142c8 100644 --- a/server.py +++ b/server.py @@ -2,17 +2,21 @@ import os os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' +import importlib import io import json +import os import re import sys import time +import traceback import zipfile from datetime import datetime from pathlib import Path -import os -import requests + import gradio as gr +import requests +from huggingface_hub import HfApi from PIL import Image import modules.extensions as extensions_module @@ -21,7 +25,6 @@ from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt, unload_model from modules.text_generation import generate_reply, stop_everything_event -from huggingface_hub import HfApi # Loading custom settings settings_file = None @@ -175,59 +178,31 @@ def create_prompt_menus(): def download_model_wrapper(repo_id): - print(repo_id) - if repo_id == '': - print("Please enter a valid repo ID. This field cant be empty") - else: - try: - print('Downloading repo') - hf_api = HfApi() - # Get repo info - repo_info = hf_api.repo_info( - repo_id=repo_id, - repo_type="model", - revision="main" - ) - # create model and repo folder and check for lora - is_lora = False - for file in repo_info.siblings: - if 'adapter_model.bin' in file.rfilename: - is_lora = True - repo_dir_name = repo_id.replace("/", "--") - if is_lora is True: - models_dir = "loras" - else: - models_dir = "models" - if not os.path.exists(models_dir): - os.makedirs(models_dir) - repo_dir = os.path.join(models_dir, repo_dir_name) - if not os.path.exists(repo_dir): - os.makedirs(repo_dir) + try: + downloader = importlib.import_module("download-model") - for sibling in repo_info.siblings: - filename = sibling.rfilename - url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}" - download_path = os.path.join(repo_dir, filename) - response = requests.get(url, stream=True) - # Get the total file size from the content-length header - total_size = int(response.headers.get('content-length', 0)) + model = repo_id + branch = "main" + check = False - # Download the file in chunks and print progress - with open(download_path, 'wb') as f: - downloaded_size = 0 - for data in response.iter_content(chunk_size=10000000): - downloaded_size += len(data) - f.write(data) - progress = downloaded_size * 100 // total_size - downloaded_size_mb = downloaded_size / (1024 * 1024) - total_size_mb = total_size / (1024 * 1024) - print(f"\rDownloading {filename}... {progress}% complete " - f"({downloaded_size_mb:.2f}/{total_size_mb:.2f} MB)", end="", flush=True) - print(f"\rDownloading {filename}... Complete!") + yield("Cleaning up the model/branch names") + model, branch = downloader.sanitize_model_and_branch_names(model, branch) - print('Repo Downloaded') - except ValueError as e: - raise ValueError("Please enter a valid repo ID. Error: {}".format(e)) + yield("Getting the download links from Hugging Face") + links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False) + + yield("Getting the output folder") + output_folder = downloader.get_output_folder(model, branch, is_lora) + + if check: + yield("Checking previously downloaded files") + downloader.check_model_files(model, branch, links, sha256, output_folder) + else: + yield("Downloading files") + downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1) + yield("Done!") + except: + yield traceback.format_exc() def create_model_menus(): @@ -241,17 +216,20 @@ def create_model_menus(): shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button') with gr.Row(): - with gr.Column(scale=0.5): - shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model", - info="Enter hugging face username/model path e.g: 'decapoda-research/llama-7b-hf'") - with gr.Row(): - with gr.Column(scale=0.5): - shared.gradio['download_button'] = gr.Button("Download", show_progress=True) - shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], - show_progress=True) + with gr.Column(): + with gr.Row(): + with gr.Column(): + shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model", + info="Enter Hugging Face username/model path e.g: facebook/galactica-125m") + with gr.Column(): + shared.gradio['download_button'] = gr.Button("Download", show_progress=True) + shared.gradio['download_status'] = gr.Markdown() + with gr.Column(): + pass shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) + shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False) def create_settings_menus(default_preset):