From 1b69694fe9c461b901b6050d8e1c164166e39d3c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 7 Nov 2023 19:05:36 -0800 Subject: [PATCH] Add types to the encode/decode/token-count endpoints --- extensions/openai/script.py | 30 +++++++++++++++--------------- extensions/openai/tokens.py | 28 +++++++++------------------- extensions/openai/typing.py | 21 +++++++++++++++++++++ modules/llamacpp_model.py | 2 +- modules/text_generation.py | 2 +- 5 files changed, 47 insertions(+), 36 deletions(-) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 72c2776b..361b97a3 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -27,7 +27,12 @@ from .typing import ( ChatCompletionResponse, CompletionRequest, CompletionResponse, + DecodeRequest, + DecodeResponse, + EncodeRequest, + EncodeResponse, ModelInfoResponse, + TokenCountResponse, to_dict ) @@ -206,26 +211,21 @@ async def handle_moderations(request: Request): return JSONResponse(response) -@app.post("/v1/internal/encode") -async def handle_token_encode(request: Request): - body = await request.json() - encoding_format = body.get("encoding_format", "") - response = token_encode(body["input"], encoding_format) +@app.post("/v1/internal/encode", response_model=EncodeResponse) +async def handle_token_encode(request_data: EncodeRequest): + response = token_encode(request_data.text) return JSONResponse(response) -@app.post("/v1/internal/decode") -async def handle_token_decode(request: Request): - body = await request.json() - encoding_format = body.get("encoding_format", "") - response = token_decode(body["input"], encoding_format) - return JSONResponse(response, no_debug=True) +@app.post("/v1/internal/decode", response_model=DecodeResponse) +async def handle_token_decode(request_data: DecodeRequest): + response = token_decode(request_data.tokens) + return JSONResponse(response) -@app.post("/v1/internal/token-count") -async def handle_token_count(request: Request): - body = await request.json() - response = token_count(body['prompt']) +@app.post("/v1/internal/token-count", response_model=TokenCountResponse) +async def handle_token_count(request_data: EncodeRequest): + response = token_count(request_data.text) return JSONResponse(response) diff --git a/extensions/openai/tokens.py b/extensions/openai/tokens.py index 0338e7f2..9e92d362 100644 --- a/extensions/openai/tokens.py +++ b/extensions/openai/tokens.py @@ -3,34 +3,24 @@ from modules.text_generation import decode, encode def token_count(prompt): tokens = encode(prompt)[0] - return { - 'results': [{ - 'tokens': len(tokens) - }] + 'length': len(tokens) } -def token_encode(input, encoding_format): - # if isinstance(input, list): +def token_encode(input): tokens = encode(input)[0] + if tokens.__class__.__name__ in ['Tensor', 'ndarray']: + tokens = tokens.tolist() return { - 'results': [{ - 'tokens': tokens, - 'length': len(tokens), - }] + 'tokens': tokens, + 'length': len(tokens), } -def token_decode(tokens, encoding_format): - # if isinstance(input, list): - # if encoding_format == "base64": - # tokens = base64_to_float_list(tokens) - output = decode(tokens)[0] - +def token_decode(tokens): + output = decode(tokens) return { - 'results': [{ - 'text': output - }] + 'text': output } diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 4e0211b2..da19e2be 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -121,6 +121,27 @@ class ChatCompletionResponse(BaseModel): usage: dict +class EncodeRequest(BaseModel): + text: str + + +class DecodeRequest(BaseModel): + tokens: List[int] + + +class EncodeResponse(BaseModel): + tokens: List[int] + length: int + + +class DecodeResponse(BaseModel): + text: str + + +class TokenCountResponse(BaseModel): + length: int + + class ModelInfoResponse(BaseModel): model_name: str lora_names: List[str] diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index 25d171b1..93f22e95 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -101,7 +101,7 @@ class LlamaCppModel: return self.model.tokenize(string) - def decode(self, ids): + def decode(self, ids, **kwargs): return self.model.detokenize(ids).decode('utf-8') def get_logits(self, tokens): diff --git a/modules/text_generation.py b/modules/text_generation.py index 310525d2..6034ef31 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -145,7 +145,7 @@ def decode(output_ids, skip_special_tokens=True): if shared.tokenizer is None: raise ValueError('No tokenizer is loaded') - return shared.tokenizer.decode(output_ids, skip_special_tokens) + return shared.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens) def get_encoded_length(prompt):