Check for EOS and \n

This commit is contained in:
oobabooga 2024-09-02 19:54:47 -07:00
parent af7b57cce0
commit 0f62744df1

View file

@ -198,6 +198,12 @@ class XTCLogitsWarper(LogitsWarper):
self.threshold = threshold
self.probability = probability
self.filter_value = filter_value
self.special_token_ids = [
shared.tokenizer.encode("\n")[-1],
]
if shared.tokenizer.eos_token_id is not None:
self.special_token_ids.append(shared.tokenizer.eos_token_id)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# `random` returns values in the half-open range [0, 1), so setting `probability`
@ -221,7 +227,14 @@ class XTCLogitsWarper(LogitsWarper):
# of all tokens that meet the threshold, *except* the least probable one.
sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold
# Convert sorted_indices_to_remove to the original indices
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
# If newline or EOS tokens would be removed, return the original scores
if indices_to_remove[:, self.special_token_ids].any()
return scores
# Otherwise, remove tokens with the mask
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores