Updated hf version too

This commit is contained in:
RandoInternetPreson 2024-08-31 11:17:48 -04:00 committed by GitHub
parent 377018eb22
commit 3e44373e8d
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: B5690EEEBB952194

View file

@ -9,7 +9,8 @@ from exllamav2 import (
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Config
ExLlamaV2Config,
ExLlamaV2Cache_TP,
)
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
@ -23,16 +24,15 @@ try:
except ModuleNotFoundError:
logger.warning(
'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '
'to be a lot higher than it could be.\n'
'to be a lot higher than it could be.\\n'
'Try installing flash-attention following the instructions here: '
'https://github.com/Dao-AILab/flash-attention#installation-and-features'
)
pass
except Exception:
logger.warning('Failed to load flash-attention due to the following error:\n')
logger.warning('Failed to load flash-attention due to the following error:\\n')
traceback.print_exc()
class Exllamav2HF(PreTrainedModel):
def __init__(self, config: ExLlamaV2Config):
super().__init__(PretrainedConfig())
@ -42,21 +42,34 @@ class Exllamav2HF(PreTrainedModel):
self.ex_model = ExLlamaV2(config)
# Check if TP is enabled and load model with TP
if shared.args.enable_tp:
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
self.ex_model.load_tp(split) # Ensure TP loading is used
else:
if not shared.args.autosplit:
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
self.ex_model.load(split)
# Determine the correct cache type
if shared.args.cache_8bit:
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache_8bit
elif shared.args.cache_4bit:
self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache_Q4
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache
if shared.args.autosplit:
# Use TP if specified
if shared.args.enable_tp:
self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)
else:
self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit)
if shared.args.autosplit and not shared.args.enable_tp:
self.ex_model.load_autosplit(self.ex_cache)
self.past_seq = None
@ -181,3 +194,4 @@ class Exllamav2HF(PreTrainedModel):
config.num_experts_per_token = int(shared.args.num_experts_per_token)
return Exllamav2HF(config)