diff --git a/download-model.py b/download-model.py index a5ee0223..a26ab8f4 100644 --- a/download-model.py +++ b/download-model.py @@ -6,6 +6,7 @@ python download-model.py facebook/opt-1.3b ''' import argparse +import json import multiprocessing import re import sys @@ -13,7 +14,6 @@ from pathlib import Path import requests import tqdm -from bs4 import BeautifulSoup parser = argparse.ArgumentParser() parser.add_argument('MODEL', type=str, default=None, nargs='?') @@ -90,6 +90,49 @@ facebook/opt-1.3b return model, branch +def get_download_links_from_huggingface(model, branch): + base = "https://huggingface.co" + page = f"/api/models/{model}/tree/{branch}?cursor=" + + links = [] + classifications = [] + has_pytorch = False + has_safetensors = False + while page is not None: + content = requests.get(f"{base}{page}").content + dict = json.loads(content) + + for i in range(len(dict['items'])): + fname = dict['items'][i]['path'] + + is_pytorch = re.match("pytorch_model.*\.bin", fname) + is_safetensors = re.match("model.*\.safetensors", fname) + is_text = re.match(".*\.(txt|json)", fname) + + if is_text or is_safetensors or is_pytorch: + if is_text: + links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") + classifications.append('text') + continue + if not args.text_only: + links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") + if is_safetensors: + has_safetensors = True + classifications.append('safetensors') + elif is_pytorch: + has_pytorch = True + classifications.append('pytorch') + + page = dict['nextUrl'] + + # If both pytorch and safetensors are available, download safetensors only + if has_pytorch and has_safetensors: + for i in range(len(classifications)-1, -1, -1): + if classifications[i] == 'pytorch': + links.pop(i) + + return links + if __name__ == '__main__': model = args.MODEL branch = args.branch @@ -107,7 +150,6 @@ if __name__ == '__main__': except ValueError as err_branch: print(f"Error: {err_branch}") sys.exit() - url = f'https://huggingface.co/{model}/tree/{branch}' if branch != 'main': output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}') else: @@ -115,45 +157,11 @@ if __name__ == '__main__': if not output_folder.exists(): output_folder.mkdir() - # Finding the relevant files to download - page = requests.get(url) - soup = BeautifulSoup(page.content, 'html.parser') - links = soup.find_all('a') - downloads = [] - classifications = [] - has_pytorch = False - has_safetensors = False - for link in links: - href = link.get('href')[1:] - if href.startswith(f'{model}/resolve/{branch}'): - fname = Path(href).name - is_pytorch = re.match("pytorch_model.*\.bin", fname) - is_safetensors = re.match("model.*\.safetensors", fname) - is_text = re.match(".*\.(txt|json)", fname) - - if is_text or is_safetensors or is_pytorch: - if is_text: - downloads.append(f'https://huggingface.co/{href}') - classifications.append('text') - continue - if not args.text_only: - downloads.append(f'https://huggingface.co/{href}') - if is_safetensors: - has_safetensors = True - classifications.append('safetensors') - elif is_pytorch: - has_pytorch = True - classifications.append('pytorch') - - # If both pytorch and safetensors are available, download safetensors only - if has_pytorch and has_safetensors: - for i in range(len(classifications)-1, -1, -1): - if classifications[i] == 'pytorch': - downloads.pop(i) + links = get_download_links_from_huggingface(model, branch) # Downloading the files print(f"Downloading the model to {output_folder}") pool = multiprocessing.Pool(processes=args.threads) - results = pool.map(get_file, [[downloads[i], output_folder, i+1, len(downloads)] for i in range(len(downloads))]) + results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))]) pool.close() pool.join() diff --git a/requirements.txt b/requirements.txt index 7a3108fe..89825316 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ accelerate==0.16.0 -beautifulsoup4 bitsandbytes==0.37.0 gradio==3.18.0 numpy