From 5cb59707f316102967ca05dc869801fbe510ff51 Mon Sep 17 00:00:00 2001 From: A0nameless0man <1395943920@qq.com> Date: Mon, 20 May 2024 07:10:39 +0800 Subject: [PATCH] fix: grammar not support utf-8 (#5900) --- modules/grammar/grammar_utils.py | 51 +++++++++++++++++++------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/modules/grammar/grammar_utils.py b/modules/grammar/grammar_utils.py index 4b37c24a..140748d5 100644 --- a/modules/grammar/grammar_utils.py +++ b/modules/grammar/grammar_utils.py @@ -60,7 +60,7 @@ def hex_to_int(c): return int(c) elif "a" <= c.lower() <= "f": return ord(c.lower()) - ord("a") + 10 - return -1 + raise RuntimeError("unknown hex char " + c) def remove_leading_white_space(src, newline_ok): @@ -100,6 +100,13 @@ def parse_name(src): return src[:pos], src[pos:] +def read_hex(s): + val = 0 + for c in s: + val = (val << 4) + hex_to_int(c) + return chr(val) + + def parse_char(src): """ parse the leading char from the input string @@ -111,13 +118,12 @@ def parse_char(src): if src[0] == "\\": esc = src[1] if esc == "x": - first = hex_to_int(src[2]) - if first > -1: - second = hex_to_int(src[3]) - if second > -1: - return (first << 4) + second, src[4:] - raise RuntimeError("expecting \\xNN at " + src) - elif esc in ('"', "[", "]"): + return read_hex(src[2:4]), src[4:] + elif esc == "u": + return read_hex(src[2:6]), src[6:] + elif esc == "U": + return read_hex(src[2:10]), src[10:] + elif esc in ('"', "[", "]", "\\", "-"): return esc, src[2:] elif esc == "r": return "\r", src[2:] @@ -454,7 +460,8 @@ class IncrementalGrammarConstraint(GrammarConstraint): def __init__(self, grammar_str, start_rule_name, tokenizer): super().__init__(grammar_str, start_rule_name, tokenizer) - def accept_char(self, byte, stacks): + def accept_char(self, char, stacks): + byte = ord(char) new_stacks = [] for stack in stacks: # stack is empty @@ -471,6 +478,9 @@ class IncrementalGrammarConstraint(GrammarConstraint): if self.grammar_encoding[pos + i] <= byte and byte <= self.grammar_encoding[pos + i + 1]: found = True break + if self.grammar_encoding[pos + i] >= byte and byte >= self.grammar_encoding[pos + i + 1]: + found = True + break if not found: continue @@ -483,9 +493,8 @@ class IncrementalGrammarConstraint(GrammarConstraint): return new_stacks def accept_string(self, string: str, stacks: List[List[int]]): - _bytes = bytes(string, "utf-8") - for byte in _bytes: - stacks = self.accept_char(byte, stacks) + for char in string: + stacks = self.accept_char(char, stacks) return stacks def accept_token_id(self, token_id: int, stacks: List[List[int]]): @@ -537,16 +546,18 @@ class IncrementalGrammarConstraint(GrammarConstraint): # For each sub-rule in the grammar, cache whether each byte is accepted. @lru_cache(maxsize=None) - def pos_char_acceptance(self, pos): - acceptance = [False] * 256 + def pos_char_acceptance(self, pos, char): + byte = ord(char) num_chars = self.grammar_encoding[pos] pos += 1 for i in range(0, num_chars, 2): start = self.grammar_encoding[pos + i] end = self.grammar_encoding[pos + i + 1] - for j in range(start, end + 1): - acceptance[j] = True - return acceptance + if byte >= start and byte <= end: + return True + if byte <= start and byte >= end: + return True + return False # Probably this should be configurable. If the grammar has an exceedingly # large number of states, the correct setting is a tradeoff between GPU @@ -580,7 +591,7 @@ class IncrementalGrammarConstraint(GrammarConstraint): pos = stk[-1] num_chars = self.grammar_encoding[pos] - if not self.pos_char_acceptance(pos)[byte]: + if not self.pos_char_acceptance(pos, byte): continue pos += num_chars + 1 @@ -657,14 +668,14 @@ class TokenTrie: token = tokenizer.convert_ids_to_tokens(id) token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token) token = token.replace("▁", " ") - return bytes(token, "utf-8") + return token else: print("Warning: unrecognized tokenizer: using default token formatting") def fmt_token(id): token = tokenizer.convert_ids_to_tokens(id) - return bytes(token, "utf-8") + return token # note: vocab_size doesn't work here because there are also # get_added_vocab() tokens