Merge pull request #43 from 81300/ds

Add DeepSpeed ZeRO-3 integration
This commit is contained in:
oobabooga 2023-02-02 10:03:19 -03:00 committed by GitHub
commit 1a658b41aa
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 133 additions and 5 deletions

4
characters/.gitignore vendored Normal file
View file

@ -0,0 +1,4 @@
*
!Example.json
!Example.png
!.gitignore

View file

@ -13,6 +13,7 @@ charset-normalizer==2.1.1
click==8.1.3 click==8.1.3
contourpy==1.0.6 contourpy==1.0.6
cycler==0.11.0 cycler==0.11.0
deepspeed==0.8.0
entrypoints==0.4 entrypoints==0.4
fastapi==0.88.0 fastapi==0.88.0
ffmpy==0.3.0 ffmpy==0.3.0

127
server.py
View file

@ -8,6 +8,7 @@ import json
import io import io
import base64 import base64
import sys import sys
import os
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
import copy import copy
@ -15,7 +16,7 @@ import gradio as gr
import warnings import warnings
from tqdm import tqdm from tqdm import tqdm
import transformers import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from modules.html_generator import * from modules.html_generator import *
from modules.ui import * from modules.ui import *
from modules.stopping_criteria import _SentinelTokenStoppingCriteria from modules.stopping_criteria import _SentinelTokenStoppingCriteria
@ -34,6 +35,10 @@ parser.add_argument('--disk', action='store_true', help='If the model is too lar
parser.add_argument('--disk-cache-dir', type=str, help='Directory to save the disk cache to. Defaults to "cache/".') parser.add_argument('--disk-cache-dir', type=str, help='Directory to save the disk cache to. Defaults to "cache/".')
parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.') parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.') parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
parser.add_argument('--nvme-offload-dir', type=str, help='Directory to use for DeepSpeed ZeRO-3 NVME offloading.')
parser.add_argument('--bf16', action='store_true', help='Instantiate the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--local_rank', type=int, default=0, help='Optional argument for DeepSpeed distributed setups.')
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".') parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".')
@ -72,12 +77,104 @@ if args.settings is not None and Path(args.settings).exists():
for item in new_settings: for item in new_settings:
settings[item] = new_settings[item] settings[item] = new_settings[item]
if args.deepspeed:
import deepspeed
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled
# Distributed setup
if args.local_rank is not None:
local_rank = args.local_rank
else:
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()
# DeepSpeed configration
# https://huggingface.co/docs/transformers/main_classes/deepspeed
if args.bf16:
ds_fp16 = False
ds_bf16 = True
else:
ds_fp16 = True
ds_bf16 = False
train_batch_size = 1 * world_size
if args.nvme_offload_dir:
ds_config = {
"fp16": {
"enabled": ds_fp16,
},
"bf16": {
"enabled": ds_bf16,
},
"zero_optimization": {
"stage": 3,
"offload_param": {
"device": "nvme",
"nvme_path": args.nvme_offload_dir,
"pin_memory": True,
"buffer_count": 5,
"buffer_size": 1e9,
"max_in_cpu": 1e9
},
"overlap_comm": True,
"reduce_bucket_size": "auto",
"contiguous_gradients": True,
"sub_group_size": 1e8,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": "auto",
"stage3_max_reuse_distance": "auto",
},
"aio": {
"block_size": 262144,
"queue_depth": 32,
"thread_count": 1,
"single_submit": False,
"overlap_events": True
},
"steps_per_print": 2000,
"train_batch_size": train_batch_size,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": False
}
else:
ds_config = {
"fp16": {
"enabled": ds_fp16,
},
"bf16": {
"enabled": ds_bf16,
},
"zero_optimization": {
"stage": 3,
"offload_param": {
"device": "cpu",
"pin_memory": True
},
"overlap_comm": True,
"contiguous_gradients": True,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": "auto",
"stage3_max_reuse_distance": "auto",
},
"steps_per_print": 2000,
"train_batch_size": train_batch_size,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": False
}
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
def load_model(model_name): def load_model(model_name):
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
t0 = time.time() t0 = time.time()
# Default settings # Default settings
if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None): if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None or args.deepspeed):
if Path(f"torch-dumps/{model_name}.pt").exists(): if Path(f"torch-dumps/{model_name}.pt").exists():
print("Loading in .pt format...") print("Loading in .pt format...")
model = torch.load(Path(f"torch-dumps/{model_name}.pt")) model = torch.load(Path(f"torch-dumps/{model_name}.pt"))
@ -85,6 +182,21 @@ def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True) model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True)
else: else:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
# DeepSpeed ZeRO-3
elif args.deepspeed:
if args.bf16:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16)
else:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.float16)
model = deepspeed.initialize(model=model,
config_params=ds_config,
model_parameters=None,
optimizer=None,
lr_scheduler=None)[0]
model.module.eval() # Inference
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
# Custom # Custom
else: else:
command = "AutoModelForCausalLM.from_pretrained" command = "AutoModelForCausalLM.from_pretrained"
@ -190,6 +302,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
cuda = "" if args.cpu else ".cuda()" cuda = "" if args.cpu else ".cuda()"
n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1] n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
if args.deepspeed:
input_ids = encode(question, tokens).to(device=local_rank)
else:
input_ids = encode(question, tokens) input_ids = encode(question, tokens)
if stopping_string is not None: if stopping_string is not None:
# The stopping_criteria code below was copied from # The stopping_criteria code below was copied from
@ -207,6 +322,10 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
# Generate the entire reply at once # Generate the entire reply at once
if args.no_stream: if args.no_stream:
t0 = time.time() t0 = time.time()
if args.deepspeed:
with torch.no_grad():
output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})")
else:
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}") output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
reply = decode(output[0]) reply = decode(output[0])
t1 = time.time() t1 = time.time()
@ -220,6 +339,10 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
yield formatted_outputs(original_question, model_name) yield formatted_outputs(original_question, model_name)
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8') preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8')
for i in tqdm(range(tokens//8+1)): for i in tqdm(range(tokens//8+1)):
if args.deepspeed:
with torch.no_grad():
output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})")
else:
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}") output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
reply = decode(output[0]) reply = decode(output[0])
if not (args.chat or args.cai_chat): if not (args.chat or args.cai_chat):