diff --git a/modules/text_generation.py b/modules/text_generation.py index d539f6d4..e738cb21 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -258,6 +258,10 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi input_ids = np.reshape(output, (1, output.shape[0])) if shared.soft_prompt: inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) + generate_params.update({"inputs_embeds": inputs_embeds}) + generate_params.update({"inputs": filler_input_ids}) + else: + generate_params.update({"inputs": input_ids}) yield formatted_outputs(reply, shared.model_name)