From f1f2c4c3f4e478ab0ff86261c23ce6f2fe2750dc Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 17 Dec 2023 12:08:33 -0300 Subject: [PATCH] Add --num_experts_per_token parameter (ExLlamav2) (#4955) --- README.md | 1 + modules/exllamav2.py | 1 + modules/exllamav2_hf.py | 1 + modules/loaders.py | 42 +++++++++++++++++++++------------------- modules/shared.py | 1 + modules/ui.py | 1 + modules/ui_model_menu.py | 1 + 7 files changed, 28 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 57b7afee..ad8087ee 100644 --- a/README.md +++ b/README.md @@ -274,6 +274,7 @@ List of command-line flags |`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. | |`--no_flash_attn` | Force flash-attention to not be used. | |`--cache_8bit` | Use 8-bit cache to save VRAM. | +|`--num_experts_per_token NUM_EXPERTS_PER_TOKEN` | Number of experts to use for generation. Applies to MoE models like Mixtral. | #### AutoGPTQ diff --git a/modules/exllamav2.py b/modules/exllamav2.py index d755a36a..2cf4a039 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -48,6 +48,7 @@ class Exllamav2Model: config.scale_pos_emb = shared.args.compress_pos_emb config.scale_alpha_value = shared.args.alpha_value config.no_flash_attn = shared.args.no_flash_attn + config.num_experts_per_token = int(shared.args.num_experts_per_token) model = ExLlamaV2(config) diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 30e3fe48..944c39dd 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -165,5 +165,6 @@ class Exllamav2HF(PreTrainedModel): config.scale_pos_emb = shared.args.compress_pos_emb config.scale_alpha_value = shared.args.alpha_value config.no_flash_attn = shared.args.no_flash_attn + config.num_experts_per_token = int(shared.args.num_experts_per_token) return Exllamav2HF(config) diff --git a/modules/loaders.py b/modules/loaders.py index c7e7653e..9f1c70d1 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -65,6 +65,18 @@ loaders_and_params = OrderedDict({ 'logits_all', 'llamacpp_HF_info', ], + 'ExLlamav2_HF': [ + 'gpu_split', + 'max_seq_len', + 'cfg_cache', + 'no_flash_attn', + 'num_experts_per_token', + 'cache_8bit', + 'alpha_value', + 'compress_pos_emb', + 'trust_remote_code', + 'no_use_fast', + ], 'ExLlama_HF': [ 'gpu_split', 'max_seq_len', @@ -75,17 +87,6 @@ loaders_and_params = OrderedDict({ 'trust_remote_code', 'no_use_fast', ], - 'ExLlamav2_HF': [ - 'gpu_split', - 'max_seq_len', - 'cfg_cache', - 'no_flash_attn', - 'cache_8bit', - 'alpha_value', - 'compress_pos_emb', - 'trust_remote_code', - 'no_use_fast', - ], 'AutoGPTQ': [ 'triton', 'no_inject_fused_attention', @@ -123,6 +124,16 @@ loaders_and_params = OrderedDict({ 'no_use_fast', 'gptq_for_llama_info', ], + 'ExLlamav2': [ + 'gpu_split', + 'max_seq_len', + 'no_flash_attn', + 'num_experts_per_token', + 'cache_8bit', + 'alpha_value', + 'compress_pos_emb', + 'exllamav2_info', + ], 'ExLlama': [ 'gpu_split', 'max_seq_len', @@ -131,15 +142,6 @@ loaders_and_params = OrderedDict({ 'compress_pos_emb', 'exllama_info', ], - 'ExLlamav2': [ - 'gpu_split', - 'max_seq_len', - 'no_flash_attn', - 'cache_8bit', - 'alpha_value', - 'compress_pos_emb', - 'exllamav2_info', - ], 'ctransformers': [ 'n_ctx', 'n_gpu_layers', diff --git a/modules/shared.py b/modules/shared.py index adebe62d..edd74af1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -125,6 +125,7 @@ parser.add_argument('--max_seq_len', type=int, default=2048, help='Maximum seque parser.add_argument('--cfg-cache', action='store_true', help='ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.') parser.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') parser.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.') +parser.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.') # AutoGPTQ parser.add_argument('--triton', action='store_true', help='Use triton.') diff --git a/modules/ui.py b/modules/ui.py index 8bfc9491..285e2fc3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -73,6 +73,7 @@ def list_model_elements(): 'disable_exllamav2', 'cfg_cache', 'no_flash_attn', + 'num_experts_per_token', 'cache_8bit', 'threads', 'threads_batch', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 7242d117..7f81ca2d 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -129,6 +129,7 @@ def create_ui(): shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.') shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.') shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.') + shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.') shared.gradio['gptq_for_llama_info'] = gr.Markdown('Legacy loader for compatibility with older GPUs. ExLlama_HF or AutoGPTQ are preferred for GPTQ models when supported.') shared.gradio['exllama_info'] = gr.Markdown("ExLlama_HF is recommended over ExLlama for better integration with extensions and more consistent sampling behavior across loaders.") shared.gradio['exllamav2_info'] = gr.Markdown("ExLlamav2_HF is recommended over ExLlamav2 for better integration with extensions and more consistent sampling behavior across loaders.")