From 3abafee69605144a0edc9f53dd318853d3e80e68 Mon Sep 17 00:00:00 2001 From: Belladore <135602125+belladoreai@users.noreply.github.com> Date: Thu, 13 Jun 2024 05:39:11 +0300 Subject: [PATCH] DRY sampler improvements (#6053) --- modules/sampler_hijack.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index bf8ecf3a..ad74d658 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -204,21 +204,25 @@ class DRYLogitsProcessor(LogitsProcessor): input_ids = input_ids[:, -self._range:] for input_ids_row, scores_row in zip(input_ids, scores): - # Raw integer must be extracted here to check for set membership. - last_token = input_ids_row[-1].item() + # Use normal Python data types for improved performance + input_ids = input_ids_row.tolist() + last_token = input_ids[-1] if last_token in self.sequence_breakers: continue # Exclude the last token as it always matches. - match_indices = (input_ids_row[:-1] == last_token).nonzero() + match_indices = [] + for idx, val in enumerate(input_ids[:-1]): + if val == last_token: + match_indices.append(idx) # Stores the maximum matching sequence length # for each token immediately following the sequence in the input. match_lengths = {} for i in match_indices: - next_token = input_ids_row[i + 1].item() + next_token = input_ids[i + 1] if next_token in self.sequence_breakers: continue @@ -227,15 +231,15 @@ class DRYLogitsProcessor(LogitsProcessor): # so the match is at least of length 1. match_length = 1 - # Extend the match backwards as far as possible. - while True: + # Extend the match backwards (at most to 50 to prevent exponent overflow at penalty calculation) (this cap also improves performance on worst case) + while match_length < 50: j = i - match_length if j < 0: # Start of input reached. break - previous_token = input_ids_row[-(match_length + 1)].item() - if input_ids_row[j] != previous_token: + previous_token = input_ids[-(match_length + 1)] + if input_ids[j] != previous_token: # Start of match reached. break