Load the model by name

This commit is contained in:
oobabooga 2023-02-28 02:52:29 -03:00
parent f871971de1
commit 6837d4d72a

View file

@ -81,7 +81,7 @@ def load_model(model_name):
elif shared.is_RWKV:
from modules.RWKV import load_RWKV_model
return load_RWKV_model(Path('models/RWKV-4-Pile-169M-20220807-8023.pth')), None
return load_RWKV_model(Path(f'models/{model_name}')), None
# Custom
else: