From e356f69b366e62bc0d104acab391ccc4889628a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=AB=E4=B9=90=E7=9A=84=E6=88=91531?= <2302004040@qq.com> Date: Sat, 24 Jun 2023 22:19:16 +0800 Subject: [PATCH] Make stop_everything work with non-streamed generation (#2848) --- modules/callbacks.py | 8 ++++++++ modules/text_generation.py | 6 ++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/modules/callbacks.py b/modules/callbacks.py index c61bddf8..1fa95e47 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -9,6 +9,14 @@ import transformers import modules.shared as shared +class _StopEverythingStoppingCriteria(transformers.StoppingCriteria): + def __init__(self): + transformers.StoppingCriteria.__init__(self) + + def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool: + return shared.stop_everything + + class Stream(transformers.StoppingCriteria): def __init__(self, callback_func=None): self.callback_func = callback_func diff --git a/modules/text_generation.py b/modules/text_generation.py index 78d74ed7..5e876ae8 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -9,7 +9,8 @@ import torch import transformers import modules.shared as shared -from modules.callbacks import Iteratorize, Stream +from modules.callbacks import (Iteratorize, Stream, + _StopEverythingStoppingCriteria) from modules.extensions import apply_extensions from modules.html_generator import generate_4chan_html, generate_basic_html from modules.logging_colors import logger @@ -252,10 +253,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings if inputs_embeds is not None: generate_params.update({'inputs_embeds': inputs_embeds}) - # Find the eos tokens + # Stopping criteria / eos token eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] generate_params['eos_token_id'] = eos_token_ids generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() + generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria()); t0 = time.time() try: