Fix Multi-GPU not working on exllama_hf (#2803)

This commit is contained in:
Panchovix 2023-06-22 15:05:25 -04:00 committed by GitHub
parent d94ea31d54
commit b4a38c24b7
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23

View file

@ -38,7 +38,6 @@ class ExllamaHF(PreTrainedModel):
@property
def device(self) -> torch.device:
# TODO: May cause problem on multi-gpu inference?
return torch.device(0)
def __call__(self, *args, **kwargs):
@ -50,7 +49,7 @@ class ExllamaHF(PreTrainedModel):
if cache is None:
cache = ExLlamaCache(self.ex_model)
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(self.device)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(kwargs['input_ids'].device)
return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None)
@classmethod
@ -73,10 +72,13 @@ class ExllamaHF(PreTrainedModel):
config.model_path = str(weight_path)
if shared.args.gpu_split:
config.set_auto_map(shared.args.gpu_split)
config.gpu_peer_fix = True
# This slowes down a bit but align better with autogptq generation.
# TODO: Should give user choice to tune the exllama config
config.act_order = True
config.fused_attn = False
config.fused_mlp_thd = 0
# config.fused_attn = False
# config.fused_mlp_thd = 0
return ExllamaHF(config)