From b227e65d86ac36cd09134f72c37d17f22e00be12 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 24 Sep 2023 07:08:41 -0700 Subject: [PATCH] Add grammar to llama.cpp loader (closes #4019) --- api-examples/api-example-chat-stream.py | 1 + api-examples/api-example-chat.py | 1 + api-examples/api-example-stream.py | 1 + api-examples/api-example.py | 1 + extensions/api/util.py | 1 + extensions/openai/defaults.py | 1 + grammars/arithmetic.gbnf | 6 ++++ grammars/c.gbnf | 42 +++++++++++++++++++++++++ grammars/chess.gbnf | 13 ++++++++ grammars/japanese.gbnf | 7 +++++ grammars/json.gbnf | 25 +++++++++++++++ grammars/json_arr.gbnf | 34 ++++++++++++++++++++ grammars/list.gbnf | 4 +++ modules/llamacpp_model.py | 16 ++++++++++ modules/loaders.py | 1 + modules/ui.py | 1 + modules/ui_parameters.py | 3 ++ modules/utils.py | 4 +++ 18 files changed, 162 insertions(+) create mode 100644 grammars/arithmetic.gbnf create mode 100644 grammars/c.gbnf create mode 100644 grammars/chess.gbnf create mode 100644 grammars/japanese.gbnf create mode 100644 grammars/json.gbnf create mode 100644 grammars/json_arr.gbnf create mode 100644 grammars/list.gbnf diff --git a/api-examples/api-example-chat-stream.py b/api-examples/api-example-chat-stream.py index bf4201ca..abf05e11 100644 --- a/api-examples/api-example-chat-stream.py +++ b/api-examples/api-example-chat-stream.py @@ -63,6 +63,7 @@ async def run(user_input, history): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'grammar_file': '', 'guidance_scale': 1, 'negative_prompt': '', diff --git a/api-examples/api-example-chat.py b/api-examples/api-example-chat.py index 42ba0a62..f53dbe4c 100644 --- a/api-examples/api-example-chat.py +++ b/api-examples/api-example-chat.py @@ -57,6 +57,7 @@ def run(user_input, history): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'grammar_file': '', 'guidance_scale': 1, 'negative_prompt': '', diff --git a/api-examples/api-example-stream.py b/api-examples/api-example-stream.py index 53822162..fbae4a6c 100644 --- a/api-examples/api-example-stream.py +++ b/api-examples/api-example-stream.py @@ -46,6 +46,7 @@ async def run(context): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'grammar_file': '', 'guidance_scale': 1, 'negative_prompt': '', diff --git a/api-examples/api-example.py b/api-examples/api-example.py index e6d79f9b..aae2a1ae 100644 --- a/api-examples/api-example.py +++ b/api-examples/api-example.py @@ -38,6 +38,7 @@ def run(prompt): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'grammar_file': '', 'guidance_scale': 1, 'negative_prompt': '', diff --git a/extensions/api/util.py b/extensions/api/util.py index e4f7738f..53111141 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -44,6 +44,7 @@ def build_parameters(body, chat=False): 'mirostat_mode': int(body.get('mirostat_mode', 0)), 'mirostat_tau': float(body.get('mirostat_tau', 5)), 'mirostat_eta': float(body.get('mirostat_eta', 0.1)), + 'grammar_file': str(body.get('grammar_file', '')), 'guidance_scale': float(body.get('guidance_scale', 1)), 'negative_prompt': str(body.get('negative_prompt', '')), 'seed': int(body.get('seed', -1)), diff --git a/extensions/openai/defaults.py b/extensions/openai/defaults.py index 7bc5ab2a..4c1da893 100644 --- a/extensions/openai/defaults.py +++ b/extensions/openai/defaults.py @@ -34,6 +34,7 @@ default_req_params = { 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1, + 'grammar_file': '', 'guidance_scale': 1, 'negative_prompt': '', 'ban_eos_token': False, diff --git a/grammars/arithmetic.gbnf b/grammars/arithmetic.gbnf new file mode 100644 index 00000000..3aa95a9d --- /dev/null +++ b/grammars/arithmetic.gbnf @@ -0,0 +1,6 @@ +root ::= (expr "=" ws term "\n")+ +expr ::= term ([-+*/] term)* +term ::= ident | num | "(" ws expr ")" ws +ident ::= [a-z] [a-z0-9_]* ws +num ::= [0-9]+ ws +ws ::= [ \t\n]* diff --git a/grammars/c.gbnf b/grammars/c.gbnf new file mode 100644 index 00000000..4a0331dd --- /dev/null +++ b/grammars/c.gbnf @@ -0,0 +1,42 @@ +root ::= (declaration)* + +declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}" + +dataType ::= "int" ws | "float" ws | "char" ws +identifier ::= [a-zA-Z_] [a-zA-Z_0-9]* + +parameter ::= dataType identifier + +statement ::= + ( dataType identifier ws "=" ws expression ";" ) | + ( identifier ws "=" ws expression ";" ) | + ( identifier ws "(" argList? ")" ";" ) | + ( "return" ws expression ";" ) | + ( "while" "(" condition ")" "{" statement* "}" ) | + ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) | + ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) | + ( singleLineComment ) | + ( multiLineComment ) + +forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression +forUpdate ::= identifier ws "=" ws expression + +condition ::= expression relationOperator expression +relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">") + +expression ::= term (("+" | "-") term)* +term ::= factor(("*" | "/") factor)* + +factor ::= identifier | number | unaryTerm | funcCall | parenExpression +unaryTerm ::= "-" factor +funcCall ::= identifier "(" argList? ")" +parenExpression ::= "(" ws expression ws ")" + +argList ::= expression ("," ws expression)* + +number ::= [0-9]+ + +singleLineComment ::= "//" [^\n]* "\n" +multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/" + +ws ::= ([ \t\n]+) diff --git a/grammars/chess.gbnf b/grammars/chess.gbnf new file mode 100644 index 00000000..ef0fc1b0 --- /dev/null +++ b/grammars/chess.gbnf @@ -0,0 +1,13 @@ +# Specifies chess moves as a list in algebraic notation, using PGN conventions + +# Force first move to "1. ", then any 1-2 digit number after, relying on model to follow the pattern +root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+ +move ::= (pawn | nonpawn | castle) [+#]? + +# piece type, optional file/rank, optional capture, dest file & rank +nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8] + +# optional file & capture, dest file & rank, optional promotion +pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])? + +castle ::= "O-O" "-O"? diff --git a/grammars/japanese.gbnf b/grammars/japanese.gbnf new file mode 100644 index 00000000..43f25ab5 --- /dev/null +++ b/grammars/japanese.gbnf @@ -0,0 +1,7 @@ +# A probably incorrect grammar for Japanese +root ::= jp-char+ ([ \t\n] jp-char+)* +jp-char ::= hiragana | katakana | punctuation | cjk +hiragana ::= [ぁ-ゟ] +katakana ::= [ァ-ヿ] +punctuation ::= [、-〾] +cjk ::= [一-鿿] diff --git a/grammars/json.gbnf b/grammars/json.gbnf new file mode 100644 index 00000000..a9537cdf --- /dev/null +++ b/grammars/json.gbnf @@ -0,0 +1,25 @@ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? diff --git a/grammars/json_arr.gbnf b/grammars/json_arr.gbnf new file mode 100644 index 00000000..ef53e77a --- /dev/null +++ b/grammars/json_arr.gbnf @@ -0,0 +1,34 @@ +# This is the same as json.gbnf but we restrict whitespaces at the end of the root array +# Useful for generating JSON arrays + +root ::= arr +value ::= object | array | string | number | ("true" | "false" | "null") ws + +arr ::= + "[\n" ws ( + value + (",\n" ws value)* + )? "]" + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? diff --git a/grammars/list.gbnf b/grammars/list.gbnf new file mode 100644 index 00000000..51e6c9c4 --- /dev/null +++ b/grammars/list.gbnf @@ -0,0 +1,4 @@ +root ::= item+ + +# Excludes various line break characters +item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index db2b8f3b..72b04aaf 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -1,5 +1,6 @@ import re from functools import partial +from pathlib import Path import numpy as np import torch @@ -42,6 +43,8 @@ def custom_token_ban_logits_processor(token_ids, input_ids, logits): class LlamaCppModel: def __init__(self): self.initialized = False + self.grammar_file = 'None' + self.grammar = None def __del__(self): self.model.__del__() @@ -107,6 +110,17 @@ class LlamaCppModel: logits = np.expand_dims(logits, 0) # batch dim is expected return torch.tensor(logits, dtype=torch.float32) + def load_grammar(self, fname): + if fname != self.grammar_file: + self.grammar_file = fname + p = Path(f'grammars/{fname}') + print(p) + if p.exists(): + logger.info(f'Loading the following grammar file: {p}') + self.grammar = llama_cpp_lib().LlamaGrammar.from_file(str(p)) + else: + self.grammar = None + def generate(self, prompt, state, callback=None): LogitsProcessorList = llama_cpp_lib().LogitsProcessorList @@ -118,6 +132,7 @@ class LlamaCppModel: prompt = prompt[-get_max_prompt_length(state):] prompt = self.decode(prompt) + self.load_grammar(state['grammar_file']) logit_processors = LogitsProcessorList() if state['ban_eos_token']: logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos())) @@ -140,6 +155,7 @@ class LlamaCppModel: mirostat_eta=state['mirostat_eta'], stream=True, logits_processor=logit_processors, + grammar=self.grammar ) output = "" diff --git a/modules/loaders.py b/modules/loaders.py index b7187e5f..460d6a27 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -305,6 +305,7 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file', 'ban_eos_token', 'custom_token_bans', }, diff --git a/modules/ui.py b/modules/ui.py index 0a19b231..d6da5cfc 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -114,6 +114,7 @@ def list_interface_input_elements(): 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file', 'negative_prompt', 'guidance_scale', 'add_bos_token', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 9fbe6456..f2537e8b 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -108,6 +108,9 @@ def create_ui(default_preset): shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.') shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') + with gr.Row(): + shared.gradio['grammar_file'] = gr.Dropdown(value='None', choices=utils.get_available_grammars(), label='Grammar file (GBNF)', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['grammar_file'], lambda: None, lambda: {'choices': utils.get_available_grammars()}, 'refresh-button') with gr.Box(): with gr.Row(): diff --git a/modules/utils.py b/modules/utils.py index f60597a6..e6449052 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -124,3 +124,7 @@ def get_datasets(path: str, ext: str): def get_available_chat_styles(): return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys) + + +def get_available_grammars(): + return ['None'] + sorted([item.name for item in list(Path('grammars').glob('*.gbnf'))], key=natural_keys)