Fix getting Phi-3-small-128k-instruct logits

This commit is contained in:
oobabooga 2024-05-21 10:35:00 -07:00
parent bd7cc4234d
commit ae86292159

View file

@ -1,4 +1,5 @@
import time
import traceback
import torch
from transformers import is_torch_npu_available, is_torch_xpu_available
@ -18,7 +19,8 @@ def get_next_logits(*args, **kwargs):
shared.generation_lock.acquire()
try:
result = _get_next_logits(*args, **kwargs)
except:
except Exception:
traceback.print_exc()
result = None
models.last_generation_time = time.time()
@ -84,7 +86,14 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
topk_values = [float(i) for i in topk_values]
output = {}
for row in list(zip(topk_values, tokens)):
output[row[1]] = row[0]
key = row[1]
if isinstance(key, bytes):
try:
key = key.decode()
except:
key = key.decode('latin')
output[key] = row[0]
return output
else: