From 3e44373e8d50643ed963488bb21c13acf2482b92 Mon Sep 17 00:00:00 2001 From: RandoInternetPreson Date: Sat, 31 Aug 2024 11:17:48 -0400 Subject: [PATCH] Updated hf version too --- modules/exllamav2_hf.py | 42 +++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 53143d9a..79c63346 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -9,7 +9,8 @@ from exllamav2 import ( ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, - ExLlamaV2Config + ExLlamaV2Config, + ExLlamaV2Cache_TP, ) from torch.nn import CrossEntropyLoss from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel @@ -23,16 +24,15 @@ try: except ModuleNotFoundError: logger.warning( 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage ' - 'to be a lot higher than it could be.\n' + 'to be a lot higher than it could be.\\n' 'Try installing flash-attention following the instructions here: ' 'https://github.com/Dao-AILab/flash-attention#installation-and-features' ) pass except Exception: - logger.warning('Failed to load flash-attention due to the following error:\n') + logger.warning('Failed to load flash-attention due to the following error:\\n') traceback.print_exc() - class Exllamav2HF(PreTrainedModel): def __init__(self, config: ExLlamaV2Config): super().__init__(PretrainedConfig()) @@ -42,21 +42,34 @@ class Exllamav2HF(PreTrainedModel): self.ex_model = ExLlamaV2(config) - if not shared.args.autosplit: + # Check if TP is enabled and load model with TP + if shared.args.enable_tp: split = None if shared.args.gpu_split: split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] - - self.ex_model.load(split) - - if shared.args.cache_8bit: - self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit) - elif shared.args.cache_4bit: - self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit) + self.ex_model.load_tp(split) # Ensure TP loading is used else: - self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit) + if not shared.args.autosplit: + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + self.ex_model.load(split) - if shared.args.autosplit: + # Determine the correct cache type + if shared.args.cache_8bit: + cache_type = ExLlamaV2Cache_8bit + elif shared.args.cache_4bit: + cache_type = ExLlamaV2Cache_Q4 + else: + cache_type = ExLlamaV2Cache + + # Use TP if specified + if shared.args.enable_tp: + self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type) + else: + self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit) + + if shared.args.autosplit and not shared.args.enable_tp: self.ex_model.load_autosplit(self.ex_cache) self.past_seq = None @@ -181,3 +194,4 @@ class Exllamav2HF(PreTrainedModel): config.num_experts_per_token = int(shared.args.num_experts_per_token) return Exllamav2HF(config) +