Backend cleanup (#6025)

This commit is contained in:
oobabooga 2024-05-21 13:32:02 -03:00 committed by GitHub
parent 6a1682aa95
commit bd7cc4234d
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: B5690EEEBB952194
23 changed files with 57 additions and 442 deletions

View file

@ -11,7 +11,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
## Features ## Features
* 3 interface modes: default (two columns), notebook, and chat. * 3 interface modes: default (two columns), notebook, and chat.
* Multiple model backends: [Transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp) (through [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)), [ExLlamaV2](https://github.com/turboderp/exllamav2), [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ), [AutoAWQ](https://github.com/casper-hansen/AutoAWQ), [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa), [QuIP#](https://github.com/Cornell-RelaxML/quip-sharp). * Multiple model backends: [Transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp) (through [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)), [ExLlamaV2](https://github.com/turboderp/exllamav2), [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ), [AutoAWQ](https://github.com/casper-hansen/AutoAWQ).
* Dropdown menu for quickly switching between different models. * Dropdown menu for quickly switching between different models.
* Large number of extensions (built-in and user-contributed), including Coqui TTS for realistic voice outputs, Whisper STT for voice inputs, translation, [multimodal pipelines](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal), vector databases, Stable Diffusion integration, and a lot more. See [the wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [the extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details. * Large number of extensions (built-in and user-contributed), including Coqui TTS for realistic voice outputs, Whisper STT for voice inputs, translation, [multimodal pipelines](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal), vector databases, Stable Diffusion integration, and a lot more. See [the wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [the extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details.
* [Chat with custom characters](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#character). * [Chat with custom characters](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#character).
@ -208,12 +208,12 @@ usage: server.py [-h] [--multi-user] [--character CHARACTER] [--model MODEL] [--
[--tensorcores] [--n_ctx N_CTX] [--threads THREADS] [--threads-batch THREADS_BATCH] [--no_mul_mat_q] [--n_batch N_BATCH] [--no-mmap] [--mlock] [--n-gpu-layers N_GPU_LAYERS] [--tensorcores] [--n_ctx N_CTX] [--threads THREADS] [--threads-batch THREADS_BATCH] [--no_mul_mat_q] [--n_batch N_BATCH] [--no-mmap] [--mlock] [--n-gpu-layers N_GPU_LAYERS]
[--tensor_split TENSOR_SPLIT] [--numa] [--logits_all] [--no_offload_kqv] [--cache-capacity CACHE_CAPACITY] [--row_split] [--streaming-llm] [--attention-sink-size ATTENTION_SINK_SIZE] [--tensor_split TENSOR_SPLIT] [--numa] [--logits_all] [--no_offload_kqv] [--cache-capacity CACHE_CAPACITY] [--row_split] [--streaming-llm] [--attention-sink-size ATTENTION_SINK_SIZE]
[--gpu-split GPU_SPLIT] [--autosplit] [--max_seq_len MAX_SEQ_LEN] [--cfg-cache] [--no_flash_attn] [--cache_8bit] [--cache_4bit] [--num_experts_per_token NUM_EXPERTS_PER_TOKEN] [--gpu-split GPU_SPLIT] [--autosplit] [--max_seq_len MAX_SEQ_LEN] [--cfg-cache] [--no_flash_attn] [--cache_8bit] [--cache_4bit] [--num_experts_per_token NUM_EXPERTS_PER_TOKEN]
[--triton] [--no_inject_fused_attention] [--no_inject_fused_mlp] [--no_use_cuda_fp16] [--desc_act] [--disable_exllama] [--disable_exllamav2] [--wbits WBITS] [--model_type MODEL_TYPE] [--triton] [--no_inject_fused_mlp] [--no_use_cuda_fp16] [--desc_act] [--disable_exllama] [--disable_exllamav2] [--wbits WBITS] [--groupsize GROUPSIZE] [--no_inject_fused_attention]
[--groupsize GROUPSIZE] [--pre_layer PRE_LAYER [PRE_LAYER ...]] [--checkpoint CHECKPOINT] [--monkey-patch] [--hqq-backend HQQ_BACKEND] [--deepspeed] [--hqq-backend HQQ_BACKEND] [--deepspeed] [--nvme-offload-dir NVME_OFFLOAD_DIR] [--local_rank LOCAL_RANK] [--alpha_value ALPHA_VALUE] [--rope_freq_base ROPE_FREQ_BASE]
[--nvme-offload-dir NVME_OFFLOAD_DIR] [--local_rank LOCAL_RANK] [--alpha_value ALPHA_VALUE] [--rope_freq_base ROPE_FREQ_BASE] [--compress_pos_emb COMPRESS_POS_EMB] [--listen] [--compress_pos_emb COMPRESS_POS_EMB] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share] [--auto-launch] [--gradio-auth GRADIO_AUTH]
[--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share] [--auto-launch] [--gradio-auth GRADIO_AUTH] [--gradio-auth-path GRADIO_AUTH_PATH] [--ssl-keyfile SSL_KEYFILE] [--gradio-auth-path GRADIO_AUTH_PATH] [--ssl-keyfile SSL_KEYFILE] [--ssl-certfile SSL_CERTFILE] [--api] [--public-api] [--public-api-id PUBLIC_API_ID] [--api-port API_PORT]
[--ssl-certfile SSL_CERTFILE] [--api] [--public-api] [--public-api-id PUBLIC_API_ID] [--api-port API_PORT] [--api-key API_KEY] [--admin-key ADMIN_KEY] [--nowebui] [--api-key API_KEY] [--admin-key ADMIN_KEY] [--nowebui] [--multimodal-pipeline MULTIMODAL_PIPELINE] [--model_type MODEL_TYPE] [--pre_layer PRE_LAYER [PRE_LAYER ...]]
[--multimodal-pipeline MULTIMODAL_PIPELINE] [--checkpoint CHECKPOINT] [--monkey-patch]
Text generation web UI Text generation web UI
@ -237,7 +237,7 @@ Basic settings:
Model loader: Model loader:
--loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav2_HF, ExLlamav2, --loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav2_HF, ExLlamav2,
AutoGPTQ, AutoAWQ, GPTQ-for-LLaMa, QuIP#. AutoGPTQ, AutoAWQ.
Transformers/Accelerate: Transformers/Accelerate:
--cpu Use the CPU to generate text. Warning: Training on CPU is extremely slow. --cpu Use the CPU to generate text. Warning: Training on CPU is extremely slow.
@ -293,21 +293,16 @@ ExLlamaV2:
AutoGPTQ: AutoGPTQ:
--triton Use triton. --triton Use triton.
--no_inject_fused_attention Disable the use of fused attention, which will use less VRAM at the cost of slower inference.
--no_inject_fused_mlp Triton mode only: disable the use of fused MLP, which will use less VRAM at the cost of slower inference. --no_inject_fused_mlp Triton mode only: disable the use of fused MLP, which will use less VRAM at the cost of slower inference.
--no_use_cuda_fp16 This can make models faster on some systems. --no_use_cuda_fp16 This can make models faster on some systems.
--desc_act For models that do not have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig. --desc_act For models that do not have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.
--disable_exllama Disable ExLlama kernel, which can improve inference speed on some systems. --disable_exllama Disable ExLlama kernel, which can improve inference speed on some systems.
--disable_exllamav2 Disable ExLlamav2 kernel. --disable_exllamav2 Disable ExLlamav2 kernel.
GPTQ-for-LLaMa:
--wbits WBITS Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. --wbits WBITS Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.
--model_type MODEL_TYPE Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.
--groupsize GROUPSIZE Group size. --groupsize GROUPSIZE Group size.
--pre_layer PRE_LAYER [PRE_LAYER ...] The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. For multi-gpu, write the numbers separated
by spaces, eg --pre_layer 30 60. AutoAWQ:
--checkpoint CHECKPOINT The path to the quantized checkpoint file. If not specified, it will be automatically detected. --no_inject_fused_attention Disable the use of fused attention, which will use less VRAM at the cost of slower inference.
--monkey-patch Apply the monkey patch for using LoRAs with quantized models.
HQQ: HQQ:
--hqq-backend HQQ_BACKEND Backend for the HQQ loader. Valid options: PYTORCH, PYTORCH_COMPILE, ATEN. --hqq-backend HQQ_BACKEND Backend for the HQQ loader. Valid options: PYTORCH, PYTORCH_COMPILE, ATEN.

View file

@ -64,14 +64,6 @@ Loads: GPTQ models.
* **no_use_cuda_fp16**: On some systems, the performance can be very bad with this unset. Can usually be ignored. * **no_use_cuda_fp16**: On some systems, the performance can be very bad with this unset. Can usually be ignored.
* **desc_act**: For ancient models without proper metadata, sets the model "act-order" parameter manually. Can usually be ignored. * **desc_act**: For ancient models without proper metadata, sets the model "act-order" parameter manually. Can usually be ignored.
### GPTQ-for-LLaMa
Loads: GPTQ models.
Ancient loader, the first one to implement 4-bit quantization. It works on older GPUs for which ExLlamaV2 and AutoGPTQ do not work, and it doesn't work with "act-order", so you should use it with simple 4-bit-128g models.
* **pre_layer**: Used for CPU offloading. The higher the number, the more layers will be sent to the GPU. GPTQ-for-LLaMa CPU offloading was faster than the one implemented in AutoGPTQ the last time I checked.
### llama.cpp ### llama.cpp
Loads: GGUF models. Note: GGML models have been deprecated and do not work anymore. Loads: GGUF models. Note: GGML models have been deprecated and do not work anymore.

View file

@ -13,28 +13,6 @@ Source: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/1126
This file will be automatically detected the next time you start the web UI. This file will be automatically detected the next time you start the web UI.
## Using LoRAs with GPTQ-for-LLaMa
This requires using a monkey patch that is supported by this web UI: https://github.com/johnsmith0031/alpaca_lora_4bit
To use it:
Install alpaca_lora_4bit using pip
```
git clone https://github.com/johnsmith0031/alpaca_lora_4bit.git
cd alpaca_lora_4bit
git fetch origin winglian-setup_pip
git checkout winglian-setup_pip
pip install .
```
Start the UI with the --monkey-patch flag:
```
python server.py --model llama-7b-4bit-128g --listen --lora tloen_alpaca-lora-7b --monkey-patch
```
## DeepSpeed ## DeepSpeed
`DeepSpeed ZeRO-3` is an alternative offloading strategy for full-precision (16-bit) transformers models. `DeepSpeed ZeRO-3` is an alternative offloading strategy for full-precision (16-bit) transformers models.

View file

@ -2,15 +2,13 @@
| Loader | Loading 1 LoRA | Loading 2 or more LoRAs | Training LoRAs | Multimodal extension | Perplexity evaluation | | Loader | Loading 1 LoRA | Loading 2 or more LoRAs | Training LoRAs | Multimodal extension | Perplexity evaluation |
|----------------|----------------|-------------------------|----------------|----------------------|-----------------------| |----------------|----------------|-------------------------|----------------|----------------------|-----------------------|
| Transformers | ✅ | ✅\*\*\* | ✅\* | ✅ | ✅ | | Transformers | ✅ | ✅\*\* | ✅\* | ✅ | ✅ |
| llama.cpp | ❌ | ❌ | ❌ | ❌ | use llamacpp_HF | | llama.cpp | ❌ | ❌ | ❌ | ❌ | use llamacpp_HF |
| llamacpp_HF | ❌ | ❌ | ❌ | ❌ | ✅ | | llamacpp_HF | ❌ | ❌ | ❌ | ❌ | ✅ |
| ExLlamav2_HF | ✅ | ✅ | ❌ | ❌ | ✅ | | ExLlamav2_HF | ✅ | ✅ | ❌ | ❌ | ✅ |
| ExLlamav2 | ✅ | ✅ | ❌ | ❌ | use ExLlamav2_HF | | ExLlamav2 | ✅ | ✅ | ❌ | ❌ | use ExLlamav2_HF |
| AutoGPTQ | ✅ | ❌ | ❌ | ✅ | ✅ | | AutoGPTQ | ✅ | ❌ | ❌ | ✅ | ✅ |
| AutoAWQ | ? | ❌ | ? | ? | ✅ | | AutoAWQ | ? | ❌ | ? | ? | ✅ |
| GPTQ-for-LLaMa | ✅\*\* | ✅\*\*\* | ✅ | ✅ | ✅ |
| QuIP# | ? | ? | ? | ? | ✅ |
| HQQ | ? | ? | ? | ? | ✅ | | HQQ | ? | ? | ? | ? | ✅ |
❌ = not implemented ❌ = not implemented
@ -19,6 +17,4 @@
\* Training LoRAs with GPTQ models also works with the Transformers loader. Make sure to check "auto-devices" and "disable_exllama" before loading the model. \* Training LoRAs with GPTQ models also works with the Transformers loader. Make sure to check "auto-devices" and "disable_exllama" before loading the model.
\*\* Requires the monkey-patch. The instructions can be found [here](https://github.com/oobabooga/text-generation-webui/wiki/08-%E2%80%90-Additional-Tips#using-loras-with-gptq-for-llama). \*\* Multi-LoRA in PEFT is tricky and the current implementation does not work reliably in all cases.
\*\*\* Multi-LoRA in PEFT is tricky and the current implementation does not work reliably in all cases.

View file

@ -44,7 +44,7 @@ def load_quantized(model_name):
'model_basename': pt_path.stem, 'model_basename': pt_path.stem,
'device': "xpu:0" if is_xpu_available() else "cuda:0" if not shared.args.cpu else "cpu", 'device': "xpu:0" if is_xpu_available() else "cuda:0" if not shared.args.cpu else "cpu",
'use_triton': shared.args.triton, 'use_triton': shared.args.triton,
'inject_fused_attention': not shared.args.no_inject_fused_attention, 'inject_fused_attention': False,
'inject_fused_mlp': not shared.args.no_inject_fused_mlp, 'inject_fused_mlp': not shared.args.no_inject_fused_mlp,
'use_safetensors': use_safetensors, 'use_safetensors': use_safetensors,
'trust_remote_code': shared.args.trust_remote_code, 'trust_remote_code': shared.args.trust_remote_code,

View file

@ -1,171 +0,0 @@
import inspect
import re
from pathlib import Path
import accelerate
import torch
import transformers
from accelerate.utils import is_xpu_available
from gptq_for_llama import llama_inference_offload
from gptq_for_llama.modelutils import find_layers
from gptq_for_llama.quant import make_quant
from transformers import AutoConfig, AutoModelForCausalLM
import modules.shared as shared
from modules.logging_colors import logger
# This function is a replacement for the load_quant function in the
# GPTQ-for_LLaMa repository. It supports more models and branches.
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=None, kernel_switch_threshold=128, eval=True):
exclude_layers = exclude_layers or ['lm_head']
def noop(*args, **kwargs):
pass
config = AutoConfig.from_pretrained(model, trust_remote_code=shared.args.trust_remote_code)
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=shared.args.trust_remote_code)
torch.set_default_dtype(torch.float)
if eval:
model = model.eval()
layers = find_layers(model)
for name in exclude_layers:
if name in layers:
del layers[name]
gptq_args = inspect.getfullargspec(make_quant).args
make_quant_kwargs = {
'module': model,
'names': layers,
'bits': wbits,
}
if 'groupsize' in gptq_args:
make_quant_kwargs['groupsize'] = groupsize
if 'faster' in gptq_args:
make_quant_kwargs['faster'] = faster_kernel
if 'kernel_switch_threshold' in gptq_args:
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
make_quant(**make_quant_kwargs)
del layers
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint), strict=False)
else:
model.load_state_dict(torch.load(checkpoint, weights_only=True), strict=False)
model.seqlen = 2048
return model
# Used to locate the .pt/.safetensors quantized file
def find_quantized_model_file(model_name):
if shared.args.checkpoint:
return Path(shared.args.checkpoint)
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = None
priority_name_list = [
Path(f'{shared.args.model_dir}/{model_name}{hyphen}{shared.args.wbits}bit{group}{ext}')
for group in ([f'-{shared.args.groupsize}g', ''] if shared.args.groupsize > 0 else [''])
for ext in ['.safetensors', '.pt']
for hyphen in ['-', f'/{model_name}-', '/']
]
for path in priority_name_list:
if path.exists():
pt_path = path
break
# If the model hasn't been found with a well-behaved name, pick the last .pt
# or the last .safetensors found in its folder as a last resort
if not pt_path:
for ext in ['.pt', '.safetensors']:
found = list(path_to_model.glob(f"*{ext}"))
if len(found) > 0:
if len(found) > 1:
logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
pt_path = found[-1]
break
return pt_path
# The function that loads the model in modules/models.py
def load_quantized(model_name):
if shared.args.model_type is None:
logger.error("The model could not be loaded because its type could not be inferred from its name.")
logger.error("Please specify the type manually using the --model_type argument.")
return None
# Select the appropriate load_quant function
model_type = shared.args.model_type.lower()
if shared.args.pre_layer and model_type == 'llama':
load_quant = llama_inference_offload.load_quant
elif model_type in ('llama', 'opt', 'gptj'):
if shared.args.pre_layer:
logger.warning("Ignoring --pre_layer because it only works for llama model type.")
load_quant = _load_quant
else:
logger.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit()
# Find the quantized model weights file (.pt/.safetensors)
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = find_quantized_model_file(model_name)
if not pt_path:
logger.error("Could not find the quantized model in .pt or .safetensors format. Exiting.")
exit()
else:
logger.info(f"Found the following quantized model: {pt_path}")
# qwopqwop200's offload
if model_type == 'llama' and shared.args.pre_layer:
if len(shared.args.pre_layer) == 1:
pre_layer = shared.args.pre_layer[0]
else:
pre_layer = shared.args.pre_layer
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, pre_layer)
else:
threshold = False if model_type == 'gptj' else 128
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
# accelerate offload (doesn't work properly)
if shared.args.gpu_memory or torch.cuda.device_count() > 1 or (is_xpu_available() and torch.xpu.device_count() > 1):
if shared.args.gpu_memory:
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
max_memory = {}
for i in range(len(memory_map)):
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
max_memory['cpu'] = f'{max_cpu_memory}GiB' if not re.match('.*ib$', max_cpu_memory.lower()) else max_cpu_memory
else:
max_memory = accelerate.utils.get_balanced_memory(model)
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
logger.info("Using the following device map for the quantized model:", device_map)
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
# No offload
elif not shared.args.cpu:
if is_xpu_available():
model = model.to(torch.device("xpu:0"))
else:
model = model.to(torch.device('cuda:0'))
return model

View file

@ -105,7 +105,6 @@ loaders_and_params = OrderedDict({
], ],
'AutoGPTQ': [ 'AutoGPTQ': [
'triton', 'triton',
'no_inject_fused_attention',
'no_inject_fused_mlp', 'no_inject_fused_mlp',
'no_use_cuda_fp16', 'no_use_cuda_fp16',
'wbits', 'wbits',
@ -131,21 +130,6 @@ loaders_and_params = OrderedDict({
'trust_remote_code', 'trust_remote_code',
'no_use_fast', 'no_use_fast',
], ],
'GPTQ-for-LLaMa': [
'wbits',
'groupsize',
'model_type',
'pre_layer',
'trust_remote_code',
'no_use_fast',
'gptq_for_llama_info',
],
'QuIP#': [
'trust_remote_code',
'no_use_fast',
'no_flash_attn',
'quipsharp_info',
],
'HQQ': [ 'HQQ': [
'hqq_backend', 'hqq_backend',
'trust_remote_code', 'trust_remote_code',
@ -205,9 +189,7 @@ def transformers_samplers():
loaders_samplers = { loaders_samplers = {
'Transformers': transformers_samplers(), 'Transformers': transformers_samplers(),
'AutoGPTQ': transformers_samplers(), 'AutoGPTQ': transformers_samplers(),
'GPTQ-for-LLaMa': transformers_samplers(),
'AutoAWQ': transformers_samplers(), 'AutoAWQ': transformers_samplers(),
'QuIP#': transformers_samplers(),
'HQQ': transformers_samplers(), 'HQQ': transformers_samplers(),
'ExLlamav2': { 'ExLlamav2': {
'temperature', 'temperature',
@ -339,15 +321,6 @@ loaders_samplers = {
}, },
} }
loaders_model_types = {
'GPTQ-for-LLaMa': [
"None",
"llama",
"opt",
"gptj"
],
}
@functools.cache @functools.cache
def list_all_samplers(): def list_all_samplers():
@ -375,13 +348,6 @@ def blacklist_samplers(loader, dynamic_temperature):
return output return output
def get_model_types(loader):
if loader in loaders_model_types:
return loaders_model_types[loader]
return ["None"]
def get_gpu_memory_keys(): def get_gpu_memory_keys():
return [k for k in shared.gradio if k.startswith('gpu_memory')] return [k for k in shared.gradio if k.startswith('gpu_memory')]

View file

@ -73,13 +73,11 @@ def load_model(model_name, loader=None):
load_func_map = { load_func_map = {
'Transformers': huggingface_loader, 'Transformers': huggingface_loader,
'AutoGPTQ': AutoGPTQ_loader, 'AutoGPTQ': AutoGPTQ_loader,
'GPTQ-for-LLaMa': GPTQ_loader,
'llama.cpp': llamacpp_loader, 'llama.cpp': llamacpp_loader,
'llamacpp_HF': llamacpp_HF_loader, 'llamacpp_HF': llamacpp_HF_loader,
'ExLlamav2': ExLlamav2_loader, 'ExLlamav2': ExLlamav2_loader,
'ExLlamav2_HF': ExLlamav2_HF_loader, 'ExLlamav2_HF': ExLlamav2_HF_loader,
'AutoAWQ': AutoAWQ_loader, 'AutoAWQ': AutoAWQ_loader,
'QuIP#': QuipSharp_loader,
'HQQ': HQQ_loader, 'HQQ': HQQ_loader,
} }
@ -310,55 +308,6 @@ def AutoAWQ_loader(model_name):
return model return model
def QuipSharp_loader(model_name):
try:
with RelativeImport("repositories/quip-sharp"):
from lib.utils.unsafe_import import model_from_hf_path
except:
logger.error(
"\nQuIP# has not been found. It must be installed manually for now.\n"
"For instructions on how to do that, please consult:\n"
"https://github.com/oobabooga/text-generation-webui/pull/4803\n"
)
return None, None
# This fixes duplicate logging messages after the import above.
handlers = logging.getLogger().handlers
if len(handlers) > 1:
logging.getLogger().removeHandler(handlers[1])
model_dir = Path(f'{shared.args.model_dir}/{model_name}')
if not all((model_dir / file).exists() for file in ['tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.model']):
logger.error(f"Could not load the model because the tokenizer files could not be found in the model folder. Please download the following files from the original (unquantized) model into {model_dir}: special_tokens_map.json, tokenizer.json, tokenizer.model, tokenizer_config.json.")
return None, None
model, model_str = model_from_hf_path(
model_dir,
use_cuda_graph=False,
use_flash_attn=not shared.args.no_flash_attn
)
return model
def GPTQ_loader(model_name):
# Monkey patch
if shared.args.monkey_patch:
logger.warning("Applying the monkey patch for using LoRAs with GPTQ models. It may cause undefined behavior outside its intended scope.")
from modules.monkey_patch_gptq_lora import load_model_llama
model, _ = load_model_llama(model_name)
# No monkey patch
else:
import modules.GPTQ_loader
model = modules.GPTQ_loader.load_quantized(model_name)
return model
def AutoGPTQ_loader(model_name): def AutoGPTQ_loader(model_name):
import modules.AutoGPTQ_loader import modules.AutoGPTQ_loader
@ -380,12 +329,12 @@ def ExLlamav2_HF_loader(model_name):
def HQQ_loader(model_name): def HQQ_loader(model_name):
from hqq.core.quantize import HQQBackend, HQQLinear from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.engine.hf import HQQModelForCausalLM from hqq.models.hf.base import AutoHQQHFModel
logger.info(f"Loading HQQ model with backend: \"{shared.args.hqq_backend}\"") logger.info(f"Loading HQQ model with backend: \"{shared.args.hqq_backend}\"")
model_dir = Path(f'{shared.args.model_dir}/{model_name}') model_dir = Path(f'{shared.args.model_dir}/{model_name}')
model = HQQModelForCausalLM.from_quantized(str(model_dir)) model = AutoHQQHFModel.from_quantized(str(model_dir))
HQQLinear.set_backend(getattr(HQQBackend, shared.args.hqq_backend)) HQQLinear.set_backend(getattr(HQQBackend, shared.args.hqq_backend))
return model return model

View file

@ -40,12 +40,7 @@ def get_model_metadata(model):
hf_metadata = None hf_metadata = None
if 'loader' not in model_settings: if 'loader' not in model_settings:
if hf_metadata is not None and 'quip_params' in hf_metadata: model_settings['loader'] = infer_loader(model, model_settings)
loader = 'QuIP#'
else:
loader = infer_loader(model, model_settings)
model_settings['loader'] = loader
# GGUF metadata # GGUF metadata
if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF']: if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF']:
@ -242,7 +237,7 @@ def apply_model_settings_to_state(model, state):
loader = model_settings.pop('loader') loader = model_settings.pop('loader')
# If the user is using an alternative loader for the same model type, let them keep using it # If the user is using an alternative loader for the same model type, let them keep using it
if not (loader == 'ExLlamav2_HF' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlamav2', 'AutoGPTQ']): if not (loader == 'ExLlamav2_HF' and state['loader'] in ['ExLlamav2', 'AutoGPTQ']):
state['loader'] = loader state['loader'] = loader
for k in model_settings: for k in model_settings:

View file

@ -1,39 +0,0 @@
# Copied from https://github.com/johnsmith0031/alpaca_lora_4bit
from pathlib import Path
import alpaca_lora_4bit.autograd_4bit as autograd_4bit
from alpaca_lora_4bit.amp_wrapper import AMPWrapper
from alpaca_lora_4bit.autograd_4bit import (
Autograd4bitQuantLinear,
load_llama_model_4bit_low_ram
)
from alpaca_lora_4bit.models import Linear4bitLt
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
replace_peft_model_with_int4_lora_model
)
from modules import shared
from modules.GPTQ_loader import find_quantized_model_file
replace_peft_model_with_int4_lora_model()
def load_model_llama(model_name):
config_path = str(Path(f'{shared.args.model_dir}/{model_name}'))
model_path = str(find_quantized_model_file(model_name))
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=shared.args.groupsize, is_v1_model=False)
for _, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt):
if m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
autograd_4bit.auto_switch = True
model.half()
wrapper = AMPWrapper(model)
wrapper.apply_generate()
return model, tokenizer

View file

@ -89,7 +89,7 @@ group.add_argument('--idle-timeout', type=int, default=0, help='Unload model aft
# Model loader # Model loader
group = parser.add_argument_group('Model loader') group = parser.add_argument_group('Model loader')
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav2_HF, ExLlamav2, AutoGPTQ, AutoAWQ, GPTQ-for-LLaMa, QuIP#.') group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav2_HF, ExLlamav2, AutoGPTQ, AutoAWQ.')
# Transformers/Accelerate # Transformers/Accelerate
group = parser.add_argument_group('Transformers/Accelerate') group = parser.add_argument_group('Transformers/Accelerate')
@ -149,21 +149,17 @@ group.add_argument('--num_experts_per_token', type=int, default=2, help='Number
# AutoGPTQ # AutoGPTQ
group = parser.add_argument_group('AutoGPTQ') group = parser.add_argument_group('AutoGPTQ')
group.add_argument('--triton', action='store_true', help='Use triton.') group.add_argument('--triton', action='store_true', help='Use triton.')
group.add_argument('--no_inject_fused_attention', action='store_true', help='Disable the use of fused attention, which will use less VRAM at the cost of slower inference.')
group.add_argument('--no_inject_fused_mlp', action='store_true', help='Triton mode only: disable the use of fused MLP, which will use less VRAM at the cost of slower inference.') group.add_argument('--no_inject_fused_mlp', action='store_true', help='Triton mode only: disable the use of fused MLP, which will use less VRAM at the cost of slower inference.')
group.add_argument('--no_use_cuda_fp16', action='store_true', help='This can make models faster on some systems.') group.add_argument('--no_use_cuda_fp16', action='store_true', help='This can make models faster on some systems.')
group.add_argument('--desc_act', action='store_true', help='For models that do not have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.') group.add_argument('--desc_act', action='store_true', help='For models that do not have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.')
group.add_argument('--disable_exllama', action='store_true', help='Disable ExLlama kernel, which can improve inference speed on some systems.') group.add_argument('--disable_exllama', action='store_true', help='Disable ExLlama kernel, which can improve inference speed on some systems.')
group.add_argument('--disable_exllamav2', action='store_true', help='Disable ExLlamav2 kernel.') group.add_argument('--disable_exllamav2', action='store_true', help='Disable ExLlamav2 kernel.')
# GPTQ-for-LLaMa
group = parser.add_argument_group('GPTQ-for-LLaMa')
group.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.') group.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
group.add_argument('--model_type', type=str, help='Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
group.add_argument('--groupsize', type=int, default=-1, help='Group size.') group.add_argument('--groupsize', type=int, default=-1, help='Group size.')
group.add_argument('--pre_layer', type=int, nargs='+', help='The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. For multi-gpu, write the numbers separated by spaces, eg --pre_layer 30 60.')
group.add_argument('--checkpoint', type=str, help='The path to the quantized checkpoint file. If not specified, it will be automatically detected.') # AutoAWQ
group.add_argument('--monkey-patch', action='store_true', help='Apply the monkey patch for using LoRAs with quantized models.') group = parser.add_argument_group('AutoAWQ')
group.add_argument('--no_inject_fused_attention', action='store_true', help='Disable the use of fused attention, which will use less VRAM at the cost of slower inference.')
# HQQ # HQQ
group = parser.add_argument_group('HQQ') group = parser.add_argument_group('HQQ')
@ -208,7 +204,11 @@ group = parser.add_argument_group('Multimodal')
group.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.') group.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
# Deprecated parameters # Deprecated parameters
# group = parser.add_argument_group('Deprecated') group = parser.add_argument_group('Deprecated')
group.add_argument('--model_type', type=str, help='DEPRECATED')
group.add_argument('--pre_layer', type=int, nargs='+', help='DEPRECATED')
group.add_argument('--checkpoint', type=str, help='DEPRECATED')
group.add_argument('--monkey-patch', action='store_true', help='DEPRECATED')
args = parser.parse_args() args = parser.parse_args()
args_defaults = parser.parse_args([]) args_defaults = parser.parse_args([])
@ -253,8 +253,6 @@ def fix_loader_name(name):
return 'Transformers' return 'Transformers'
elif name in ['autogptq', 'auto-gptq', 'auto_gptq', 'auto gptq']: elif name in ['autogptq', 'auto-gptq', 'auto_gptq', 'auto gptq']:
return 'AutoGPTQ' return 'AutoGPTQ'
elif name in ['gptq-for-llama', 'gptqforllama', 'gptqllama', 'gptq for llama', 'gptq_for_llama']:
return 'GPTQ-for-LLaMa'
elif name in ['exllama', 'ex-llama', 'ex_llama', 'exlama']: elif name in ['exllama', 'ex-llama', 'ex_llama', 'exlama']:
return 'ExLlama' return 'ExLlama'
elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2']: elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2']:
@ -263,8 +261,6 @@ def fix_loader_name(name):
return 'ExLlamav2_HF' return 'ExLlamav2_HF'
elif name in ['autoawq', 'awq', 'auto-awq']: elif name in ['autoawq', 'awq', 'auto-awq']:
return 'AutoAWQ' return 'AutoAWQ'
elif name in ['quip#', 'quip-sharp', 'quipsharp', 'quip_sharp']:
return 'QuIP#'
elif name in ['hqq']: elif name in ['hqq']:
return 'HQQ' return 'HQQ'

View file

@ -292,12 +292,6 @@ def calc_trainable_parameters(model):
def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str): def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str):
if shared.args.monkey_patch:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
replace_peft_model_with_int4_lora_model
)
replace_peft_model_with_int4_lora_model()
global WANT_INTERRUPT global WANT_INTERRUPT
WANT_INTERRUPT = False WANT_INTERRUPT = False
@ -329,10 +323,6 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
time.sleep(5) time.sleep(5)
if shared.args.loader == 'GPTQ-for-LLaMa' and not shared.args.monkey_patch:
yield "LoRA training with GPTQ-for-LLaMa requires loading with `--monkey-patch`"
return
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
yield "Cannot input zeroes." yield "Cannot input zeroes."
return return
@ -553,15 +543,6 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
yield traceback.format_exc().replace('\n', '\n\n') yield traceback.format_exc().replace('\n', '\n\n')
return return
if shared.args.monkey_patch:
from alpaca_lora_4bit.autograd_4bit import Autograd4bitQuantLinear
from alpaca_lora_4bit.models import Linear4bitLt
for _, m in lora_model.named_modules():
if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt):
if m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
class Tracked(): class Tracked():
def __init__(self): def __init__(self):
self.current_steps = 0 self.current_steps = 0

View file

@ -111,7 +111,6 @@ def create_ui():
shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.', value=shared.args.compress_pos_emb) shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.', value=shared.args.compress_pos_emb)
shared.gradio['autogptq_info'] = gr.Markdown('ExLlamav2_HF is recommended over AutoGPTQ for models derived from Llama.') shared.gradio['autogptq_info'] = gr.Markdown('ExLlamav2_HF is recommended over AutoGPTQ for models derived from Llama.')
shared.gradio['quipsharp_info'] = gr.Markdown('QuIP# has to be installed manually at the moment.')
with gr.Column(): with gr.Column():
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit) shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)

View file

@ -388,7 +388,7 @@ def update_requirements(initial_installation=False, pull=True):
# Prepare the requirements file # Prepare the requirements file
textgen_requirements = open(requirements_file).read().splitlines() textgen_requirements = open(requirements_file).read().splitlines()
if is_cuda118: if is_cuda118:
textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements] textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements if "auto-gptq" not in req]
if is_windows() and is_cuda118: # No flash-attention on Windows for CUDA 11 if is_windows() and is_cuda118: # No flash-attention on Windows for CUDA 11
textgen_requirements = [req for req in textgen_requirements if 'oobabooga/flash-attention' not in req] textgen_requirements = [req for req in textgen_requirements if 'oobabooga/flash-attention' not in req]

View file

@ -1,11 +1,12 @@
accelerate==0.30.* accelerate==0.30.*
aqlm[gpu,cpu]==1.1.3; platform_system == "Linux" aqlm[gpu,cpu]==1.1.5; platform_system == "Linux"
auto-gptq==0.7.1
bitsandbytes==0.43.* bitsandbytes==0.43.*
colorama colorama
datasets datasets
einops einops
gradio==4.26.* gradio==4.26.*
hqq==0.1.5 hqq==0.1.7.post2
jinja2==3.1.2 jinja2==3.1.2
lm_eval==0.3.0 lm_eval==0.3.0
markdown markdown
@ -23,7 +24,7 @@ safetensors==0.4.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
transformers==4.40.* transformers==4.41.*
tqdm tqdm
wandb wandb
@ -52,10 +53,6 @@ https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/te
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.75+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.75+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
# CUDA wheels # CUDA wheels
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
@ -65,8 +62,4 @@ https://github.com/oobabooga/flash-attention/releases/download/v2.5.6/flash_attn
https://github.com/oobabooga/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2.0cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" https://github.com/oobabooga/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2.0cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" autoawq==0.2.5; platform_system == "Linux" or platform_system == "Windows"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
autoawq==0.2.3; platform_system == "Linux" or platform_system == "Windows"

View file

@ -3,7 +3,7 @@ colorama
datasets datasets
einops einops
gradio==4.26.* gradio==4.26.*
hqq==0.1.5 hqq==0.1.7.post2
jinja2==3.1.2 jinja2==3.1.2
lm_eval==0.3.0 lm_eval==0.3.0
markdown markdown
@ -21,7 +21,7 @@ safetensors==0.4.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
transformers==4.40.* transformers==4.41.*
tqdm tqdm
wandb wandb
@ -40,12 +40,8 @@ https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cp
# AMD wheels # AMD wheels
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/rocm/llama_cpp_python_cuda-0.2.75+rocm5.6.1-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/rocm/llama_cpp_python_cuda-0.2.75+rocm5.6.1-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/rocm/llama_cpp_python_cuda-0.2.75+rocm5.6.1-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/rocm/llama_cpp_python_cuda-0.2.75+rocm5.6.1-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+rocm5.6-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+rocm5.6-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+rocm5.6-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+rocm5.6-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+rocm5.6-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+rocm5.6-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64" https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+rocm5.6-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.5/autoawq-0.2.5+rocm561-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+rocm5.6-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.5/autoawq-0.2.5+rocm561-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+rocm561-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+rocm561-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"

View file

@ -3,7 +3,7 @@ colorama
datasets datasets
einops einops
gradio==4.26.* gradio==4.26.*
hqq==0.1.5 hqq==0.1.7.post2
jinja2==3.1.2 jinja2==3.1.2
lm_eval==0.3.0 lm_eval==0.3.0
markdown markdown
@ -21,7 +21,7 @@ safetensors==0.4.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
transformers==4.40.* transformers==4.41.*
tqdm tqdm
wandb wandb
@ -38,12 +38,8 @@ https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cp
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.75+cpuavx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.75+cpuavx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
# AMD wheels # AMD wheels
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+rocm5.6-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+rocm5.6-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+rocm5.6-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+rocm5.6-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+rocm5.6-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+rocm5.6-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64" https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+rocm5.6-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.5/autoawq-0.2.5+rocm561-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+rocm5.6-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.5/autoawq-0.2.5+rocm561-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+rocm561-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+rocm561-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"

View file

@ -3,7 +3,7 @@ colorama
datasets datasets
einops einops
gradio==4.26.* gradio==4.26.*
hqq==0.1.5 hqq==0.1.7.post2
jinja2==3.1.2 jinja2==3.1.2
lm_eval==0.3.0 lm_eval==0.3.0
markdown markdown
@ -21,7 +21,7 @@ safetensors==0.4.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
transformers==4.40.* transformers==4.41.*
tqdm tqdm
wandb wandb

View file

@ -3,7 +3,7 @@ colorama
datasets datasets
einops einops
gradio==4.26.* gradio==4.26.*
hqq==0.1.5 hqq==0.1.7.post2
jinja2==3.1.2 jinja2==3.1.2
lm_eval==0.3.0 lm_eval==0.3.0
markdown markdown
@ -21,7 +21,7 @@ safetensors==0.4.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
transformers==4.40.* transformers==4.41.*
tqdm tqdm
wandb wandb

View file

@ -3,7 +3,7 @@ colorama
datasets datasets
einops einops
gradio==4.26.* gradio==4.26.*
hqq==0.1.5 hqq==0.1.7.post2
jinja2==3.1.2 jinja2==3.1.2
lm_eval==0.3.0 lm_eval==0.3.0
markdown markdown
@ -21,7 +21,7 @@ safetensors==0.4.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
transformers==4.40.* transformers==4.41.*
tqdm tqdm
wandb wandb

View file

@ -3,7 +3,7 @@ colorama
datasets datasets
einops einops
gradio==4.26.* gradio==4.26.*
hqq==0.1.5 hqq==0.1.7.post2
jinja2==3.1.2 jinja2==3.1.2
lm_eval==0.3.0 lm_eval==0.3.0
markdown markdown
@ -21,7 +21,7 @@ safetensors==0.4.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
transformers==4.40.* transformers==4.41.*
tqdm tqdm
wandb wandb

View file

@ -1,11 +1,12 @@
accelerate==0.30.* accelerate==0.30.*
aqlm[gpu,cpu]==1.1.3; platform_system == "Linux" aqlm[gpu,cpu]==1.1.5; platform_system == "Linux"
auto-gptq==0.7.1
bitsandbytes==0.43.* bitsandbytes==0.43.*
colorama colorama
datasets datasets
einops einops
gradio==4.26.* gradio==4.26.*
hqq==0.1.5 hqq==0.1.7.post2
jinja2==3.1.2 jinja2==3.1.2
lm_eval==0.3.0 lm_eval==0.3.0
markdown markdown
@ -23,7 +24,7 @@ safetensors==0.4.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
transformers==4.40.* transformers==4.41.*
tqdm tqdm
wandb wandb
@ -52,10 +53,6 @@ https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/te
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.75+cu121avx-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.75+cu121avx-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
# CUDA wheels # CUDA wheels
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
@ -65,8 +62,4 @@ https://github.com/oobabooga/flash-attention/releases/download/v2.5.6/flash_attn
https://github.com/oobabooga/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2.0cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" https://github.com/oobabooga/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2.0cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" autoawq==0.2.5; platform_system == "Linux" or platform_system == "Windows"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
autoawq==0.2.3; platform_system == "Linux" or platform_system == "Windows"

View file

@ -3,7 +3,7 @@ colorama
datasets datasets
einops einops
gradio==4.26.* gradio==4.26.*
hqq==0.1.5 hqq==0.1.7.post2
jinja2==3.1.2 jinja2==3.1.2
lm_eval==0.3.0 lm_eval==0.3.0
markdown markdown
@ -21,7 +21,7 @@ safetensors==0.4.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
transformers==4.40.* transformers==4.41.*
tqdm tqdm
wandb wandb