mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-09-20 10:35:10 +02:00
Check for EOS and \n
This commit is contained in:
parent
af7b57cce0
commit
0f62744df1
1 changed files with 13 additions and 0 deletions
|
@ -198,6 +198,12 @@ class XTCLogitsWarper(LogitsWarper):
|
|||
self.threshold = threshold
|
||||
self.probability = probability
|
||||
self.filter_value = filter_value
|
||||
self.special_token_ids = [
|
||||
shared.tokenizer.encode("\n")[-1],
|
||||
]
|
||||
|
||||
if shared.tokenizer.eos_token_id is not None:
|
||||
self.special_token_ids.append(shared.tokenizer.eos_token_id)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# `random` returns values in the half-open range [0, 1), so setting `probability`
|
||||
|
@ -221,7 +227,14 @@ class XTCLogitsWarper(LogitsWarper):
|
|||
# of all tokens that meet the threshold, *except* the least probable one.
|
||||
sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold
|
||||
|
||||
# Convert sorted_indices_to_remove to the original indices
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
|
||||
# If newline or EOS tokens would be removed, return the original scores
|
||||
if indices_to_remove[:, self.special_token_ids].any()
|
||||
return scores
|
||||
|
||||
# Otherwise, remove tokens with the mask
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
|
Loading…
Reference in a new issue