Prevent deadlocks in OpenAI API with simultaneous requests

This commit is contained in:
oobabooga 2023-11-08 20:55:39 -08:00
parent 21ed9a260e
commit effb3aef42

View file

@ -1,3 +1,4 @@
import asyncio
import json import json
import os import os
import traceback import traceback
@ -46,6 +47,9 @@ params = {
} }
streaming_semaphore = asyncio.Semaphore(1)
def verify_api_key(authorization: str = Header(None)) -> None: def verify_api_key(authorization: str = Header(None)) -> None:
expected_api_key = shared.args.api_key expected_api_key = shared.args.api_key
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"): if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
@ -84,9 +88,10 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
if request_data.stream: if request_data.stream:
async def generator(): async def generator():
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy) async with streaming_semaphore:
for resp in response: response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
yield {"data": json.dumps(resp)} for resp in response:
yield {"data": json.dumps(resp)}
return EventSourceResponse(generator()) # SSE streaming return EventSourceResponse(generator()) # SSE streaming
@ -102,9 +107,10 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
if request_data.stream: if request_data.stream:
async def generator(): async def generator():
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy) async with streaming_semaphore:
for resp in response: response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
yield {"data": json.dumps(resp)} for resp in response:
yield {"data": json.dumps(resp)}
return EventSourceResponse(generator()) # SSE streaming return EventSourceResponse(generator()) # SSE streaming