Merge pull request #5534 from oobabooga/dev

Merge dev branch
This commit is contained in:
oobabooga 2024-02-17 18:09:40 -03:00 committed by GitHub
commit 7838075990
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: B5690EEEBB952194
2 changed files with 19 additions and 16 deletions

View file

@ -51,20 +51,21 @@ class Exllamav2Model:
model = ExLlamaV2(config)
if shared.args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model, lazy=True)
else:
cache = ExLlamaV2Cache(model, lazy=True)
if shared.args.autosplit:
model.load_autosplit(cache)
else:
if not shared.args.autosplit:
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
model.load(split)
if shared.args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit)
else:
cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit)
if shared.args.autosplit:
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)

View file

@ -36,24 +36,26 @@ class Exllamav2HF(PreTrainedModel):
def __init__(self, config: ExLlamaV2Config):
super().__init__(PretrainedConfig())
self.ex_config = config
self.ex_model = ExLlamaV2(config)
self.loras = None
self.generation_config = GenerationConfig()
if shared.args.cache_8bit:
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=True)
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=True)
self.ex_model = ExLlamaV2(config)
if shared.args.autosplit:
self.ex_model.load_autosplit(self.ex_cache)
else:
if not shared.args.autosplit:
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
self.ex_model.load(split)
if shared.args.cache_8bit:
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)
if shared.args.autosplit:
self.ex_model.load_autosplit(self.ex_cache)
self.past_seq = None
if shared.args.cfg_cache:
if shared.args.cache_8bit: