[Feature] Support for Ascend NPU backend (#3853)
Signed-off-by: Song Zhang <gepin.zs@antgroup.com> Co-authored-by: 22dimensions <waitingwind@foxmail.com>
This commit is contained in:
@@ -89,6 +89,8 @@ srt_hpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
|
||||
# To install vllm for CPU, please follow the instruction here:
|
||||
# https://docs.vllm.ai/en/latest/getting_started/installation/cpu/index.html
|
||||
srt_cpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11", "torch"]
|
||||
# https://vllm-ascend.readthedocs.io/en/latest/installation.html
|
||||
srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
|
||||
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
@@ -107,6 +109,7 @@ all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[lit
|
||||
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
all_cpu = ["sglang[srt_cpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
all_npu = ["sglang[srt_npu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
|
||||
dev = ["sglang[all]", "sglang[test]"]
|
||||
dev_hip = ["sglang[all_hip]", "sglang[test]"]
|
||||
|
||||
@@ -10,7 +10,7 @@ class DeviceConfig:
|
||||
device: Optional[torch.device]
|
||||
|
||||
def __init__(self, device: str = "cuda") -> None:
|
||||
if device in ["cuda", "xpu", "hpu", "cpu"]:
|
||||
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
|
||||
self.device_type = device
|
||||
else:
|
||||
raise RuntimeError(f"Not supported device type: {device}")
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.utils import is_npu
|
||||
|
||||
|
||||
class NpuCommunicator:
|
||||
|
||||
def __init__(self, group: ProcessGroup):
|
||||
if not is_npu():
|
||||
self.disabled = True
|
||||
return
|
||||
self.disabled = False
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(x, group=self.group)
|
||||
return x
|
||||
|
||||
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += x.dim()
|
||||
input_size = x.size()
|
||||
output_size = (input_size[0] * world_size,) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size, dtype=x.dtype, device=x.device)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
|
||||
)
|
||||
return output_tensor
|
||||
@@ -42,6 +42,7 @@ from torch.distributed import Backend, ProcessGroup
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
is_cuda_alike,
|
||||
is_npu,
|
||||
supports_custom_op,
|
||||
)
|
||||
|
||||
@@ -206,6 +207,7 @@ class GroupCoordinator:
|
||||
use_custom_allreduce: bool,
|
||||
use_hpu_communicator: bool,
|
||||
use_xpu_communicator: bool,
|
||||
use_npu_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
):
|
||||
@@ -244,6 +246,7 @@ class GroupCoordinator:
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_hpu_communicator = use_hpu_communicator
|
||||
self.use_xpu_communicator = use_xpu_communicator
|
||||
self.use_npu_communicator = use_npu_communicator
|
||||
self.use_message_queue_broadcaster = use_message_queue_broadcaster
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
@@ -291,6 +294,14 @@ class GroupCoordinator:
|
||||
if use_xpu_communicator and self.world_size > 1:
|
||||
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
||||
|
||||
from sglang.srt.distributed.device_communicators.npu_communicator import (
|
||||
NpuCommunicator,
|
||||
)
|
||||
|
||||
self.npu_communicator: Optional[NpuCommunicator] = None
|
||||
if use_npu_communicator and self.world_size > 1:
|
||||
self.npu_communicator = NpuCommunicator(group=self.device_group)
|
||||
|
||||
from sglang.srt.distributed.device_communicators.shm_broadcast import (
|
||||
MessageQueue,
|
||||
)
|
||||
@@ -418,6 +429,9 @@ class GroupCoordinator:
|
||||
if self.xpu_communicator is not None and not self.xpu_communicator.disabled:
|
||||
return self.xpu_communicator.all_reduce(input_)
|
||||
|
||||
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
||||
return self.npu_communicator.all_reduce(input_)
|
||||
|
||||
if (
|
||||
self.ca_comm is not None
|
||||
and not self.ca_comm.disabled
|
||||
@@ -497,6 +511,11 @@ class GroupCoordinator:
|
||||
if hpu_comm is not None and not hpu_comm.disabled:
|
||||
return hpu_comm.all_gather(input_, dim)
|
||||
|
||||
# For NPUs, use NPU communicator.
|
||||
npu_comm = self.npu_communicator
|
||||
if npu_comm is not None and not npu_comm.disabled:
|
||||
return npu_comm.all_gather(input_, dim)
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
@@ -941,6 +960,7 @@ def init_world_group(
|
||||
use_custom_allreduce=False,
|
||||
use_hpu_communicator=False,
|
||||
use_xpu_communicator=False,
|
||||
use_npu_communicator=False,
|
||||
group_name="world",
|
||||
)
|
||||
|
||||
@@ -959,10 +979,11 @@ def init_model_parallel_group(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=True,
|
||||
use_pynccl=not is_npu(),
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
use_hpu_communicator=True,
|
||||
use_xpu_communicator=True,
|
||||
use_npu_communicator=True,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
@@ -361,6 +361,8 @@ class ModelRunner:
|
||||
backend = "hccl"
|
||||
elif self.device == "cpu":
|
||||
backend = "gloo"
|
||||
elif self.device == "npu":
|
||||
backend = "hccl"
|
||||
|
||||
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
if not self.server_args.enable_p2p_check:
|
||||
|
||||
@@ -555,7 +555,7 @@ class ServerArgs:
|
||||
"--device",
|
||||
type=str,
|
||||
default=ServerArgs.device,
|
||||
help="The device to use ('cuda', 'xpu', 'hpu', 'cpu'). Defaults to auto-detection if not specified.",
|
||||
help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--served-model-name",
|
||||
|
||||
@@ -145,6 +145,10 @@ def is_xpu() -> bool:
|
||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
|
||||
|
||||
def is_npu() -> bool:
|
||||
return hasattr(torch, "npu") and torch.npu.is_available()
|
||||
|
||||
|
||||
def is_flashinfer_available():
|
||||
"""
|
||||
Check whether flashinfer is available.
|
||||
@@ -328,6 +332,16 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
|
||||
elif device == "cpu":
|
||||
# TODO: rename the variables in the current function to be not GPU specific
|
||||
free_gpu_memory = psutil.virtual_memory().available
|
||||
elif device == "npu":
|
||||
num_gpus = torch.npu.device_count()
|
||||
assert gpu_id < num_gpus
|
||||
|
||||
if torch.npu.current_device() != gpu_id:
|
||||
print(
|
||||
f"WARNING: current device is not {gpu_id}, but {torch.npu.current_device()}, ",
|
||||
"which may cause useless memory allocation for torch NPU context.",
|
||||
)
|
||||
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
|
||||
|
||||
if distributed:
|
||||
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
||||
@@ -1348,6 +1362,9 @@ def get_device_name(device_id: int = 0) -> str:
|
||||
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
||||
return torch.hpu.get_device_name(device_id)
|
||||
|
||||
if hasattr(torch, "npu") and torch.npu.is_available():
|
||||
return torch.npu.get_device_name(device_id)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_habana_available() -> bool:
|
||||
@@ -1444,6 +1461,13 @@ def get_compiler_backend() -> str:
|
||||
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
||||
return "hpu_backend"
|
||||
|
||||
if hasattr(torch, "npu") and torch.npu.is_available():
|
||||
import torchair
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
return npu_backend
|
||||
|
||||
return "inductor"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user