Don't download if --check is specified

This commit is contained in:
oobabooga 2023-03-31 01:31:47 -03:00
parent 0cc89e7755
commit 92c7068daf

View file

@ -9,6 +9,7 @@ python download-model.py facebook/opt-1.3b
import argparse
import base64
import datetime
import hashlib
import json
import re
import sys
@ -17,7 +18,6 @@ from pathlib import Path
import requests
import tqdm
from tqdm.contrib.concurrent import thread_map
import hashlib
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
@ -212,22 +212,32 @@ if __name__ == '__main__':
if sha256_str != '':
f.write(f'sha256sum:\n{sha256_str}')
# Downloading the files
print(f"Downloading the model to {output_folder}")
download_files(links, output_folder, args.threads)
if args.check:
# Validate the checksums
validated = True
for i in range(len(sha256)):
fpath = (output_folder / sha256[i][0])
if not fpath.exists():
print(f"The following file is missing: {fpath}")
validated = False
continue
with open(output_folder / sha256[i][0], "rb") as f:
bytes = f.read()
file_hash = hashlib.sha256(bytes).hexdigest()
if file_hash != sha256[i][1]:
print(f'[!] Checksum for {sha256[i][0]} failed!')
print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
validated = False
else:
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
if validated:
print('[+] Validated checksums of all model files!')
else:
print('[-] Rerun the download-model.py with --clean flag')
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
else:
# Downloading the files
print(f"Downloading the model to {output_folder}")
download_files(links, output_folder, args.threads)