Pin PyTorch version to 2.1 (#5056)

This commit is contained in:
oobabooga 2024-01-04 23:50:23 -03:00 committed by GitHub
parent c9c31f71b8
commit 3d854ee516
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23

View file

@ -89,6 +89,7 @@ def torch_version():
torver = [line for line in torch_version_file if '__version__' in line][0].split('__version__ = ')[1].strip("'") torver = [line for line in torch_version_file if '__version__' in line][0].split('__version__ = ')[1].strip("'")
else: else:
from torch import __version__ as torver from torch import __version__ as torver
return torver return torver
@ -203,7 +204,7 @@ def install_webui():
# Find the proper Pytorch installation command # Find the proper Pytorch installation command
install_git = "conda install -y -k ninja git" install_git = "conda install -y -k ninja git"
install_pytorch = "python -m pip install torch torchvision torchaudio" install_pytorch = "python -m pip install torch==2.1.* torchvision==0.16.* torchaudio==2.1.* "
use_cuda118 = "N" use_cuda118 = "N"
if any((is_windows(), is_linux())) and selected_gpu == "NVIDIA": if any((is_windows(), is_linux())) and selected_gpu == "NVIDIA":
@ -219,20 +220,20 @@ def install_webui():
if use_cuda118 == 'Y': if use_cuda118 == 'Y':
print("CUDA: 11.8") print("CUDA: 11.8")
install_pytorch += "--index-url https://download.pytorch.org/whl/cu118"
else: else:
print("CUDA: 12.1") print("CUDA: 12.1")
install_pytorch += "--index-url https://download.pytorch.org/whl/cu121"
install_pytorch = f"python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/{'cu121' if use_cuda118 == 'N' else 'cu118'}"
elif not is_macos() and selected_gpu == "AMD": elif not is_macos() and selected_gpu == "AMD":
if is_linux(): if is_linux():
install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6" install_pytorch += "--index-url https://download.pytorch.org/whl/rocm5.6"
else: else:
print("AMD GPUs are only supported on Linux. Exiting...") print("AMD GPUs are only supported on Linux. Exiting...")
sys.exit(1) sys.exit(1)
elif is_linux() and selected_gpu in ["APPLE", "NONE"]: elif is_linux() and selected_gpu in ["APPLE", "NONE"]:
install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu" install_pytorch += "--index-url https://download.pytorch.org/whl/cpu"
elif selected_gpu == "INTEL": elif selected_gpu == "INTEL":
install_pytorch = "python -m pip install torch==2.1.0a0 torchvision==0.16.0a0 intel_extension_for_pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/" install_pytorch += "intel_extension_for_pytorch==2.1.* --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
# Install Git and then Pytorch # Install Git and then Pytorch
print_big_message("Installing PyTorch.") print_big_message("Installing PyTorch.")