Adding parameter to compose API's model list

This commit is contained in:
jsboige 2024-09-01 12:24:51 +02:00
parent 5522584992
commit 6e294af0a6
3 changed files with 26 additions and 2 deletions

View file

@ -15,7 +15,30 @@ def get_current_model_info():
def list_models():
return {'model_names': get_available_models()[1:]}
mode = shared.args.model_selection_mode
result = {
"object": "list",
"data": []
}
# Inclure les dummy models si le bit 0 est activé
if mode & 1:
dummy_models = ['gpt-3.5-turbo', 'text-embedding-ada-002']
for model in dummy_models:
result["data"].append(model_info_dict(model))
# Inclure les modèles locaux si le bit 1 est activé
if mode & 2:
if mode & 4:
# Ne renvoyer que le modèle actuellement chargé
result["data"].append(model_info_dict(shared.model_name))
else:
# Renvoyer tous les modèles disponibles
for model in get_available_models():
result["data"].append(model_info_dict(model))
return result
def list_dummy_models():

View file

@ -147,7 +147,7 @@ async def handle_models(request: Request):
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
if is_list:
response = OAImodels.list_dummy_models()
response = OAImodels.list_models()
else:
model_name = path[len('/v1/models/'):]
response = OAImodels.model_info_dict(model_name)

View file

@ -200,6 +200,7 @@ group.add_argument('--api-port', type=int, default=5000, help='The listening por
group.add_argument('--api-key', type=str, default='', help='API authentication key.')
group.add_argument('--admin-key', type=str, default='', help='API authentication key for admin tasks like loading and unloading models. If not set, will be the same as --api-key.')
group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.')
group.add_argument('--model-selection-mode', type=int, default=0, help='Model selection mode: bitwise flag. 1=Include dummy models, 2=Include local models, 4=Return only the currently loaded model if local models are included.')
# Multimodal
group = parser.add_argument_group('Multimodal')