mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-09-20 10:35:10 +02:00
Updated hf version too
This commit is contained in:
parent
377018eb22
commit
3e44373e8d
1 changed files with 28 additions and 14 deletions
|
@ -9,7 +9,8 @@ from exllamav2 import (
|
|||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Cache_Q4,
|
||||
ExLlamaV2Config
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2Cache_TP,
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
||||
|
@ -23,16 +24,15 @@ try:
|
|||
except ModuleNotFoundError:
|
||||
logger.warning(
|
||||
'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '
|
||||
'to be a lot higher than it could be.\n'
|
||||
'to be a lot higher than it could be.\\n'
|
||||
'Try installing flash-attention following the instructions here: '
|
||||
'https://github.com/Dao-AILab/flash-attention#installation-and-features'
|
||||
)
|
||||
pass
|
||||
except Exception:
|
||||
logger.warning('Failed to load flash-attention due to the following error:\n')
|
||||
logger.warning('Failed to load flash-attention due to the following error:\\n')
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class Exllamav2HF(PreTrainedModel):
|
||||
def __init__(self, config: ExLlamaV2Config):
|
||||
super().__init__(PretrainedConfig())
|
||||
|
@ -42,21 +42,34 @@ class Exllamav2HF(PreTrainedModel):
|
|||
|
||||
self.ex_model = ExLlamaV2(config)
|
||||
|
||||
# Check if TP is enabled and load model with TP
|
||||
if shared.args.enable_tp:
|
||||
split = None
|
||||
if shared.args.gpu_split:
|
||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||
self.ex_model.load_tp(split) # Ensure TP loading is used
|
||||
else:
|
||||
if not shared.args.autosplit:
|
||||
split = None
|
||||
if shared.args.gpu_split:
|
||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||
|
||||
self.ex_model.load(split)
|
||||
|
||||
# Determine the correct cache type
|
||||
if shared.args.cache_8bit:
|
||||
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)
|
||||
cache_type = ExLlamaV2Cache_8bit
|
||||
elif shared.args.cache_4bit:
|
||||
self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit)
|
||||
cache_type = ExLlamaV2Cache_Q4
|
||||
else:
|
||||
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)
|
||||
cache_type = ExLlamaV2Cache
|
||||
|
||||
if shared.args.autosplit:
|
||||
# Use TP if specified
|
||||
if shared.args.enable_tp:
|
||||
self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)
|
||||
else:
|
||||
self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit)
|
||||
|
||||
if shared.args.autosplit and not shared.args.enable_tp:
|
||||
self.ex_model.load_autosplit(self.ex_cache)
|
||||
|
||||
self.past_seq = None
|
||||
|
@ -181,3 +194,4 @@ class Exllamav2HF(PreTrainedModel):
|
|||
config.num_experts_per_token = int(shared.args.num_experts_per_token)
|
||||
|
||||
return Exllamav2HF(config)
|
||||
|
||||
|
|
Loading…
Reference in a new issue