Failed attempt at evaluating exllama_hf perplexity

This commit is contained in:
oobabooga 2023-06-24 12:02:25 -03:00
parent e356f69b36
commit cec5fb0ef6
2 changed files with 21 additions and 10 deletions

View file

@ -100,7 +100,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = shared.model(input_ids, labels=target_ids)
outputs = shared.model(input_ids=input_ids, labels=target_ids)
# loss is calculated using CrossEntropyLoss which averages over valid labels
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels

View file

@ -1,15 +1,10 @@
import os
import sys
from pathlib import Path
from typing import *
from typing import Any, Dict, Optional, Union
import torch
from transformers import (
GenerationConfig,
LlamaTokenizer,
PretrainedConfig,
PreTrainedModel
)
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared
@ -43,13 +38,29 @@ class ExllamaHF(PreTrainedModel):
def __call__(self, *args, **kwargs):
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
assert len(args) == 0, 'no *args should be passed to forward'
use_cache = kwargs['use_cache']
use_cache = kwargs.get('use_cache', True)
labels = kwargs.get('labels', None)
seq = kwargs['input_ids'][0].tolist()
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
if cache is None:
cache = ExLlamaCache(self.ex_model)
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(kwargs['input_ids'].device)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, logits.shape[-1])
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None)
@classmethod