Simplify encode() function

This commit is contained in:
oobabooga 2023-02-02 13:31:32 -03:00
parent afc2b0f4c8
commit 3f05cf5ddd

View file

@ -168,16 +168,13 @@ def fix_galactica(s):
return s
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
if args.cpu:
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
else:
torch.cuda.empty_cache()
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda()
if not args.deepspeed:
return input_ids
else:
elif args.deepspeed:
return input_ids.to(device=local_rank)
else:
return input_ids.cuda()
def decode(output_ids):
reply = tokenizer.decode(output_ids, skip_special_tokens=True)