diff --git a/python/pyproject.toml b/python/pyproject.toml index 946a058cd..29cfcc793 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -89,6 +89,8 @@ srt_hpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"] # To install vllm for CPU, please follow the instruction here: # https://docs.vllm.ai/en/latest/getting_started/installation/cpu/index.html srt_cpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11", "torch"] +# https://vllm-ascend.readthedocs.io/en/latest/installation.html +srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] @@ -107,6 +109,7 @@ all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[lit all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_cpu = ["sglang[srt_cpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] +all_npu = ["sglang[srt_npu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] dev = ["sglang[all]", "sglang[test]"] dev_hip = ["sglang[all_hip]", "sglang[test]"] diff --git a/python/sglang/srt/configs/device_config.py b/python/sglang/srt/configs/device_config.py index d95e848dd..3b9d3a1ed 100644 --- a/python/sglang/srt/configs/device_config.py +++ b/python/sglang/srt/configs/device_config.py @@ -10,7 +10,7 @@ class DeviceConfig: device: Optional[torch.device] def __init__(self, device: str = "cuda") -> None: - if device in ["cuda", "xpu", "hpu", "cpu"]: + if device in ["cuda", "xpu", "hpu", "cpu", "npu"]: self.device_type = device else: raise RuntimeError(f"Not supported device type: {device}") diff --git a/python/sglang/srt/distributed/device_communicators/npu_communicator.py b/python/sglang/srt/distributed/device_communicators/npu_communicator.py new file mode 100644 index 000000000..cb6eb88e3 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/npu_communicator.py @@ -0,0 +1,39 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from sglang.srt.utils import is_npu + + +class NpuCommunicator: + + def __init__(self, group: ProcessGroup): + if not is_npu(): + self.disabled = True + return + self.disabled = False + self.group = group + self.world_size = dist.get_world_size(self.group) + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + dist.all_reduce(x, group=self.group) + return x + + def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + if dim < 0: + # Convert negative dim to positive. + dim += x.dim() + input_size = x.size() + output_size = (input_size[0] * world_size,) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, dtype=x.dtype, device=x.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, x, group=self.group) + # Reshape + output_tensor = output_tensor.reshape((world_size,) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape( + input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] + ) + return output_tensor diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index e43bc0000..a9161b5c3 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -42,6 +42,7 @@ from torch.distributed import Backend, ProcessGroup from sglang.srt.utils import ( direct_register_custom_op, is_cuda_alike, + is_npu, supports_custom_op, ) @@ -206,6 +207,7 @@ class GroupCoordinator: use_custom_allreduce: bool, use_hpu_communicator: bool, use_xpu_communicator: bool, + use_npu_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ): @@ -244,6 +246,7 @@ class GroupCoordinator: self.use_custom_allreduce = use_custom_allreduce self.use_hpu_communicator = use_hpu_communicator self.use_xpu_communicator = use_xpu_communicator + self.use_npu_communicator = use_npu_communicator self.use_message_queue_broadcaster = use_message_queue_broadcaster # lazy import to avoid documentation build error @@ -291,6 +294,14 @@ class GroupCoordinator: if use_xpu_communicator and self.world_size > 1: self.xpu_communicator = XpuCommunicator(group=self.device_group) + from sglang.srt.distributed.device_communicators.npu_communicator import ( + NpuCommunicator, + ) + + self.npu_communicator: Optional[NpuCommunicator] = None + if use_npu_communicator and self.world_size > 1: + self.npu_communicator = NpuCommunicator(group=self.device_group) + from sglang.srt.distributed.device_communicators.shm_broadcast import ( MessageQueue, ) @@ -418,6 +429,9 @@ class GroupCoordinator: if self.xpu_communicator is not None and not self.xpu_communicator.disabled: return self.xpu_communicator.all_reduce(input_) + if self.npu_communicator is not None and not self.npu_communicator.disabled: + return self.npu_communicator.all_reduce(input_) + if ( self.ca_comm is not None and not self.ca_comm.disabled @@ -497,6 +511,11 @@ class GroupCoordinator: if hpu_comm is not None and not hpu_comm.disabled: return hpu_comm.all_gather(input_, dim) + # For NPUs, use NPU communicator. + npu_comm = self.npu_communicator + if npu_comm is not None and not npu_comm.disabled: + return npu_comm.all_gather(input_, dim) + if dim < 0: # Convert negative dim to positive. dim += input_.dim() @@ -941,6 +960,7 @@ def init_world_group( use_custom_allreduce=False, use_hpu_communicator=False, use_xpu_communicator=False, + use_npu_communicator=False, group_name="world", ) @@ -959,10 +979,11 @@ def init_model_parallel_group( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, - use_pynccl=True, + use_pynccl=not is_npu(), use_custom_allreduce=use_custom_allreduce, use_hpu_communicator=True, use_xpu_communicator=True, + use_npu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index bf2c91080..589cc9b06 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -361,6 +361,8 @@ class ModelRunner: backend = "hccl" elif self.device == "cpu": backend = "gloo" + elif self.device == "npu": + backend = "hccl" before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) if not self.server_args.enable_p2p_check: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 56c1f916b..ba91cd2ac 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -555,7 +555,7 @@ class ServerArgs: "--device", type=str, default=ServerArgs.device, - help="The device to use ('cuda', 'xpu', 'hpu', 'cpu'). Defaults to auto-detection if not specified.", + help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.", ) parser.add_argument( "--served-model-name", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f581ffd55..d137e4eac 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -145,6 +145,10 @@ def is_xpu() -> bool: return hasattr(torch, "xpu") and torch.xpu.is_available() +def is_npu() -> bool: + return hasattr(torch, "npu") and torch.npu.is_available() + + def is_flashinfer_available(): """ Check whether flashinfer is available. @@ -328,6 +332,16 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True elif device == "cpu": # TODO: rename the variables in the current function to be not GPU specific free_gpu_memory = psutil.virtual_memory().available + elif device == "npu": + num_gpus = torch.npu.device_count() + assert gpu_id < num_gpus + + if torch.npu.current_device() != gpu_id: + print( + f"WARNING: current device is not {gpu_id}, but {torch.npu.current_device()}, ", + "which may cause useless memory allocation for torch NPU context.", + ) + free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info() if distributed: tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( @@ -1348,6 +1362,9 @@ def get_device_name(device_id: int = 0) -> str: if hasattr(torch, "hpu") and torch.hpu.is_available(): return torch.hpu.get_device_name(device_id) + if hasattr(torch, "npu") and torch.npu.is_available(): + return torch.npu.get_device_name(device_id) + @lru_cache(maxsize=1) def is_habana_available() -> bool: @@ -1444,6 +1461,13 @@ def get_compiler_backend() -> str: if hasattr(torch, "hpu") and torch.hpu.is_available(): return "hpu_backend" + if hasattr(torch, "npu") and torch.npu.is_available(): + import torchair + + config = torchair.CompilerConfig() + npu_backend = torchair.get_npu_backend(compiler_config=config) + return npu_backend + return "inductor"