DRY sampler improvements (#6053)

This commit is contained in:
Belladore 2024-06-13 05:39:11 +03:00 committed by GitHub
parent b675151f25
commit 3abafee696
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: B5690EEEBB952194

View file

@ -204,21 +204,25 @@ class DRYLogitsProcessor(LogitsProcessor):
input_ids = input_ids[:, -self._range:]
for input_ids_row, scores_row in zip(input_ids, scores):
# Raw integer must be extracted here to check for set membership.
last_token = input_ids_row[-1].item()
# Use normal Python data types for improved performance
input_ids = input_ids_row.tolist()
last_token = input_ids[-1]
if last_token in self.sequence_breakers:
continue
# Exclude the last token as it always matches.
match_indices = (input_ids_row[:-1] == last_token).nonzero()
match_indices = []
for idx, val in enumerate(input_ids[:-1]):
if val == last_token:
match_indices.append(idx)
# Stores the maximum matching sequence length
# for each token immediately following the sequence in the input.
match_lengths = {}
for i in match_indices:
next_token = input_ids_row[i + 1].item()
next_token = input_ids[i + 1]
if next_token in self.sequence_breakers:
continue
@ -227,15 +231,15 @@ class DRYLogitsProcessor(LogitsProcessor):
# so the match is at least of length 1.
match_length = 1
# Extend the match backwards as far as possible.
while True:
# Extend the match backwards (at most to 50 to prevent exponent overflow at penalty calculation) (this cap also improves performance on worst case)
while match_length < 50:
j = i - match_length
if j < 0:
# Start of input reached.
break
previous_token = input_ids_row[-(match_length + 1)].item()
if input_ids_row[j] != previous_token:
previous_token = input_ids[-(match_length + 1)]
if input_ids[j] != previous_token:
# Start of match reached.
break