Add a retry mechanism to the model downloader (#5943)

This commit is contained in:
oobabooga 2024-04-27 12:25:28 -03:00 committed by GitHub
parent dfdb6fee22
commit 5770e06c48
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: B5690EEEBB952194

View file

@ -15,10 +15,12 @@ import os
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from time import sleep
import requests import requests
import tqdm import tqdm
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, RequestException, Timeout
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co" base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
@ -177,50 +179,65 @@ class ModelDownloader:
return output_folder return output_folder
def get_single_file(self, url, output_folder, start_from_scratch=False): def get_single_file(self, url, output_folder, start_from_scratch=False):
session = self.get_session()
filename = Path(url.rsplit('/', 1)[1]) filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename output_path = output_folder / filename
headers = {}
mode = 'wb'
if output_path.exists() and not start_from_scratch:
# Check if the file has already been downloaded completely max_retries = 7
r = session.get(url, stream=True, timeout=10) attempt = 0
total_size = int(r.headers.get('content-length', 0)) while attempt < max_retries:
if output_path.stat().st_size >= total_size: attempt += 1
return session = self.get_session()
headers = {}
mode = 'wb'
# Otherwise, resume the download from where it left off if output_path.exists() and not start_from_scratch:
headers = {'Range': f'bytes={output_path.stat().st_size}-'} # Resume download
mode = 'ab' r = session.get(url, stream=True, timeout=20)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
with session.get(url, stream=True, headers=headers, timeout=10) as r: headers = {'Range': f'bytes={output_path.stat().st_size}-'}
r.raise_for_status() # Do not continue the download if the request was unsuccessful mode = 'ab'
total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB
tqdm_kwargs = { try:
'total': total_size, with session.get(url, stream=True, headers=headers, timeout=30) as r:
'unit': 'iB', r.raise_for_status() # If status is not 2xx, raise an error
'unit_scale': True, total_size = int(r.headers.get('content-length', 0))
'bar_format': '{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}' block_size = 1024 * 1024 # 1MB
}
if 'COLAB_GPU' in os.environ: tqdm_kwargs = {
tqdm_kwargs.update({ 'total': total_size,
'position': 0, 'unit': 'iB',
'leave': True 'unit_scale': True,
}) 'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} {rate_fmt}'
}
with open(output_path, mode) as f: if 'COLAB_GPU' in os.environ:
with tqdm.tqdm(**tqdm_kwargs) as t: tqdm_kwargs.update({
count = 0 'position': 0,
for data in r.iter_content(block_size): 'leave': True
t.update(len(data)) })
f.write(data)
if total_size != 0 and self.progress_bar is not None: with open(output_path, mode) as f:
count += len(data) with tqdm.tqdm(**tqdm_kwargs) as t:
self.progress_bar(float(count) / float(total_size), f"{filename}") count = 0
for data in r.iter_content(block_size):
f.write(data)
t.update(len(data))
if total_size != 0 and self.progress_bar is not None:
count += len(data)
self.progress_bar(float(count) / float(total_size), f"{filename}")
break # Exit loop if successful
except (RequestException, ConnectionError, Timeout) as e:
print(f"Error downloading {filename}: {e}.")
print(f"That was attempt {attempt}/{max_retries}.", end=' ')
if attempt < max_retries:
print(f"Retry begins in {2 ** attempt} seconds.")
sleep(2 ** attempt)
else:
print("Failed to download after the maximum number of attempts.")
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4): def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4):
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True) thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)