From fd4e46bce296d2c6b19b76ac84afbff73fe6108a Mon Sep 17 00:00:00 2001 From: wangshuai09 <391746016@qq.com> Date: Fri, 12 Apr 2024 05:42:20 +0800 Subject: [PATCH] Add Ascend NPU support (basic) (#5541) --- modules/callbacks.py | 4 +++- modules/logits.py | 6 +++++- modules/models.py | 12 +++++++++++- modules/text_generation.py | 15 ++++++++++++--- modules/ui_model_menu.py | 5 ++++- 5 files changed, 35 insertions(+), 7 deletions(-) diff --git a/modules/callbacks.py b/modules/callbacks.py index 0b219954..2b039ef1 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -5,7 +5,7 @@ from threading import Thread import torch import transformers -from transformers import is_torch_xpu_available +from transformers import is_torch_npu_available, is_torch_xpu_available import modules.shared as shared @@ -99,5 +99,7 @@ def clear_torch_cache(): if not shared.args.cpu: if is_torch_xpu_available(): torch.xpu.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() else: torch.cuda.empty_cache() diff --git a/modules/logits.py b/modules/logits.py index c630be88..f2fd233b 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -1,5 +1,5 @@ import torch -from transformers import is_torch_xpu_available +from transformers import is_torch_npu_available, is_torch_xpu_available from modules import sampler_hijack, shared from modules.logging_colors import logger @@ -34,6 +34,8 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return if is_non_hf_exllamav2: if is_torch_xpu_available(): tokens = shared.tokenizer.encode(prompt).to("xpu:0") + elif is_torch_npu_available(): + tokens = shared.tokenizer.encode(prompt).to("npu:0") else: tokens = shared.tokenizer.encode(prompt).cuda() scores = shared.model.get_logits(tokens)[-1][-1] @@ -43,6 +45,8 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return else: if is_torch_xpu_available(): tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("xpu:0") + elif is_torch_npu_available(): + tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("npu:0") else: tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda() output = shared.model(input_ids=tokens) diff --git a/modules/models.py b/modules/models.py index 1519fc89..20a65764 100644 --- a/modules/models.py +++ b/modules/models.py @@ -10,7 +10,11 @@ from pathlib import Path import torch import transformers from accelerate import infer_auto_device_map, init_empty_weights -from accelerate.utils import is_ccl_available, is_xpu_available +from accelerate.utils import ( + is_ccl_available, + is_npu_available, + is_xpu_available +) from transformers import ( AutoConfig, AutoModel, @@ -45,6 +49,9 @@ if shared.args.deepspeed: if is_xpu_available() and is_ccl_available(): torch.xpu.set_device(local_rank) deepspeed.init_distributed(backend="ccl") + elif is_npu_available(): + torch.npu.set_device(local_rank) + deepspeed.init_distributed(dist_backend="hccl") else: torch.cuda.set_device(local_rank) deepspeed.init_distributed() @@ -164,6 +171,9 @@ def huggingface_loader(model_name): elif is_xpu_available(): device = torch.device("xpu") model = model.to(device) + elif is_npu_available(): + device = torch.device("npu") + model = model.to(device) else: model = model.cuda() diff --git a/modules/text_generation.py b/modules/text_generation.py index f99c605e..5c5727d6 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -10,7 +10,11 @@ import traceback import numpy as np import torch import transformers -from transformers import LogitsProcessorList, is_torch_xpu_available +from transformers import ( + LogitsProcessorList, + is_torch_npu_available, + is_torch_xpu_available +) import modules.shared as shared from modules.cache_utils import process_llamacpp_cache @@ -24,7 +28,7 @@ from modules.grammar.grammar_utils import initialize_grammar from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor from modules.html_generator import generate_basic_html from modules.logging_colors import logger -from modules.models import clear_torch_cache, local_rank +from modules.models import clear_torch_cache def generate_reply(*args, **kwargs): @@ -131,12 +135,15 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model'] or shared.args.cpu: return input_ids elif shared.args.deepspeed: - return input_ids.to(device=local_rank) + import deepspeed + return input_ids.to(deepspeed.get_accelerator().current_device_name()) elif torch.backends.mps.is_available(): device = torch.device('mps') return input_ids.to(device) elif is_torch_xpu_available(): return input_ids.to("xpu:0") + elif is_torch_npu_available(): + return input_ids.to("npu:0") else: return input_ids.cuda() @@ -213,6 +220,8 @@ def set_manual_seed(seed): torch.cuda.manual_seed_all(seed) elif is_torch_xpu_available(): torch.xpu.manual_seed_all(seed) + elif is_torch_npu_available(): + torch.npu.manual_seed_all(seed) return seed diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 29f0b926..7f7a3ab8 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -8,7 +8,7 @@ from pathlib import Path import gradio as gr import psutil import torch -from transformers import is_torch_xpu_available +from transformers import is_torch_npu_available, is_torch_xpu_available from modules import loaders, shared, ui, utils from modules.logging_colors import logger @@ -32,6 +32,9 @@ def create_ui(): if is_torch_xpu_available(): for i in range(torch.xpu.device_count()): total_mem.append(math.floor(torch.xpu.get_device_properties(i).total_memory / (1024 * 1024))) + elif is_torch_npu_available(): + for i in range(torch.npu.device_count()): + total_mem.append(math.floor(torch.npu.get_device_properties(i).total_memory / (1024 * 1024))) else: for i in range(torch.cuda.device_count()): total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)))