Add grammar to transformers and _HF loaders (#4091)

This commit is contained in:
oobabooga 2023-10-05 10:01:36 -03:00 committed by GitHub
parent 0197fdddf1
commit ae4ba3007f
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 56 additions and 0 deletions

33
modules/grammar.py Normal file
View file

@ -0,0 +1,33 @@
from torch_grammar import GrammarSampler
from transformers.generation.logits_process import LogitsProcessor
from modules import shared
sampler = None
grammar = None
grammar_string = ''
class GrammarLogitsProcessor(LogitsProcessor):
def __init__(self, string):
global sampler, grammar, grammar_string
if string != grammar_string:
grammar_string = string
if string.strip() != '':
string = string.strip() + '\n'
sampler = GrammarSampler(string, 'root', shared.tokenizer)
else:
sampler = None
if sampler is not None:
grammar = sampler.logits_processor()
else:
grammar = None
def __call__(self, input_ids, scores):
if grammar is not None:
scores = grammar(input_ids, scores)
return scores

View file

@ -156,6 +156,8 @@ loaders_samplers = {
'mirostat_mode', 'mirostat_mode',
'mirostat_tau', 'mirostat_tau',
'mirostat_eta', 'mirostat_eta',
'grammar_file_row',
'grammar_string',
'guidance_scale', 'guidance_scale',
'negative_prompt', 'negative_prompt',
'ban_eos_token', 'ban_eos_token',
@ -183,6 +185,8 @@ loaders_samplers = {
'mirostat_mode', 'mirostat_mode',
'mirostat_tau', 'mirostat_tau',
'mirostat_eta', 'mirostat_eta',
'grammar_file_row',
'grammar_string',
'guidance_scale', 'guidance_scale',
'negative_prompt', 'negative_prompt',
'ban_eos_token', 'ban_eos_token',
@ -236,6 +240,8 @@ loaders_samplers = {
'mirostat_mode', 'mirostat_mode',
'mirostat_tau', 'mirostat_tau',
'mirostat_eta', 'mirostat_eta',
'grammar_file_row',
'grammar_string',
'guidance_scale', 'guidance_scale',
'negative_prompt', 'negative_prompt',
'ban_eos_token', 'ban_eos_token',
@ -267,6 +273,8 @@ loaders_samplers = {
'mirostat_mode', 'mirostat_mode',
'mirostat_tau', 'mirostat_tau',
'mirostat_eta', 'mirostat_eta',
'grammar_file_row',
'grammar_string',
'guidance_scale', 'guidance_scale',
'negative_prompt', 'negative_prompt',
'ban_eos_token', 'ban_eos_token',
@ -298,6 +306,8 @@ loaders_samplers = {
'mirostat_mode', 'mirostat_mode',
'mirostat_tau', 'mirostat_tau',
'mirostat_eta', 'mirostat_eta',
'grammar_file_row',
'grammar_string',
'guidance_scale', 'guidance_scale',
'negative_prompt', 'negative_prompt',
'ban_eos_token', 'ban_eos_token',
@ -339,6 +349,8 @@ loaders_samplers = {
'mirostat_mode', 'mirostat_mode',
'mirostat_tau', 'mirostat_tau',
'mirostat_eta', 'mirostat_eta',
'grammar_file_row',
'grammar_string',
'guidance_scale', 'guidance_scale',
'negative_prompt', 'negative_prompt',
'ban_eos_token', 'ban_eos_token',

View file

@ -18,6 +18,7 @@ from modules.callbacks import (
_StopEverythingStoppingCriteria _StopEverythingStoppingCriteria
) )
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.grammar import GrammarLogitsProcessor
from modules.html_generator import generate_4chan_html, generate_basic_html from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.models import clear_torch_cache, local_rank from modules.models import clear_torch_cache, local_rank
@ -319,6 +320,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
# In case a processor is passed by itself. # In case a processor is passed by itself.
if not isinstance(processor, LogitsProcessorList): if not isinstance(processor, LogitsProcessorList):
processor = LogitsProcessorList([processor]) processor = LogitsProcessorList([processor])
processor.append(GrammarLogitsProcessor(state['grammar_string']))
apply_extensions('logits_processor', processor, input_ids) apply_extensions('logits_processor', processor, input_ids)
generate_params['logits_processor'] = processor generate_params['logits_processor'] = processor

View file

@ -25,6 +25,7 @@ tqdm
wandb wandb
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
git+https://github.com/oobabooga/torch-grammar.git
# bitsandbytes # bitsandbytes
bitsandbytes==0.41.1; platform_system != "Windows" bitsandbytes==0.41.1; platform_system != "Windows"

View file

@ -25,6 +25,7 @@ tqdm
wandb wandb
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
git+https://github.com/oobabooga/torch-grammar.git
# bitsandbytes # bitsandbytes
bitsandbytes==0.38.1; platform_system != "Windows" bitsandbytes==0.38.1; platform_system != "Windows"

View file

@ -25,6 +25,7 @@ tqdm
wandb wandb
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
git+https://github.com/oobabooga/torch-grammar.git
# bitsandbytes # bitsandbytes
bitsandbytes==0.38.1; platform_system != "Windows" bitsandbytes==0.38.1; platform_system != "Windows"

View file

@ -25,6 +25,7 @@ tqdm
wandb wandb
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
git+https://github.com/oobabooga/torch-grammar.git
# bitsandbytes # bitsandbytes
bitsandbytes==0.41.1; platform_system != "Windows" bitsandbytes==0.41.1; platform_system != "Windows"

View file

@ -25,6 +25,7 @@ tqdm
wandb wandb
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
git+https://github.com/oobabooga/torch-grammar.git
# bitsandbytes # bitsandbytes
bitsandbytes==0.41.1; platform_system != "Windows" bitsandbytes==0.41.1; platform_system != "Windows"

View file

@ -25,6 +25,7 @@ tqdm
wandb wandb
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
git+https://github.com/oobabooga/torch-grammar.git
# bitsandbytes # bitsandbytes
bitsandbytes==0.41.1; platform_system != "Windows" bitsandbytes==0.41.1; platform_system != "Windows"

View file

@ -25,6 +25,7 @@ tqdm
wandb wandb
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
git+https://github.com/oobabooga/torch-grammar.git
# bitsandbytes # bitsandbytes
bitsandbytes==0.41.1; platform_system != "Windows" bitsandbytes==0.41.1; platform_system != "Windows"

View file

@ -25,6 +25,7 @@ tqdm
wandb wandb
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
git+https://github.com/oobabooga/torch-grammar.git
# bitsandbytes # bitsandbytes
bitsandbytes==0.41.1; platform_system != "Windows" bitsandbytes==0.41.1; platform_system != "Windows"

View file

@ -25,6 +25,7 @@ tqdm
wandb wandb
git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7
git+https://github.com/oobabooga/torch-grammar.git
# bitsandbytes # bitsandbytes
bitsandbytes==0.41.1; platform_system != "Windows" bitsandbytes==0.41.1; platform_system != "Windows"