Only download safetensors if both pytorch and safetensors are present

This commit is contained in:
oobabooga 2023-02-12 00:06:22 -03:00
parent 5d3f15b915
commit 66862203fc

View file

@ -72,13 +72,33 @@ if __name__ == '__main__':
soup = BeautifulSoup(page.content, 'html.parser')
links = soup.find_all('a')
downloads = []
classifications = []
has_pytorch = False
has_safetensors = False
for link in links:
href = link.get('href')[1:]
if href.startswith(f'{model}/resolve/{branch}'):
is_pytorch = href.endswith('.bin') and 'pytorch_model' in href
is_safetensors = href.endswith('.safetensors') and 'model' in href
if href.endswith(('.json', '.txt')) or is_pytorch or is_safetensors:
fname = Path(href).name
is_pytorch = re.match("pytorch_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname)
is_text = re.match(".*\.(txt|json)", fname)
if is_text or is_safetensors or is_pytorch:
downloads.append(f'https://huggingface.co/{href}')
if is_text:
classifications.append('text')
elif is_safetensors:
has_safetensors = True
classifications.append('safetensors')
elif is_pytorch:
has_pytorch = True
classifications.append('pytorch')
# If both pytorch and safetensors are available, download safetensors only
if has_pytorch and has_safetensors:
for i in range(len(classifications)-1, -1, -1):
if classifications[i] == 'pytorch':
downloads.pop(i)
# Downloading the files
print(f"Downloading the model to {output_folder}")