From 791a38bad14763829d754bb68b115cf6affd36f2 Mon Sep 17 00:00:00 2001 From: Jeffrey Lin Date: Mon, 8 May 2023 18:31:34 -0700 Subject: [PATCH] [extensions/openai] Support undocumented base64 'encoding_format' param for compatibility with official OpenAI client (#1876) --- extensions/openai/script.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index c168ec95..9eb35a46 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -1,4 +1,6 @@ +import base64 import json +import numpy as np import os import time from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer @@ -45,6 +47,20 @@ def clamp(value, minvalue, maxvalue): return max(minvalue, min(value, maxvalue)) +def float_list_to_base64(float_list): + # Convert the list to a float32 array that the OpenAPI client expects + float_array = np.array(float_list, dtype="float32") + + # Get raw bytes + bytes_array = float_array.tobytes() + + # Encode bytes into base64 + encoded_bytes = base64.b64encode(bytes_array) + + # Turn raw base64 encoded bytes into ASCII + ascii_string = encoded_bytes.decode('ascii') + return ascii_string + class Handler(BaseHTTPRequestHandler): def do_GET(self): if self.path.startswith('/v1/models'): @@ -435,7 +451,13 @@ class Handler(BaseHTTPRequestHandler): embeddings = embedding_model.encode(input).tolist() - data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)] + def enc_emb(emb): + # If base64 is specified, encode. Otherwise, do nothing. + if body.get("encoding_format", "") == "base64": + return float_list_to_base64(emb) + else: + return emb + data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)] response = json.dumps({ "object": "list",