Downloader: use HF get_token function (#5381)

This commit is contained in:
Anthony Guijarro 2024-01-27 14:13:09 -06:00 committed by GitHub
parent de387069da
commit 828be63f2c
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: B5690EEEBB952194

View file

@ -20,6 +20,7 @@ import requests
import tqdm
from requests.adapters import HTTPAdapter
from tqdm.contrib.concurrent import thread_map
from huggingface_hub import get_token
base = "https://huggingface.co"
@ -32,8 +33,8 @@ class ModelDownloader:
self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
if os.getenv('HF_TOKEN') is not None:
self.session.headers = {'authorization': f'Bearer {os.getenv("HF_TOKEN")}'}
if get_token() is not None:
self.session.headers = {'authorization': f'Bearer {get_token()}'}
def sanitize_model_and_branch_names(self, model, branch):
if model[-1] == '/':