diff --git a/modules/llama_cpp_python_hijack.py b/modules/llama_cpp_python_hijack.py index f3f3f560..c9813708 100644 --- a/modules/llama_cpp_python_hijack.py +++ b/modules/llama_cpp_python_hijack.py @@ -100,9 +100,11 @@ def eval_with_progress(self, tokens: Sequence[int]): def monkey_patch_llama_cpp_python(lib): + if getattr(lib.Llama, '_is_patched', False): + # If the patch is already applied, do nothing + return def my_generate(self, *args, **kwargs): - if shared.args.streaming_llm: new_sequence = args[0] past_sequence = self._input_ids @@ -116,3 +118,6 @@ def monkey_patch_llama_cpp_python(lib): lib.Llama.eval = eval_with_progress lib.Llama.original_generate = lib.Llama.generate lib.Llama.generate = my_generate + + # Set the flag to indicate that the patch has been applied + lib.Llama._is_patched = True