diff --git a/server.py b/server.py index 07bcd220..943fb5d2 100644 --- a/server.py +++ b/server.py @@ -177,7 +177,7 @@ def load_soft_prompt(name): return name -def upload_softprompt_event(file): +def upload_soft_prompt(file): with zipfile.ZipFile(io.BytesIO(file)) as zf: zf.extract('meta.json') j = json.loads(open('meta.json', 'r').read()) @@ -276,6 +276,13 @@ def formatted_outputs(reply, model_name): else: return reply +def generate_softprompt_input_tensors(input_ids): + inputs_embeds = model.transformer.wte(input_ids) + inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1) + filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device) + filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens + return inputs_embeds, filler_input_ids + def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): global model_name, model, tokenizer, soft_prompt, soft_prompt_tensor @@ -326,43 +333,36 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top generate_params.append(f"max_new_tokens=8") if soft_prompt: - inputs_embeds = model.transformer.wte(input_ids) - inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1) - filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device) - filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens + inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) generate_params.insert(0, "inputs_embeds=inputs_embeds") generate_params.insert(0, "filler_input_ids") else: - filler_input_ids = None generate_params.insert(0, "input_ids") # Generate the entire reply at once if args.no_stream: t0 = time.time() with torch.no_grad(): - output = eval(f"model.generate({','.join(generate_params)}){cuda}") + output = eval(f"model.generate({','.join(generate_params)}){cuda}")[0] if soft_prompt: - output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:])) - else: - output = output[0] + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + reply = decode(output) - t1 = time.time() - print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)") if not (args.chat or args.cai_chat): reply = original_question + apply_extensions(reply[len(question):], "output") yield formatted_outputs(reply, model_name) + t1 = time.time() + print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)") + # Generate the reply 1 token at a time else: yield formatted_outputs(original_question, model_name) for i in tqdm(range(tokens//8+1)): with torch.no_grad(): - output = eval(f"model.generate({','.join(generate_params)}){cuda}") - + output = eval(f"model.generate({','.join(generate_params)}){cuda}")[0] if soft_prompt: - output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:])) - else: - output = output[0] + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) reply = decode(output) if not (args.chat or args.cai_chat): @@ -371,10 +371,7 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top input_ids = torch.reshape(output, (1, output.shape[0])) if soft_prompt: - inputs_embeds = model.transformer.wte(input_ids) - inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1) - filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device) - filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens + inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) if output[-1] == n: break @@ -486,7 +483,7 @@ def create_settings_menus(): model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True) preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping]) softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True) - upload_softprompt.upload(upload_softprompt_event, [upload_softprompt], [softprompts_menu]) + upload_softprompt.upload(upload_soft_prompt, [upload_softprompt], [softprompts_menu]) return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping # This gets the new line characters right.