From cf0697936a1f1434f6064747f80a6acb3d861fb9 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 8 Mar 2024 21:39:02 -0800 Subject: [PATCH] Optimize StreamingLLM by over 10x --- modules/cache_utils.py | 14 ++++++++++---- modules/text_generation.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/modules/cache_utils.py b/modules/cache_utils.py index 3a200d8e..3f5a0f31 100644 --- a/modules/cache_utils.py +++ b/modules/cache_utils.py @@ -1,10 +1,13 @@ import torch +from numba import njit from modules import shared -from modules.logging_colors import logger def process_llamacpp_cache(model, new_sequence, past_sequence): + if len(past_sequence) == 0 or len(new_sequence) == 0: + return past_sequence + i1, i2, j1, j2 = find_longest_common_substring_indices(past_sequence, new_sequence) overlap_length = i2 - i1 + 1 @@ -65,6 +68,7 @@ def find_prefix_length(past_seq, seq_tensor): return prefix_length +@njit def find_longest_common_substring_indices(list1, list2): ''' Given two lists, solves the Longest Common Substring problem. @@ -86,11 +90,13 @@ def find_longest_common_substring_indices(list1, list2): start_index_list1, end_index_list1 = 0, -1 start_index_list2, end_index_list2 = 0, -1 - for index1 in range(len_list1): + # for index1 in tqdm(range(0, len_list1), desc="StreamingLLM prompt comparison", leave=False): + for index1 in range(0, len_list1): try: index2 = list2.index(list1[index1]) - except ValueError: + except: continue + while index2 >= 0: temp_index1, temp_index2 = index1, index2 while temp_index1 < len_list1 and temp_index2 < len_list2 and list2[temp_index2] == list1[temp_index1]: @@ -102,7 +108,7 @@ def find_longest_common_substring_indices(list1, list2): temp_index2 += 1 try: index2 = list2.index(list1[index1], index2 + 1) - except ValueError: + except: break return start_index_list1, end_index_list1, start_index_list2, end_index_list2 diff --git a/modules/text_generation.py b/modules/text_generation.py index dc9c63ea..d1a59a9d 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -367,7 +367,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings # Handle StreamingLLM for llamacpp_HF if shared.model.__class__.__name__ == 'LlamacppHF' and shared.args.streaming_llm: - tmp = process_llamacpp_cache(shared.model.model, input_ids[-1].tolist(), shared.model.model._input_ids) + tmp = process_llamacpp_cache(shared.model.model, input_ids[-1].tolist(), shared.model.model._input_ids.tolist()) shared.model.past_seq = torch.tensor(tmp) shared.model.save_cache()