diff --git a/download-model.py b/download-model.py index 0014b689..306784a3 100644 --- a/download-model.py +++ b/download-model.py @@ -29,6 +29,7 @@ base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co" class ModelDownloader: def __init__(self, max_retries=5): self.max_retries = max_retries + self.session = self.get_session() def get_session(self): session = requests.Session() @@ -72,7 +73,7 @@ class ModelDownloader: return model, branch def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None): - session = self.get_session() + session = self.session page = f"/api/models/{model}/tree/{branch}" cursor = b"" @@ -192,7 +193,7 @@ class ModelDownloader: attempt = 0 while attempt < max_retries: attempt += 1 - session = self.get_session() + session = self.session headers = {} mode = 'wb'