mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-09-20 18:45:09 +02:00
Some simplifications
This commit is contained in:
parent
3277b751f5
commit
7739a29524
1 changed files with 19 additions and 22 deletions
41
server.py
41
server.py
|
@ -177,7 +177,7 @@ def load_soft_prompt(name):
|
||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def upload_softprompt_event(file):
|
def upload_soft_prompt(file):
|
||||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
||||||
zf.extract('meta.json')
|
zf.extract('meta.json')
|
||||||
j = json.loads(open('meta.json', 'r').read())
|
j = json.loads(open('meta.json', 'r').read())
|
||||||
|
@ -276,6 +276,13 @@ def formatted_outputs(reply, model_name):
|
||||||
else:
|
else:
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
def generate_softprompt_input_tensors(input_ids):
|
||||||
|
inputs_embeds = model.transformer.wte(input_ids)
|
||||||
|
inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
|
||||||
|
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
|
||||||
|
filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||||
|
return inputs_embeds, filler_input_ids
|
||||||
|
|
||||||
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
||||||
global model_name, model, tokenizer, soft_prompt, soft_prompt_tensor
|
global model_name, model, tokenizer, soft_prompt, soft_prompt_tensor
|
||||||
|
|
||||||
|
@ -326,43 +333,36 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
|
||||||
generate_params.append(f"max_new_tokens=8")
|
generate_params.append(f"max_new_tokens=8")
|
||||||
|
|
||||||
if soft_prompt:
|
if soft_prompt:
|
||||||
inputs_embeds = model.transformer.wte(input_ids)
|
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||||
inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
|
|
||||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
|
|
||||||
filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
|
|
||||||
generate_params.insert(0, "inputs_embeds=inputs_embeds")
|
generate_params.insert(0, "inputs_embeds=inputs_embeds")
|
||||||
generate_params.insert(0, "filler_input_ids")
|
generate_params.insert(0, "filler_input_ids")
|
||||||
else:
|
else:
|
||||||
filler_input_ids = None
|
|
||||||
generate_params.insert(0, "input_ids")
|
generate_params.insert(0, "input_ids")
|
||||||
|
|
||||||
# Generate the entire reply at once
|
# Generate the entire reply at once
|
||||||
if args.no_stream:
|
if args.no_stream:
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = eval(f"model.generate({','.join(generate_params)}){cuda}")
|
output = eval(f"model.generate({','.join(generate_params)}){cuda}")[0]
|
||||||
if soft_prompt:
|
if soft_prompt:
|
||||||
output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
|
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
else:
|
|
||||||
output = output[0]
|
|
||||||
reply = decode(output)
|
reply = decode(output)
|
||||||
t1 = time.time()
|
|
||||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
|
|
||||||
if not (args.chat or args.cai_chat):
|
if not (args.chat or args.cai_chat):
|
||||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||||
yield formatted_outputs(reply, model_name)
|
yield formatted_outputs(reply, model_name)
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
|
||||||
|
|
||||||
# Generate the reply 1 token at a time
|
# Generate the reply 1 token at a time
|
||||||
else:
|
else:
|
||||||
yield formatted_outputs(original_question, model_name)
|
yield formatted_outputs(original_question, model_name)
|
||||||
for i in tqdm(range(tokens//8+1)):
|
for i in tqdm(range(tokens//8+1)):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = eval(f"model.generate({','.join(generate_params)}){cuda}")
|
output = eval(f"model.generate({','.join(generate_params)}){cuda}")[0]
|
||||||
|
|
||||||
if soft_prompt:
|
if soft_prompt:
|
||||||
output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
|
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
else:
|
|
||||||
output = output[0]
|
|
||||||
|
|
||||||
reply = decode(output)
|
reply = decode(output)
|
||||||
if not (args.chat or args.cai_chat):
|
if not (args.chat or args.cai_chat):
|
||||||
|
@ -371,10 +371,7 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
|
||||||
|
|
||||||
input_ids = torch.reshape(output, (1, output.shape[0]))
|
input_ids = torch.reshape(output, (1, output.shape[0]))
|
||||||
if soft_prompt:
|
if soft_prompt:
|
||||||
inputs_embeds = model.transformer.wte(input_ids)
|
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||||
inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
|
|
||||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
|
|
||||||
filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
|
|
||||||
|
|
||||||
if output[-1] == n:
|
if output[-1] == n:
|
||||||
break
|
break
|
||||||
|
@ -486,7 +483,7 @@ def create_settings_menus():
|
||||||
model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True)
|
model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True)
|
||||||
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
|
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
|
||||||
softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True)
|
softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True)
|
||||||
upload_softprompt.upload(upload_softprompt_event, [upload_softprompt], [softprompts_menu])
|
upload_softprompt.upload(upload_soft_prompt, [upload_softprompt], [softprompts_menu])
|
||||||
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
|
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
|
||||||
|
|
||||||
# This gets the new line characters right.
|
# This gets the new line characters right.
|
||||||
|
|
Loading…
Reference in a new issue