diff --git a/modules/exllama_hf.py b/modules/exllama_hf.py index 27cac374..64de7a5f 100644 --- a/modules/exllama_hf.py +++ b/modules/exllama_hf.py @@ -38,7 +38,6 @@ class ExllamaHF(PreTrainedModel): @property def device(self) -> torch.device: - # TODO: May cause problem on multi-gpu inference? return torch.device(0) def __call__(self, *args, **kwargs): @@ -50,7 +49,7 @@ class ExllamaHF(PreTrainedModel): if cache is None: cache = ExLlamaCache(self.ex_model) self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True) - logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(self.device) + logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(kwargs['input_ids'].device) return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None) @classmethod @@ -72,11 +71,14 @@ class ExllamaHF(PreTrainedModel): assert weight_path is not None, f'could not find weight in "{pretrained_model_name_or_path}"' config.model_path = str(weight_path) + + if shared.args.gpu_split: + config.set_auto_map(shared.args.gpu_split) + config.gpu_peer_fix = True # This slowes down a bit but align better with autogptq generation. # TODO: Should give user choice to tune the exllama config - config.act_order = True - config.fused_attn = False - config.fused_mlp_thd = 0 + # config.fused_attn = False + # config.fused_mlp_thd = 0 - return ExllamaHF(config) \ No newline at end of file + return ExllamaHF(config)