diff --git a/modules/exllama.py b/modules/exllama.py index ecfb10a4..00b37b9c 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -94,11 +94,15 @@ class ExllamaModel: # Tokenizing the input ids = self.generator.tokenizer.encode(prompt) ids = ids[:, -get_max_prompt_length(state):] + if state['auto_max_new_tokens']: + max_new_tokens = state['truncation_length'] - ids.shape[-1] + else: + max_new_tokens = state['max_new_tokens'] self.generator.gen_begin_reuse(ids) initial_len = self.generator.sequence[0].shape[0] has_leading_space = False - for i in range(state['max_new_tokens']): + for i in range(max_new_tokens): token = self.generator.gen_single_token() if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): has_leading_space = True diff --git a/modules/loaders.py b/modules/loaders.py index 838ecc86..68b48204 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -151,6 +151,7 @@ loaders_samplers = { 'repetition_penalty_range', 'seed', 'ban_eos_token', + 'auto_max_new_tokens', }, 'AutoGPTQ': { 'temperature',