diff --git a/modules/grammar.py b/modules/grammar.py new file mode 100644 index 00000000..5f6ad3a6 --- /dev/null +++ b/modules/grammar.py @@ -0,0 +1,33 @@ +from torch_grammar import GrammarSampler +from transformers.generation.logits_process import LogitsProcessor + +from modules import shared + +sampler = None +grammar = None +grammar_string = '' + + +class GrammarLogitsProcessor(LogitsProcessor): + def __init__(self, string): + + global sampler, grammar, grammar_string + + if string != grammar_string: + grammar_string = string + if string.strip() != '': + string = string.strip() + '\n' + sampler = GrammarSampler(string, 'root', shared.tokenizer) + else: + sampler = None + + if sampler is not None: + grammar = sampler.logits_processor() + else: + grammar = None + + def __call__(self, input_ids, scores): + if grammar is not None: + scores = grammar(input_ids, scores) + + return scores diff --git a/modules/loaders.py b/modules/loaders.py index 7580e30e..964fb00a 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -156,6 +156,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', 'guidance_scale', 'negative_prompt', 'ban_eos_token', @@ -183,6 +185,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', 'guidance_scale', 'negative_prompt', 'ban_eos_token', @@ -236,6 +240,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', 'guidance_scale', 'negative_prompt', 'ban_eos_token', @@ -267,6 +273,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', 'guidance_scale', 'negative_prompt', 'ban_eos_token', @@ -298,6 +306,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', 'guidance_scale', 'negative_prompt', 'ban_eos_token', @@ -339,6 +349,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', 'guidance_scale', 'negative_prompt', 'ban_eos_token', diff --git a/modules/text_generation.py b/modules/text_generation.py index ab556a94..a7f3509b 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -18,6 +18,7 @@ from modules.callbacks import ( _StopEverythingStoppingCriteria ) from modules.extensions import apply_extensions +from modules.grammar import GrammarLogitsProcessor from modules.html_generator import generate_4chan_html, generate_basic_html from modules.logging_colors import logger from modules.models import clear_torch_cache, local_rank @@ -319,6 +320,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings # In case a processor is passed by itself. if not isinstance(processor, LogitsProcessorList): processor = LogitsProcessorList([processor]) + processor.append(GrammarLogitsProcessor(state['grammar_string'])) apply_extensions('logits_processor', processor, input_ids) generate_params['logits_processor'] = processor diff --git a/requirements.txt b/requirements.txt index 9d519b97..b651c5e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" diff --git a/requirements_amd.txt b/requirements_amd.txt index b15e4d06..2ea6e355 100644 --- a/requirements_amd.txt +++ b/requirements_amd.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.38.1; platform_system != "Windows" diff --git a/requirements_amd_noavx2.txt b/requirements_amd_noavx2.txt index d567d798..18a81336 100644 --- a/requirements_amd_noavx2.txt +++ b/requirements_amd_noavx2.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.38.1; platform_system != "Windows" diff --git a/requirements_apple_intel.txt b/requirements_apple_intel.txt index 6a37726e..3d3896ab 100644 --- a/requirements_apple_intel.txt +++ b/requirements_apple_intel.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" diff --git a/requirements_apple_silicon.txt b/requirements_apple_silicon.txt index 76024e2f..fb8598e1 100644 --- a/requirements_apple_silicon.txt +++ b/requirements_apple_silicon.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" diff --git a/requirements_cpu_only.txt b/requirements_cpu_only.txt index 39f7e7d9..cde13ef4 100644 --- a/requirements_cpu_only.txt +++ b/requirements_cpu_only.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" diff --git a/requirements_cpu_only_noavx2.txt b/requirements_cpu_only_noavx2.txt index ec665001..bf7a5fda 100644 --- a/requirements_cpu_only_noavx2.txt +++ b/requirements_cpu_only_noavx2.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" diff --git a/requirements_noavx2.txt b/requirements_noavx2.txt index f0e3383d..6057d46d 100644 --- a/requirements_noavx2.txt +++ b/requirements_noavx2.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" diff --git a/requirements_nowheels.txt b/requirements_nowheels.txt index d71d82df..2984bd32 100644 --- a/requirements_nowheels.txt +++ b/requirements_nowheels.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows"