diff --git a/download-model.py b/download-model.py index 09bc9a86..d7cf9273 100644 --- a/download-model.py +++ b/download-model.py @@ -26,13 +26,16 @@ base = "https://huggingface.co" class ModelDownloader: def __init__(self, max_retries=5): - self.session = requests.Session() - if max_retries: - self.session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries)) - self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries)) + self.max_retries = max_retries + + def get_session(self): + session = requests.Session() + if self.max_retries: + session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=self.max_retries)) + session.mount('https://huggingface.co', HTTPAdapter(max_retries=self.max_retries)) if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None: - self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) + session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) try: from huggingface_hub import get_token @@ -41,7 +44,9 @@ class ModelDownloader: token = os.getenv("HF_TOKEN") if token is not None: - self.session.headers = {'authorization': f'Bearer {token}'} + session.headers = {'authorization': f'Bearer {token}'} + + return session def sanitize_model_and_branch_names(self, model, branch): if model[-1] == '/': @@ -65,6 +70,7 @@ class ModelDownloader: return model, branch def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None): + session = self.get_session() page = f"/api/models/{model}/tree/{branch}" cursor = b"" @@ -78,7 +84,7 @@ class ModelDownloader: is_lora = False while True: url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "") - r = self.session.get(url, timeout=10) + r = session.get(url, timeout=10) r.raise_for_status() content = r.content @@ -171,6 +177,7 @@ class ModelDownloader: return output_folder def get_single_file(self, url, output_folder, start_from_scratch=False): + session = self.get_session() filename = Path(url.rsplit('/', 1)[1]) output_path = output_folder / filename headers = {} @@ -178,7 +185,7 @@ class ModelDownloader: if output_path.exists() and not start_from_scratch: # Check if the file has already been downloaded completely - r = self.session.get(url, stream=True, timeout=10) + r = session.get(url, stream=True, timeout=10) total_size = int(r.headers.get('content-length', 0)) if output_path.stat().st_size >= total_size: return @@ -187,7 +194,7 @@ class ModelDownloader: headers = {'Range': f'bytes={output_path.stat().st_size}-'} mode = 'ab' - with self.session.get(url, stream=True, headers=headers, timeout=10) as r: + with session.get(url, stream=True, headers=headers, timeout=10) as r: r.raise_for_status() # Do not continue the download if the request was unsuccessful total_size = int(r.headers.get('content-length', 0)) block_size = 1024 * 1024 # 1MB