diff --git a/docker/TensorRT-LLM/Dockerfile b/docker/TensorRT-LLM/Dockerfile new file mode 100644 index 00000000..ae503c94 --- /dev/null +++ b/docker/TensorRT-LLM/Dockerfile @@ -0,0 +1,27 @@ +FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime + +# Install Git +RUN apt update && apt install -y git + +# System-wide TensorRT-LLM requirements +RUN apt install -y openmpi-bin libopenmpi-dev + +# Set the working directory +WORKDIR /app + +# Install text-generation-webui +RUN git clone https://github.com/oobabooga/text-generation-webui +WORKDIR /app/text-generation-webui +RUN pip install -r requirements.txt + +# This is needed to avoid an error about "Failed to build mpi4py" in the next command +ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH + +# Install TensorRT-LLM +RUN pip3 install tensorrt_llm==0.10.0 -U --pre --extra-index-url https://pypi.nvidia.com + +# Expose the necessary port for the Python server +EXPOSE 7860 5000 + +# Run the Python server.py script with the specified command +CMD ["python", "server.py", "--api", "--listen"] diff --git a/modules/loaders.py b/modules/loaders.py index 7bf1cde4..1da37595 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -131,6 +131,11 @@ loaders_and_params = OrderedDict({ 'hqq_backend', 'trust_remote_code', 'no_use_fast', + ], + 'TensorRT-LLM': [ + 'max_seq_len', + 'cpp_runner', + 'tensorrt_llm_info', ] }) @@ -316,6 +321,16 @@ loaders_samplers = { 'skip_special_tokens', 'auto_max_new_tokens', }, + 'TensorRT-LLM': { + 'temperature', + 'top_p', + 'top_k', + 'repetition_penalty', + 'presence_penalty', + 'frequency_penalty', + 'ban_eos_token', + 'auto_max_new_tokens', + } } diff --git a/modules/models.py b/modules/models.py index bd54c146..da741cb0 100644 --- a/modules/models.py +++ b/modules/models.py @@ -77,6 +77,7 @@ def load_model(model_name, loader=None): 'ExLlamav2_HF': ExLlamav2_HF_loader, 'AutoAWQ': AutoAWQ_loader, 'HQQ': HQQ_loader, + 'TensorRT-LLM': TensorRT_LLM_loader, } metadata = get_model_metadata(model_name) @@ -101,7 +102,7 @@ def load_model(model_name, loader=None): tokenizer = load_tokenizer(model_name, model) shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings}) - if loader.lower().startswith('exllama'): + if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt'): shared.settings['truncation_length'] = shared.args.max_seq_len elif loader in ['llama.cpp', 'llamacpp_HF']: shared.settings['truncation_length'] = shared.args.n_ctx @@ -337,6 +338,13 @@ def HQQ_loader(model_name): return model +def TensorRT_LLM_loader(model_name): + from modules.tensorrt_llm import TensorRTLLMModel + + model = TensorRTLLMModel.from_pretrained(model_name) + return model + + def get_max_memory_dict(): max_memory = {} max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB' diff --git a/modules/models_settings.py b/modules/models_settings.py index e9645c96..387c5658 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -81,6 +81,9 @@ def get_model_metadata(model): # Transformers metadata if hf_metadata is not None: metadata = json.loads(open(path, 'r', encoding='utf-8').read()) + if 'pretrained_config' in metadata: + metadata = metadata['pretrained_config'] + for k in ['max_position_embeddings', 'model_max_length', 'max_seq_len']: if k in metadata: model_settings['truncation_length'] = metadata[k] diff --git a/modules/shared.py b/modules/shared.py index 373089dc..ebbfc268 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -165,6 +165,10 @@ group.add_argument('--no_inject_fused_attention', action='store_true', help='Dis group = parser.add_argument_group('HQQ') group.add_argument('--hqq-backend', type=str, default='PYTORCH_COMPILE', help='Backend for the HQQ loader. Valid options: PYTORCH, PYTORCH_COMPILE, ATEN.') +# TensorRT-LLM +group = parser.add_argument_group('TensorRT-LLM') +group.add_argument('--cpp-runner', action='store_true', help='Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn\'t support streaming yet.') + # DeepSpeed group = parser.add_argument_group('DeepSpeed') group.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') @@ -263,6 +267,8 @@ def fix_loader_name(name): return 'AutoAWQ' elif name in ['hqq']: return 'HQQ' + elif name in ['tensorrt', 'tensorrtllm', 'tensorrt_llm', 'tensorrt-llm', 'tensort', 'tensortllm']: + return 'TensorRT-LLM' def add_extension(name, last=False): diff --git a/modules/tensorrt_llm.py b/modules/tensorrt_llm.py new file mode 100644 index 00000000..c2685b75 --- /dev/null +++ b/modules/tensorrt_llm.py @@ -0,0 +1,131 @@ +from pathlib import Path + +import tensorrt_llm +import torch +from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp + +from modules import shared +from modules.logging_colors import logger +from modules.text_generation import ( + get_max_prompt_length, + get_reply_from_output_ids +) + + +class TensorRTLLMModel: + def __init__(self): + pass + + @classmethod + def from_pretrained(self, path_to_model): + + path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) + runtime_rank = tensorrt_llm.mpi_rank() + + # Define model settings + runner_kwargs = dict( + engine_dir=str(path_to_model), + lora_dir=None, + rank=runtime_rank, + debug_mode=False, + lora_ckpt_source="hf", + ) + + if shared.args.cpp_runner: + logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"") + runner_kwargs.update( + max_batch_size=1, + max_input_len=shared.args.max_seq_len - 512, + max_output_len=512, + max_beam_width=1, + max_attention_window_size=None, + sink_token_length=None, + ) + else: + logger.info("TensorRT-LLM: Using \"ModelRunner\"") + + # Load the model + runner_cls = ModelRunnerCpp if shared.args.cpp_runner else ModelRunner + runner = runner_cls.from_dir(**runner_kwargs) + + result = self() + result.model = runner + result.runtime_rank = runtime_rank + + return result + + def generate_with_streaming(self, prompt, state): + batch_input_ids = [] + input_ids = shared.tokenizer.encode( + prompt, + add_special_tokens=True, + truncation=False, + ) + input_ids = torch.tensor(input_ids, dtype=torch.int32) + input_ids = input_ids[-get_max_prompt_length(state):] # Apply truncation_length + batch_input_ids.append(input_ids) + + if shared.args.cpp_runner: + max_new_tokens = min(512, state['max_new_tokens']) + elif state['auto_max_new_tokens']: + max_new_tokens = state['truncation_length'] - input_ids.shape[-1] + else: + max_new_tokens = state['max_new_tokens'] + + with torch.no_grad(): + generator = self.model.generate( + batch_input_ids, + max_new_tokens=max_new_tokens, + max_attention_window_size=None, + sink_token_length=None, + end_id=shared.tokenizer.eos_token_id if not state['ban_eos_token'] else -1, + pad_id=shared.tokenizer.pad_token_id or shared.tokenizer.eos_token_id, + temperature=state['temperature'], + top_k=state['top_k'], + top_p=state['top_p'], + num_beams=1, + length_penalty=1.0, + repetition_penalty=state['repetition_penalty'], + presence_penalty=state['presence_penalty'], + frequency_penalty=state['frequency_penalty'], + stop_words_list=None, + bad_words_list=None, + lora_uids=None, + prompt_table_path=None, + prompt_tasks=None, + streaming=not shared.args.cpp_runner, + output_sequence_lengths=True, + return_dict=True, + medusa_choices=None + ) + + torch.cuda.synchronize() + + cumulative_reply = '' + starting_from = batch_input_ids[0].shape[-1] + + if shared.args.cpp_runner: + sequence_length = generator['sequence_lengths'][0].item() + output_ids = generator['output_ids'][0][0][:sequence_length].tolist() + + cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from) + starting_from = sequence_length + yield cumulative_reply + else: + for curr_outputs in generator: + if shared.stop_everything: + break + + sequence_length = curr_outputs['sequence_lengths'][0].item() + output_ids = curr_outputs['output_ids'][0][0][:sequence_length].tolist() + + cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from) + starting_from = sequence_length + yield cumulative_reply + + def generate(self, prompt, state): + output = '' + for output in self.generate_with_streaming(prompt, state): + pass + + return output diff --git a/modules/text_generation.py b/modules/text_generation.py index ca42ba1f..d971a30e 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -54,7 +54,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap yield '' return - if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model']: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel']: generate_func = generate_reply_custom else: generate_func = generate_reply_HF @@ -132,7 +132,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if shared.tokenizer is None: raise ValueError('No tokenizer is loaded') - if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model']: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel']: input_ids = shared.tokenizer.encode(str(prompt)) if shared.model.__class__.__name__ not in ['Exllamav2Model']: input_ids = np.array(input_ids).reshape(1, len(input_ids)) @@ -158,7 +158,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] - if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model'] or shared.args.cpu: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu: return input_ids elif shared.args.deepspeed: import deepspeed diff --git a/modules/ui.py b/modules/ui.py index f88c0a82..c20a7888 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -106,6 +106,7 @@ def list_model_elements(): 'streaming_llm', 'attention_sink_size', 'hqq_backend', + 'cpp_runner', ] if is_torch_xpu_available(): for i in range(torch.xpu.device_count()): diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 53a9a238..3ebcd126 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -139,6 +139,7 @@ def create_ui(): shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.') shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.') shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.') + shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.') shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.') with gr.Blocks(): shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code) @@ -149,6 +150,7 @@ def create_ui(): shared.gradio['disable_exllamav2'] = gr.Checkbox(label="disable_exllamav2", value=shared.args.disable_exllamav2, info='Disable ExLlamav2 kernel for GPTQ models.') shared.gradio['exllamav2_info'] = gr.Markdown("ExLlamav2_HF is recommended over ExLlamav2 for better integration with extensions and more consistent sampling behavior across loaders.") shared.gradio['llamacpp_HF_info'] = gr.Markdown("llamacpp_HF loads llama.cpp as a Transformers model. To use it, you need to place your GGUF in a subfolder of models/ with the necessary tokenizer files.\n\nYou can use the \"llamacpp_HF creator\" menu to do that automatically.") + shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `max_seq_len` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.') with gr.Column(): with gr.Row():