Obtain the EOT token from the jinja template (attempt)

To use as a stopping string.
This commit is contained in:
oobabooga 2024-06-30 15:09:22 -07:00
parent 3e3f8637d6
commit ed01322763

View file

@ -3,6 +3,7 @@ import copy
import functools
import html
import json
import pprint
import re
from datetime import datetime
from functools import partial
@ -259,10 +260,27 @@ def get_stopping_strings(state):
suffix_bot + prefix_user,
]
# Try to find the EOT token
for item in stopping_strings.copy():
item = item.strip()
if item.startswith("<") and ">" in item:
stopping_strings.append(item.split(">")[0] + ">")
elif item.startswith("[") and "]" in item:
stopping_strings.append(item.split("]")[0] + "]")
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
stopping_strings += state.pop('stopping_strings')
return list(set(stopping_strings))
# Remove redundant items that start with another item
result = [item for item in stopping_strings if not any(item.startswith(other) and item != other for other in stopping_strings)]
result = list(set(result))
if shared.args.verbose:
logger.info("STOPPING_STRINGS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(result)
print()
return result
def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):