From ae8cd449ae3e0236ecb3775892bb1eea23f9ed68 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Thu, 19 Oct 2023 04:16:05 +0200 Subject: [PATCH] ExLlamav2_HF: Convert logits to FP32 (#4310) --- modules/exllamav2_hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index e12a0717..952d7172 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -108,10 +108,10 @@ class Exllamav2HF(PreTrainedModel): if len(seq_tensor) > 1: self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras) - logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device) + logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float() else: ex_cache.current_seq_len = 0 - logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras) + logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float() if is_negative: self.past_seq_negative = seq_tensor