Add a retry mechanism to the model downloader (#5943)

This commit is contained in:
oobabooga 2024-04-27 12:25:28 -03:00 committed by GitHub
parent dfdb6fee22
commit 5770e06c48
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: B5690EEEBB952194

View file

@ -15,10 +15,12 @@ import os
import re
import sys
from pathlib import Path
from time import sleep
import requests
import tqdm
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, RequestException, Timeout
from tqdm.contrib.concurrent import thread_map
base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
@ -177,50 +179,65 @@ class ModelDownloader:
return output_folder
def get_single_file(self, url, output_folder, start_from_scratch=False):
session = self.get_session()
filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename
headers = {}
mode = 'wb'
if output_path.exists() and not start_from_scratch:
# Check if the file has already been downloaded completely
r = session.get(url, stream=True, timeout=10)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
max_retries = 7
attempt = 0
while attempt < max_retries:
attempt += 1
session = self.get_session()
headers = {}
mode = 'wb'
# Otherwise, resume the download from where it left off
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'
if output_path.exists() and not start_from_scratch:
# Resume download
r = session.get(url, stream=True, timeout=20)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
with session.get(url, stream=True, headers=headers, timeout=10) as r:
r.raise_for_status() # Do not continue the download if the request was unsuccessful
total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'
tqdm_kwargs = {
'total': total_size,
'unit': 'iB',
'unit_scale': True,
'bar_format': '{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}'
}
try:
with session.get(url, stream=True, headers=headers, timeout=30) as r:
r.raise_for_status() # If status is not 2xx, raise an error
total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB
if 'COLAB_GPU' in os.environ:
tqdm_kwargs.update({
'position': 0,
'leave': True
})
tqdm_kwargs = {
'total': total_size,
'unit': 'iB',
'unit_scale': True,
'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} {rate_fmt}'
}
with open(output_path, mode) as f:
with tqdm.tqdm(**tqdm_kwargs) as t:
count = 0
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
if total_size != 0 and self.progress_bar is not None:
count += len(data)
self.progress_bar(float(count) / float(total_size), f"{filename}")
if 'COLAB_GPU' in os.environ:
tqdm_kwargs.update({
'position': 0,
'leave': True
})
with open(output_path, mode) as f:
with tqdm.tqdm(**tqdm_kwargs) as t:
count = 0
for data in r.iter_content(block_size):
f.write(data)
t.update(len(data))
if total_size != 0 and self.progress_bar is not None:
count += len(data)
self.progress_bar(float(count) / float(total_size), f"{filename}")
break # Exit loop if successful
except (RequestException, ConnectionError, Timeout) as e:
print(f"Error downloading {filename}: {e}.")
print(f"That was attempt {attempt}/{max_retries}.", end=' ')
if attempt < max_retries:
print(f"Retry begins in {2 ** attempt} seconds.")
sleep(2 ** attempt)
else:
print("Failed to download after the maximum number of attempts.")
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4):
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)