diff --git a/download-model.py b/download-model.py index 5e62036f..82e956d6 100644 --- a/download-model.py +++ b/download-model.py @@ -20,7 +20,6 @@ import requests import tqdm from requests.adapters import HTTPAdapter from tqdm.contrib.concurrent import thread_map -from huggingface_hub import get_token base = "https://huggingface.co" @@ -31,10 +30,18 @@ class ModelDownloader: if max_retries: self.session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries)) self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries)) + if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None: self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) - if get_token() is not None: - self.session.headers = {'authorization': f'Bearer {get_token()}'} + + try: + from huggingface_hub import get_token + token = get_token() + except ImportError: + token = os.getenv("HF_TOKEN") + + if token is not None: + self.session.headers = {'authorization': f'Bearer {token}'} def sanitize_model_and_branch_names(self, model, branch): if model[-1] == '/':