diff --git a/download-model.py b/download-model.py index 90711884..46aa9d77 100644 --- a/download-model.py +++ b/download-model.py @@ -72,13 +72,33 @@ if __name__ == '__main__': soup = BeautifulSoup(page.content, 'html.parser') links = soup.find_all('a') downloads = [] + classifications = [] + has_pytorch = False + has_safetensors = False for link in links: href = link.get('href')[1:] if href.startswith(f'{model}/resolve/{branch}'): - is_pytorch = href.endswith('.bin') and 'pytorch_model' in href - is_safetensors = href.endswith('.safetensors') and 'model' in href - if href.endswith(('.json', '.txt')) or is_pytorch or is_safetensors: + fname = Path(href).name + is_pytorch = re.match("pytorch_model.*\.bin", fname) + is_safetensors = re.match("model.*\.safetensors", fname) + is_text = re.match(".*\.(txt|json)", fname) + + if is_text or is_safetensors or is_pytorch: downloads.append(f'https://huggingface.co/{href}') + if is_text: + classifications.append('text') + elif is_safetensors: + has_safetensors = True + classifications.append('safetensors') + elif is_pytorch: + has_pytorch = True + classifications.append('pytorch') + + # If both pytorch and safetensors are available, download safetensors only + if has_pytorch and has_safetensors: + for i in range(len(classifications)-1, -1, -1): + if classifications[i] == 'pytorch': + downloads.pop(i) # Downloading the files print(f"Downloading the model to {output_folder}")