Merge branch 'main' into pt-path-changes

This commit is contained in:
oobabooga 2023-03-10 11:03:42 -03:00 committed by GitHub
commit e9dbdafb14
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 9 deletions

View file

@ -54,7 +54,7 @@ The third line assumes that you have an NVIDIA GPU.
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
```
* If you are running in CPU mode, replace the third command with this one:
* If you are running it in CPU mode, replace the third command with this one:
```
conda install pytorch torchvision torchaudio git -c pytorch

View file

@ -5,7 +5,9 @@ Example:
python download-model.py facebook/opt-1.3b
'''
import argparse
import base64
import json
import multiprocessing
import re
@ -93,23 +95,28 @@ facebook/opt-1.3b
def get_download_links_from_huggingface(model, branch):
base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor="
cursor = b""
links = []
classifications = []
has_pytorch = False
has_safetensors = False
while page is not None:
content = requests.get(f"{base}{page}").content
while True:
content = requests.get(f"{base}{page}{cursor.decode()}").content
dict = json.loads(content)
if len(dict) == 0:
break
for i in range(len(dict)):
fname = dict[i]['path']
is_pytorch = re.match("pytorch_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname)
is_text = re.match(".*\.(txt|json)", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
if is_text or is_safetensors or is_pytorch:
if any((is_pytorch, is_safetensors, is_text, is_tokenizer)):
if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text')
@ -123,8 +130,9 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True
classifications.append('pytorch')
#page = dict['nextUrl']
page = None
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
cursor = base64.b64encode(cursor)
cursor = cursor.replace(b'=', b'%3D')
# If both pytorch and safetensors are available, download safetensors only
if has_pytorch and has_safetensors:

View file

@ -116,7 +116,22 @@ def load_model(model_name):
print(f"Could not find {pt_model}, exiting...")
exit()
model = load_quant(path_to_model, pt_path, 4)
model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4)
# Multi-GPU setup
if shared.args.gpu_memory:
import accelerate
max_memory = {}
for i in range(len(shared.args.gpu_memory)):
max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"])
model = accelerate.dispatch_model(model, device_map=device_map)
# Single GPU
else:
model = model.to(torch.device('cuda:0'))
# Custom