Optimize StreamingLLM by over 10x

This commit is contained in:
oobabooga 2024-03-08 21:39:02 -08:00
parent afb51bd5d6
commit cf0697936a
2 changed files with 11 additions and 5 deletions

View file

@ -1,10 +1,13 @@
import torch import torch
from numba import njit
from modules import shared from modules import shared
from modules.logging_colors import logger
def process_llamacpp_cache(model, new_sequence, past_sequence): def process_llamacpp_cache(model, new_sequence, past_sequence):
if len(past_sequence) == 0 or len(new_sequence) == 0:
return past_sequence
i1, i2, j1, j2 = find_longest_common_substring_indices(past_sequence, new_sequence) i1, i2, j1, j2 = find_longest_common_substring_indices(past_sequence, new_sequence)
overlap_length = i2 - i1 + 1 overlap_length = i2 - i1 + 1
@ -65,6 +68,7 @@ def find_prefix_length(past_seq, seq_tensor):
return prefix_length return prefix_length
@njit
def find_longest_common_substring_indices(list1, list2): def find_longest_common_substring_indices(list1, list2):
''' '''
Given two lists, solves the Longest Common Substring problem. Given two lists, solves the Longest Common Substring problem.
@ -86,11 +90,13 @@ def find_longest_common_substring_indices(list1, list2):
start_index_list1, end_index_list1 = 0, -1 start_index_list1, end_index_list1 = 0, -1
start_index_list2, end_index_list2 = 0, -1 start_index_list2, end_index_list2 = 0, -1
for index1 in range(len_list1): # for index1 in tqdm(range(0, len_list1), desc="StreamingLLM prompt comparison", leave=False):
for index1 in range(0, len_list1):
try: try:
index2 = list2.index(list1[index1]) index2 = list2.index(list1[index1])
except ValueError: except:
continue continue
while index2 >= 0: while index2 >= 0:
temp_index1, temp_index2 = index1, index2 temp_index1, temp_index2 = index1, index2
while temp_index1 < len_list1 and temp_index2 < len_list2 and list2[temp_index2] == list1[temp_index1]: while temp_index1 < len_list1 and temp_index2 < len_list2 and list2[temp_index2] == list1[temp_index1]:
@ -102,7 +108,7 @@ def find_longest_common_substring_indices(list1, list2):
temp_index2 += 1 temp_index2 += 1
try: try:
index2 = list2.index(list1[index1], index2 + 1) index2 = list2.index(list1[index1], index2 + 1)
except ValueError: except:
break break
return start_index_list1, end_index_list1, start_index_list2, end_index_list2 return start_index_list1, end_index_list1, start_index_list2, end_index_list2

View file

@ -367,7 +367,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
# Handle StreamingLLM for llamacpp_HF # Handle StreamingLLM for llamacpp_HF
if shared.model.__class__.__name__ == 'LlamacppHF' and shared.args.streaming_llm: if shared.model.__class__.__name__ == 'LlamacppHF' and shared.args.streaming_llm:
tmp = process_llamacpp_cache(shared.model.model, input_ids[-1].tolist(), shared.model.model._input_ids) tmp = process_llamacpp_cache(shared.model.model, input_ids[-1].tolist(), shared.model.model._input_ids.tolist())
shared.model.past_seq = torch.tensor(tmp) shared.model.past_seq = torch.tensor(tmp)
shared.model.save_cache() shared.model.save_cache()