From 5770e06c4875797fc144a135a082c754cab7a92a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 27 Apr 2024 12:25:28 -0300 Subject: [PATCH] Add a retry mechanism to the model downloader (#5943) --- download-model.py | 89 ++++++++++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 36 deletions(-) diff --git a/download-model.py b/download-model.py index eca17c43..c38e79fb 100644 --- a/download-model.py +++ b/download-model.py @@ -15,10 +15,12 @@ import os import re import sys from pathlib import Path +from time import sleep import requests import tqdm from requests.adapters import HTTPAdapter +from requests.exceptions import ConnectionError, RequestException, Timeout from tqdm.contrib.concurrent import thread_map base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co" @@ -177,50 +179,65 @@ class ModelDownloader: return output_folder def get_single_file(self, url, output_folder, start_from_scratch=False): - session = self.get_session() filename = Path(url.rsplit('/', 1)[1]) output_path = output_folder / filename - headers = {} - mode = 'wb' - if output_path.exists() and not start_from_scratch: - # Check if the file has already been downloaded completely - r = session.get(url, stream=True, timeout=10) - total_size = int(r.headers.get('content-length', 0)) - if output_path.stat().st_size >= total_size: - return + max_retries = 7 + attempt = 0 + while attempt < max_retries: + attempt += 1 + session = self.get_session() + headers = {} + mode = 'wb' - # Otherwise, resume the download from where it left off - headers = {'Range': f'bytes={output_path.stat().st_size}-'} - mode = 'ab' + if output_path.exists() and not start_from_scratch: + # Resume download + r = session.get(url, stream=True, timeout=20) + total_size = int(r.headers.get('content-length', 0)) + if output_path.stat().st_size >= total_size: + return - with session.get(url, stream=True, headers=headers, timeout=10) as r: - r.raise_for_status() # Do not continue the download if the request was unsuccessful - total_size = int(r.headers.get('content-length', 0)) - block_size = 1024 * 1024 # 1MB + headers = {'Range': f'bytes={output_path.stat().st_size}-'} + mode = 'ab' - tqdm_kwargs = { - 'total': total_size, - 'unit': 'iB', - 'unit_scale': True, - 'bar_format': '{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}' - } + try: + with session.get(url, stream=True, headers=headers, timeout=30) as r: + r.raise_for_status() # If status is not 2xx, raise an error + total_size = int(r.headers.get('content-length', 0)) + block_size = 1024 * 1024 # 1MB - if 'COLAB_GPU' in os.environ: - tqdm_kwargs.update({ - 'position': 0, - 'leave': True - }) + tqdm_kwargs = { + 'total': total_size, + 'unit': 'iB', + 'unit_scale': True, + 'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} {rate_fmt}' + } - with open(output_path, mode) as f: - with tqdm.tqdm(**tqdm_kwargs) as t: - count = 0 - for data in r.iter_content(block_size): - t.update(len(data)) - f.write(data) - if total_size != 0 and self.progress_bar is not None: - count += len(data) - self.progress_bar(float(count) / float(total_size), f"{filename}") + if 'COLAB_GPU' in os.environ: + tqdm_kwargs.update({ + 'position': 0, + 'leave': True + }) + + with open(output_path, mode) as f: + with tqdm.tqdm(**tqdm_kwargs) as t: + count = 0 + for data in r.iter_content(block_size): + f.write(data) + t.update(len(data)) + if total_size != 0 and self.progress_bar is not None: + count += len(data) + self.progress_bar(float(count) / float(total_size), f"{filename}") + + break # Exit loop if successful + except (RequestException, ConnectionError, Timeout) as e: + print(f"Error downloading {filename}: {e}.") + print(f"That was attempt {attempt}/{max_retries}.", end=' ') + if attempt < max_retries: + print(f"Retry begins in {2 ** attempt} seconds.") + sleep(2 ** attempt) + else: + print("Failed to download after the maximum number of attempts.") def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4): thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)