Merge #1073 from sgsdxzy/triton

* Multi-GPU support for triton
* Better quantized model filename detection
This commit is contained in:
oobabooga 2023-04-13 11:31:21 -03:00 committed by GitHub
commit 8b482b4127
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 32 deletions

View file

@ -239,6 +239,7 @@ Optionally, you can use the following command-line flags:
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. |
| `--no-warmup_autotune` | GPTQ: Disable warmup autotune for triton. |
#### FlexGen

View file

@ -61,6 +61,16 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
model.load_state_dict(safe_load(checkpoint), strict=False)
else:
model.load_state_dict(torch.load(checkpoint), strict=False)
try:
from quant import autotune_warmup, make_quant_attn
# triton branch
make_quant_attn(model)
if not shared.args.no_warmup_autotune:
autotune_warmup(model)
except ImportError: # not triton branch
pass
model.seqlen = 2048
print('Done.')
@ -68,8 +78,9 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
def load_quantized(model_name):
# Find the model type
if not shared.args.model_type:
# Try to determine model type from model name
name = model_name.lower()
if any((k in name for k in ['llama', 'alpaca', 'vicuna'])):
model_type = 'llama'
@ -84,6 +95,7 @@ def load_quantized(model_name):
else:
model_type = shared.args.model_type.lower()
# Select the appropriate load_quant function
if shared.args.pre_layer and model_type == 'llama':
load_quant = llama_inference_offload.load_quant
elif model_type in ('llama', 'opt', 'gptj'):
@ -94,33 +106,34 @@ def load_quantized(model_name):
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit()
# Now we are going to try to locate the quantized model file.
# Locate the quantized model file
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
found_pts = list(path_to_model.glob("*.pt"))
found_safetensors = list(path_to_model.glob("*.safetensors"))
pt_path = None
priority_name_list = [
Path(f'{shared.args.model_dir}/{model_name}{hyphen}{shared.args.wbits}bit{group}{ext}')
for ext in ['.safetensors', '.pt']
for group in ([f'-{shared.args.groupsize}g', ''] if shared.args.groupsize > 0 else [''])
for hyphen in ['-', f'/{model_name}-', '/']
]
if shared.args.groupsize > 0:
priority_name_list = [i for i in priority_name_list if str(shared.args.groupsize) in i.name] + [i for i in priority_name_list if str(shared.args.groupsize) not in i.name]
for path in priority_name_list:
if path.exists():
pt_path = path
break
if len(found_pts) > 0:
pt_path = found_pts[-1]
elif len(found_safetensors) > 0:
pt_path = found_safetensors[-1]
else:
if path_to_model.name.lower().startswith('llama-7b'):
pt_model = f'llama-7b-{shared.args.wbits}bit'
elif path_to_model.name.lower().startswith('llama-13b'):
pt_model = f'llama-13b-{shared.args.wbits}bit'
elif path_to_model.name.lower().startswith('llama-30b'):
pt_model = f'llama-30b-{shared.args.wbits}bit'
elif path_to_model.name.lower().startswith('llama-65b'):
pt_model = f'llama-65b-{shared.args.wbits}bit'
else:
pt_model = f'{model_name}-{shared.args.wbits}bit'
# If the model hasn't been found with a well-behaved name, pick the last .pt
# or the last .safetensors found in its folder as a last resort
if not pt_path:
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
found_pts = list(path_to_model.glob("*.pt"))
found_safetensors = list(path_to_model.glob("*.safetensors"))
pt_path = None
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists():
pt_path = path
break
if len(found_pts) > 0:
pt_path = found_pts[-1]
elif len(found_safetensors) > 0:
pt_path = found_safetensors[-1]
if not pt_path:
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
@ -136,16 +149,19 @@ def load_quantized(model_name):
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
# accelerate offload (doesn't work properly)
if shared.args.gpu_memory:
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
max_memory = {}
for i in range(len(memory_map)):
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
max_memory['cpu'] = max_cpu_memory
if shared.args.gpu_memory or torch.cuda.device_count() > 1:
if shared.args.gpu_memory:
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
max_memory = {}
for i in range(len(memory_map)):
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
max_memory['cpu'] = max_cpu_memory
else:
max_memory = accelerate.utils.get_balanced_memory(model)
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
print("Using the following device map for the 4-bit model:", device_map)
print("Using the following device map for the quantized model:", device_map)
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)

View file

@ -117,6 +117,7 @@ parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quant
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models.')
parser.add_argument('--no-warmup_autotune', action='store_true', help='GPTQ: Disable warmup autotune for triton.')
# FlexGen
parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')