forked from EngineX-Ascend/enginex-ascend-910-vllm
v0.10.1rc1
This commit is contained in:
28
vllm_ascend/distributed/__init__.py
Normal file
28
vllm_ascend/distributed/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import \
|
||||
KVConnectorFactory
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LLMDataDistCMgrConnector",
|
||||
"vllm_ascend.distributed.llmdatadist_c_mgr_connector",
|
||||
"LLMDataDistCMgrConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnectorV1", "vllm_ascend.distributed.mooncake_connector",
|
||||
"MooncakeConnector")
|
||||
25
vllm_ascend/distributed/communication_op.py
Normal file
25
vllm_ascend/distributed/communication_op.py
Normal file
@@ -0,0 +1,25 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
|
||||
|
||||
def data_parallel_reduce_scatter(input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
"""Reduce-Scatter the input tensor across data parallel group."""
|
||||
return get_dp_group().reduce_scatter(input_, dim)
|
||||
75
vllm_ascend/distributed/communicator.py
Normal file
75
vllm_ascend/distributed/communicator.py
Normal file
@@ -0,0 +1,75 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.device_communicators.base_device_communicator import \
|
||||
DeviceCommunicatorBase
|
||||
|
||||
|
||||
class NPUCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: dist.ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[dist.ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
|
||||
# init device according to rank
|
||||
self.device = torch.npu.current_device()
|
||||
|
||||
def all_to_all(self,
|
||||
input_: torch.Tensor,
|
||||
scatter_dim: int = 0,
|
||||
gather_dim: int = -1,
|
||||
scatter_sizes: Optional[List[int]] = None,
|
||||
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
|
||||
|
||||
if scatter_dim < 0:
|
||||
scatter_dim += input_.dim()
|
||||
if gather_dim < 0:
|
||||
gather_dim += input_.dim()
|
||||
|
||||
if scatter_sizes is not None and gather_sizes is not None:
|
||||
input_list = [
|
||||
t.contiguous()
|
||||
for t in torch.split(input_, scatter_sizes, scatter_dim)
|
||||
]
|
||||
output_list = []
|
||||
tensor_shape_base = input_list[self.rank].size()
|
||||
for i in range(self.world_size):
|
||||
tensor_shape = list(tensor_shape_base)
|
||||
tensor_shape[gather_dim] = gather_sizes[i]
|
||||
output_list.append(
|
||||
torch.empty(tensor_shape,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device))
|
||||
|
||||
else:
|
||||
input_list = [
|
||||
t.contiguous() for t in torch.tensor_split(
|
||||
input_, self.world_size, scatter_dim)
|
||||
]
|
||||
output_list = [
|
||||
torch.empty_like(input_list[i]) for i in range(self.world_size)
|
||||
]
|
||||
|
||||
dist.all_to_all(output_list, input_list, group=self.device_group)
|
||||
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
return output_tensor
|
||||
165
vllm_ascend/distributed/device_communicators/pyhccl.py
Normal file
165
vllm_ascend/distributed/device_communicators/pyhccl.py
Normal file
@@ -0,0 +1,165 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
|
||||
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
|
||||
hcclRedOpTypeEnum, hcclUniqueId)
|
||||
from vllm_ascend.utils import current_stream
|
||||
|
||||
|
||||
class PyHcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the PyHcclCommunicator to. If None,
|
||||
it will be bind to f"npu:{local_rank}".
|
||||
library_path: the path to the HCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert dist.get_backend(group) != dist.Backend.HCCL, (
|
||||
"PyHcclCommunicator should be attached to a non-HCCL group.")
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
else:
|
||||
self.rank = group.rank
|
||||
self.world_size = group.world_size
|
||||
|
||||
self.group = group
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
try:
|
||||
self.hccl = HCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing HCCL library
|
||||
# e.g. in a non-NPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
logger.info("vLLM is using pyhccl")
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"npu:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
if self.rank == 0:
|
||||
# get the unique id from HCCL
|
||||
with torch.npu.device(device):
|
||||
self.unique_id = self.hccl.hcclGetUniqueId()
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = hcclUniqueId()
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
else:
|
||||
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||
|
||||
# hccl communicator and stream will use this device
|
||||
# `torch.npu.device` is a context manager that changes the
|
||||
# current npu device to the specified one
|
||||
with torch.npu.device(device):
|
||||
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(self,
|
||||
in_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return None
|
||||
# hccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}")
|
||||
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
hcclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
aclrtStream_t(stream.npu_stream))
|
||||
return out_tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
self.hccl.hcclBroadcast(buffer, tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, aclrtStream_t(stream.npu_stream))
|
||||
253
vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py
Normal file
253
vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py
Normal file
@@ -0,0 +1,253 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.utils import find_hccl_library
|
||||
|
||||
# export types and functions from hccl to Python ===
|
||||
# for the original hccl definition, please check
|
||||
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl.h#L90
|
||||
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl_types.h#L48
|
||||
|
||||
hcclResult_t = ctypes.c_int
|
||||
hcclComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class hcclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 4108)]
|
||||
|
||||
|
||||
aclrtStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
hcclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class hcclDataTypeEnum:
|
||||
hcclInt8 = 0
|
||||
hcclInt16 = 1
|
||||
hcclInt32 = 2
|
||||
hcclFloat16 = 3
|
||||
hcclFloat32 = 4
|
||||
hcclInt64 = 5
|
||||
hcclUint64 = 6
|
||||
hcclUint8 = 7
|
||||
hcclUint16 = 8
|
||||
hcclUint32 = 9
|
||||
hcclFloat64 = 10
|
||||
hcclBfloat16 = 11
|
||||
hcclInt128 = 12
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.hcclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.hcclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.hcclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.hcclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.hcclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.hcclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.hcclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.hcclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
hcclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class hcclRedOpTypeEnum:
|
||||
hcclSum = 0
|
||||
hcclProd = 1
|
||||
hcclMax = 2
|
||||
hcclMin = 3
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.hcclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.hcclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.hcclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.hcclMin
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
class HCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* HcclGetErrorString(HcclResult code);
|
||||
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
|
||||
|
||||
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
|
||||
Function("HcclGetRootInfo", hcclResult_t,
|
||||
[ctypes.POINTER(hcclUniqueId)]),
|
||||
|
||||
# HcclResult HcclCommInitRootInfo(
|
||||
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
|
||||
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
|
||||
Function("HcclCommInitRootInfo", hcclResult_t, [
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclUniqueId),
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclComm_t),
|
||||
]),
|
||||
|
||||
# HcclResult HcclAllReduce(
|
||||
# void *sendBuf, void *recvBuf, uint64_t count,
|
||||
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
|
||||
# aclrtStream stream);
|
||||
Function("HcclAllReduce", hcclResult_t, [
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
hcclRedOp_t,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
# HcclResult HcclBroadcast(
|
||||
# void *buf, uint64_t count,
|
||||
# HcclDataType dataType, uint32_t root,
|
||||
# HcclComm comm, aclrtStream stream);
|
||||
Function("HcclBroadcast", hcclResult_t, [
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
ctypes.c_int,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
# HcclResult HcclCommDestroy(HcclComm comm);
|
||||
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the correspongding directory
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_hccl_library()
|
||||
|
||||
try:
|
||||
if so_file not in HCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
HCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = HCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load HCCL library from %s. "
|
||||
"It is expected if you are not running on Ascend NPUs."
|
||||
"Otherwise, the hccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable HCCL_SO_PATH"
|
||||
" to point to the correct hccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in HCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: Dict[str, Any] = {}
|
||||
for func in HCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
HCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = HCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def hcclGetErrorString(self, result: hcclResult_t) -> str:
|
||||
return self._funcs["HcclGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def HCCL_CHECK(self, result: hcclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.hcclGetErrorString(result)
|
||||
raise RuntimeError(f"HCCL error: {error_str}")
|
||||
|
||||
def hcclGetUniqueId(self) -> hcclUniqueId:
|
||||
unique_id = hcclUniqueId()
|
||||
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
|
||||
rank: int) -> hcclComm_t:
|
||||
comm = hcclComm_t()
|
||||
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
|
||||
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
|
||||
return comm
|
||||
|
||||
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
# `datatype` actually should be `hcclDataType_t`
|
||||
# and `op` should be `hcclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
|
||||
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
|
||||
root: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
|
||||
root, comm, stream))
|
||||
|
||||
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"HCCLLibrary",
|
||||
"hcclDataTypeEnum",
|
||||
"hcclRedOpTypeEnum",
|
||||
"hcclUniqueId",
|
||||
"hcclComm_t",
|
||||
"aclrtStream_t",
|
||||
"buffer_type",
|
||||
]
|
||||
894
vllm_ascend/distributed/llmdatadist_c_mgr_connector.py
Normal file
894
vllm_ascend/distributed/llmdatadist_c_mgr_connector.py
Normal file
@@ -0,0 +1,894 @@
|
||||
import contextlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
import llm_datadist # type: ignore
|
||||
import msgspec
|
||||
import torch
|
||||
import zmq
|
||||
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
|
||||
LLMException, LLMRole)
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_tp_group, get_world_group
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import get_ip, logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
TORCH_DTYPE_TO_NPU_DTYPE = {
|
||||
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.float16: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.bfloat16: llm_datadist.DataType.DT_BF16,
|
||||
torch.float: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.float32: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.int8: llm_datadist.DataType.DT_INT8,
|
||||
torch.int64: llm_datadist.DataType.DT_INT64,
|
||||
torch.int32: llm_datadist.DataType.DT_INT32
|
||||
}
|
||||
|
||||
|
||||
class LLMDataDistCMgrEvent(Enum):
|
||||
ReqForMetadata = 0
|
||||
ReqForFinished = 1
|
||||
|
||||
|
||||
class LLMDataDistCMgrAgentMetadata(msgspec.Struct):
|
||||
super_pod_id: str
|
||||
server_id: str
|
||||
device_id: str
|
||||
device_ip: str
|
||||
super_device_id: str
|
||||
cluster_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: str
|
||||
engine_id: str
|
||||
remote_tp_size: str
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self):
|
||||
self.requests: dict[str, ReqMeta] = {}
|
||||
|
||||
def add_new_req(self, request_id: str, local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any]):
|
||||
self.requests[request_id] = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
remote_tp_size=kv_transfer_params["remote_tp_size"],
|
||||
)
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler: Optional[
|
||||
LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler(
|
||||
vllm_config, self.engine_id)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = LLMDataDistCMgrConnectorWorker(vllm_config)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(
|
||||
self,
|
||||
kv_caches: dict[
|
||||
str, # type: ignore[override]
|
||||
Tuple[torch.Tensor]]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(finished_req_ids)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata,
|
||||
LLMDataDistCMgrConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""LLMDataDistCMgrConnector does not do layerwise saving, the load is in blocking manager."""
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata, **kwargs) -> None:
|
||||
"""LLMDataDistCMgrConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
"""LLMDataDistCMgrConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnectorScheduler():
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.engine_id = engine_id
|
||||
self.local_ip = get_ip()
|
||||
# Can not retrieve the parallel config since it is not initialized.
|
||||
self.local_dp_rank = None
|
||||
self.tp_size = None
|
||||
dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
|
||||
|
||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
"""
|
||||
For remote prefill, pull all prompt blocks from remote
|
||||
asynchronously relative to engine execution.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
Returns:
|
||||
* the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
* true if the external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps).
|
||||
"""
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnector get_num_new_matched_tokens: num_computed_tokens={num_computed_tokens}, kv_transfer_params={params}"
|
||||
)
|
||||
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
assert num_computed_tokens % self.block_size == 0
|
||||
# Note: We use the full token count as transmit data here.
|
||||
count = max(len(request.prompt_token_ids) - num_computed_tokens, 0)
|
||||
return count, count > 0
|
||||
|
||||
# No remote prefill for this request.
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
|
||||
num_externel_tokens: int):
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnector update states num_externel_tokens: {num_externel_tokens} kv_transfer_params: {params}"
|
||||
)
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
if params.get("remote_block_ids"):
|
||||
if all(p in params for p in ("remote_engine_id", "remote_host",
|
||||
"remote_port", "remote_tp_size")):
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request, blocks.get_unhashed_block_ids())
|
||||
else:
|
||||
logger.warning("" \
|
||||
f"Invalid KVTransferParams {params}, This request will be discard")
|
||||
else:
|
||||
assert num_externel_tokens == 0
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
meta = LLMDataDistCMgrConnectorMetadata()
|
||||
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
meta.add_new_req(request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params)
|
||||
self._reqs_need_recv.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"LLMDataDistCMgrConnector request_finished, request_status=%s, "
|
||||
"kv_transfer_params=%s", request.status, params)
|
||||
|
||||
if (params is None or not params.get("do_remote_decode")
|
||||
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
||||
return False, None
|
||||
|
||||
# note: NIXL transfer the full block only, but I don't see any reason to do that, so here
|
||||
# we just transfer any data that computed from prefill node
|
||||
# note: there might be some issue on this, check it if there is any unexpected result
|
||||
computed_block_ids = block_ids
|
||||
delay_free_blocks = len(computed_block_ids) > 0
|
||||
if delay_free_blocks:
|
||||
logger.info("Delaying free of %d blocks for request %s",
|
||||
len(computed_block_ids), request.request_id)
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_block_ids=computed_block_ids,
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=self.local_ip,
|
||||
remote_port=self.port,
|
||||
remote_tp_size=str(
|
||||
self.vllm_config.parallel_config.tensor_parallel_size),
|
||||
)
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnectorWorker():
|
||||
"""
|
||||
Implementation of Worker side methods
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
logger.info("Initialize the LLMDataDistCMgrConnectorWorker")
|
||||
# we assume the local node only contains dp and tp, and tp will not communicate inter-node.
|
||||
# for any scenario beyond this scope, the functionality of this connector is not guaranteed.
|
||||
self.local_rank_on_node = get_world_group().rank % (
|
||||
vllm_config.parallel_config.data_parallel_size_local *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
self.local_rank = get_world_group().local_rank
|
||||
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
||||
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.rank = get_world_group().rank
|
||||
self.local_ip = get_ip()
|
||||
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
|
||||
self.local_agent_metadata: Optional[
|
||||
LLMDataDistCMgrAgentMetadata] = None
|
||||
self.vllm_config = vllm_config
|
||||
self.executor = ThreadPoolExecutor(1)
|
||||
self.thread_lock = threading.Lock()
|
||||
|
||||
self.llm_datadist_role = None
|
||||
self.llm_datadist_remote_role = None
|
||||
if self.kv_transfer_config.kv_role == "kv_producer":
|
||||
self.llm_datadist_role = LLMRole.PROMPT
|
||||
self.llm_datadist_remote_role = LLMRole.DECODER
|
||||
elif self.kv_transfer_config.kv_role == "kv_consumer":
|
||||
self.llm_datadist_role = LLMRole.DECODER
|
||||
self.llm_datadist_remote_role = LLMRole.PROMPT
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistWorker: Receive unexpected kv role in LLMDataDistWorker, this worker now only support kv_producer and kv_consumer, but receiving {vllm_config.kv_transfer_config.kv_role}"
|
||||
)
|
||||
|
||||
# linked_cluster record the cluster that already build the connection its format should be {"cluster_id": "comm_name"}
|
||||
self.linked_cluster: dict[Any, Any] = {}
|
||||
self.prefill_device_list: list[tuple[int, int]] = []
|
||||
self.decode_device_list: list[tuple[int, int]] = []
|
||||
global_rank_table = self.read_offline_rank_table()
|
||||
self.local_agent_metadata = self.read_agent_metadata(global_rank_table)
|
||||
self.llm_datadist = LLMDataDist(self.llm_datadist_role,
|
||||
self.local_agent_metadata.cluster_id)
|
||||
self.init_llm_datadist()
|
||||
self.finished_reqs: set[str] = set()
|
||||
self.soc_info = get_ascend_soc_version()
|
||||
# Set hccl deterministic for model execute
|
||||
os.environ["HCCL_DETERMINISTIC"] = "true"
|
||||
self.done_receiving_counts: defaultdict[str,
|
||||
set[int]] = defaultdict(set)
|
||||
|
||||
def listen_for_agent_metadata_req(self, event: threading.Event):
|
||||
assert self.local_agent_metadata is not None
|
||||
port = envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
|
||||
url = f"tcp://{envs_ascend.VLLM_ASCEND_LLMDD_RPC_IP}:{port}"
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_decoder = msgspec.msgpack.Decoder()
|
||||
msg_to_send = msg_encoder.encode(self.local_agent_metadata)
|
||||
logger.debug(f"Start to listen to address: {url}")
|
||||
logger.debug(
|
||||
f"The local agent metadata have {len(msg_to_send)} bytes here")
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: Cluster {self.local_agent_metadata.cluster_id} start to listen request from peers"
|
||||
)
|
||||
with zmq_ctx(zmq.ROUTER, url) as sock: # type: ignore[attr-defined]
|
||||
event.set()
|
||||
while True:
|
||||
identity, _, msg = sock.recv_multipart()
|
||||
event_msg, decode_msg = msg_decoder.decode(msg)
|
||||
event_msg = LLMDataDistCMgrEvent(event_msg)
|
||||
if event_msg == LLMDataDistCMgrEvent.ReqForMetadata:
|
||||
if "cluster_id" in decode_msg:
|
||||
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
|
||||
)
|
||||
sock.send_multipart((identity, b"", msg_to_send))
|
||||
self.add_remote_agent(decode_msg)
|
||||
else:
|
||||
logger.warning(
|
||||
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
|
||||
)
|
||||
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
|
||||
finished_req_id = decode_msg[0]
|
||||
decode_tp_rank = decode_msg[1]
|
||||
decode_tp_size = decode_msg[2]
|
||||
with self.thread_lock:
|
||||
if self._increment_task_count(finished_req_id,
|
||||
decode_tp_rank,
|
||||
decode_tp_size):
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
||||
)
|
||||
self.finished_reqs.add(finished_req_id)
|
||||
sock.send_multipart(
|
||||
(identity, b"", b"receiving decode finished"))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
|
||||
)
|
||||
|
||||
def _increment_task_count(self, request_id: str, tp_rank: int,
|
||||
decode_tp_size: int):
|
||||
if request_id not in self.done_receiving_counts:
|
||||
self.done_receiving_counts[request_id] = set()
|
||||
if tp_rank in self.done_receiving_counts[request_id]:
|
||||
logger.warning(
|
||||
f"Received duplicate done signal for request {request_id} "
|
||||
f"from tp rank {tp_rank}. Ignoring.")
|
||||
return False
|
||||
self.done_receiving_counts[request_id].add(tp_rank)
|
||||
if len(self.done_receiving_counts[request_id]) == decode_tp_size:
|
||||
self.done_receiving_counts.pop(request_id)
|
||||
logger.info("All transfers completed for request: "
|
||||
f"{request_id}. Total ranks: "
|
||||
f"{decode_tp_size}.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def init_llm_datadist(self):
|
||||
assert self.local_agent_metadata is not None
|
||||
llm_config = LLMConfig()
|
||||
llm_config.device_id = self.local_rank
|
||||
llm_config.sync_kv_timeout = 20000
|
||||
llm_config.enable_switch_role = True
|
||||
llm_config.enable_cache_manager = True
|
||||
llm_config.enable_remote_cache_accessible = True
|
||||
llm_config_options = llm_config.generate_options()
|
||||
self.llm_datadist.init(llm_config_options)
|
||||
self.cache_manager = self.llm_datadist.cache_manager
|
||||
logger.info(
|
||||
f"Done initialize llm_datadist in rank {self.rank}, local rank {self.local_rank}, cluster id {self.local_agent_metadata.cluster_id}"
|
||||
)
|
||||
|
||||
def read_offline_rank_table(self):
|
||||
assert (
|
||||
envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
|
||||
), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH"
|
||||
rank_table_path = envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
|
||||
with open(rank_table_path, "r", encoding="utf-8") as f:
|
||||
global_rank_table = json.load(f)
|
||||
decode_device_list = global_rank_table["decode_device_list"]
|
||||
for decode_device in decode_device_list:
|
||||
server_id = decode_device["server_id"]
|
||||
device_id = decode_device["device_id"]
|
||||
self.decode_device_list.append((server_id, device_id))
|
||||
prefill_device_list = global_rank_table["prefill_device_list"]
|
||||
for prefill_device in prefill_device_list:
|
||||
server_id = prefill_device["server_id"]
|
||||
device_id = prefill_device["device_id"]
|
||||
self.prefill_device_list.append((server_id, device_id))
|
||||
|
||||
# global_rank_table = json.dumps(global_rank_table)
|
||||
return global_rank_table
|
||||
|
||||
@staticmethod
|
||||
def _get_visible_devices() -> Callable[[str], bool]:
|
||||
"""
|
||||
Return a test function that check if the given device ID is visible.
|
||||
i.e. ASCEND_RT_VISIBLE_DEVICES is not set or contains the device_id.
|
||||
"""
|
||||
visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
|
||||
if not visible_devices:
|
||||
return lambda device_id: True
|
||||
visible_device_list = visible_devices.split(",")
|
||||
return lambda device_id: device_id in visible_device_list
|
||||
|
||||
def read_agent_metadata(self, global_rank_table):
|
||||
device_filter = LLMDataDistCMgrConnectorWorker._get_visible_devices()
|
||||
devices_type_list = []
|
||||
agent_metadata = None
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
devices_type_list.append("prefill_device_list")
|
||||
elif self.llm_datadist_role == LLMRole.DECODER:
|
||||
devices_type_list.append("decode_device_list")
|
||||
else:
|
||||
devices_type_list.append("prefill_device_list")
|
||||
devices_type_list.append("decode_device_list")
|
||||
for device_type in devices_type_list:
|
||||
device_list = global_rank_table[device_type]
|
||||
device_list = [
|
||||
d for d in device_list if d.get("server_id") == self.local_ip
|
||||
and device_filter(d.get("device_id", ""))
|
||||
]
|
||||
if len(device_list) <= self.tp_rank:
|
||||
continue
|
||||
device_info = device_list[self.tp_rank]
|
||||
super_pod_id_ = device_info.get("super_pod_id", None)
|
||||
server_id_ = device_info["server_id"]
|
||||
device_id_ = device_info["device_id"]
|
||||
device_ip_ = device_info["device_ip"]
|
||||
super_device_id_ = device_info.get("super_device_id", None)
|
||||
cluster_id_ = int(device_info["cluster_id"])
|
||||
agent_metadata = LLMDataDistCMgrAgentMetadata(
|
||||
super_pod_id=super_pod_id_,
|
||||
server_id=server_id_,
|
||||
device_id=device_id_,
|
||||
device_ip=device_ip_,
|
||||
super_device_id=super_device_id_,
|
||||
cluster_id=cluster_id_,
|
||||
)
|
||||
assert agent_metadata is not None, f"Can't read the target server_id {self.local_ip} and device_rank {self.rank} from rank table"
|
||||
return agent_metadata
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
assert len(first_kv_cache_tuple) > 1
|
||||
assert self.local_agent_metadata is not None
|
||||
kv_cache_dtype = first_kv_cache.dtype
|
||||
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1)
|
||||
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
||||
# MHA case. [2 (k and v), num_blocks, ...]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
|
||||
self.block_len = math.prod(block_shape)
|
||||
self.cache_addr: list[int] = []
|
||||
alignment = 2 * 1024 * 1024
|
||||
if self.use_mla:
|
||||
cache_k_normed_addr_list = []
|
||||
cache_k_pe_addr_list = []
|
||||
k_normed = None
|
||||
k_pe = None
|
||||
for cache_or_caches in kv_caches.values():
|
||||
assert len(cache_or_caches) > 1
|
||||
k_normed, k_pe = cache_or_caches[0], cache_or_caches[1]
|
||||
cache_k_normed_addr_list.append(k_normed.data_ptr())
|
||||
cache_k_pe_addr_list.append(k_pe.data_ptr())
|
||||
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list)
|
||||
|
||||
cache_desc_k_normed = CacheDesc(
|
||||
len(self.cache_addr[0]), [*k_normed.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_pe = CacheDesc(
|
||||
len(self.cache_addr[1]), [*k_pe.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=0)
|
||||
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=1)
|
||||
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe)
|
||||
self.cache_key = (cache_key_k_normed, cache_key_k_pe)
|
||||
try:
|
||||
cache_k_normed = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
|
||||
cache_k_pe = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
|
||||
self.cache = (cache_k_normed, cache_k_pe)
|
||||
logger.info("LLMDataDistWorker: End of register Paged Cache.")
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
else:
|
||||
for cache_or_caches in kv_caches.values():
|
||||
for cache in cache_or_caches:
|
||||
base_addr = cache.data_ptr()
|
||||
assert base_addr % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
|
||||
self.cache_addr.append(base_addr)
|
||||
# register paged kv cache into the llm_cache manager
|
||||
self.cache_desc = CacheDesc(
|
||||
len(self.cache_addr), [*cache.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
self.cache_key = BlocksCacheKey(
|
||||
cluster_id=int(self.local_agent_metadata.cluster_id))
|
||||
logger.info(
|
||||
f"num of cache: {len(self.cache_addr)}, size of cache: {[*cache.shape]}, real size of cache: {first_kv_cache.shape}"
|
||||
)
|
||||
try:
|
||||
self.cache = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc, self.cache_addr, self.cache_key)
|
||||
logger.info(
|
||||
"LLMDataDistCMgrConnectorWorker: End of register Paged Cache."
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
self.ready_event = threading.Event()
|
||||
self.metadata_agent_listener_t = threading.Thread(
|
||||
target=self.listen_for_agent_metadata_req,
|
||||
args=(self.ready_event, ),
|
||||
daemon=True,
|
||||
name="metadata_agent_listener")
|
||||
self.metadata_agent_listener_t.start()
|
||||
self.ready_event.wait()
|
||||
|
||||
def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata):
|
||||
futures = []
|
||||
for req_id, meta in metadata.requests.items():
|
||||
logger.debug(f"Start to transmit {req_id}")
|
||||
future = self.executor.submit(
|
||||
self._read_blocks,
|
||||
local_block_ids=meta.local_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
remote_ip=meta.remote_host,
|
||||
remote_port=int(meta.remote_port),
|
||||
remote_engine_id=meta.engine_id,
|
||||
request_id=req_id,
|
||||
remote_tp_size=meta.remote_tp_size,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
def handle_exception(future):
|
||||
if future.exception():
|
||||
logger.error(f"KV transfer task failed: {future.exception()}")
|
||||
|
||||
for future in futures:
|
||||
future.add_done_callback(handle_exception)
|
||||
|
||||
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
|
||||
assert self.local_agent_metadata is not None
|
||||
remote_cluster_id = metadata.cluster_id
|
||||
if remote_cluster_id in self.linked_cluster:
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection"
|
||||
)
|
||||
return remote_cluster_id
|
||||
remote_super_pod_id = metadata.super_pod_id
|
||||
remote_server_id = metadata.server_id
|
||||
is_same_server = remote_server_id == self.local_agent_metadata.server_id
|
||||
is_same_pod = remote_super_pod_id == self.local_agent_metadata.super_pod_id
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
prefill_metadata = self.local_agent_metadata
|
||||
decode_metadata = metadata
|
||||
else:
|
||||
prefill_metadata = metadata
|
||||
decode_metadata = self.local_agent_metadata
|
||||
comm_name = f"pd_comm_{prefill_metadata.device_ip}_{decode_metadata.device_ip}"
|
||||
cluster_rank_info = {
|
||||
prefill_metadata.cluster_id: 0,
|
||||
decode_metadata.cluster_id: 1
|
||||
}
|
||||
rank_table = {}
|
||||
rank_table["version"] = "1.2"
|
||||
rank_table["server_count"] = "1" if is_same_server else "2"
|
||||
rank_table["status"] = "completed"
|
||||
|
||||
# generate server_list for rank table
|
||||
rank_table["server_list"] = [] # type: ignore[assignment]
|
||||
decode_server_device_info = None
|
||||
prefill_server_device_info = {
|
||||
"device": [{
|
||||
k: v
|
||||
for k, v in [(
|
||||
"device_id", prefill_metadata.device_id
|
||||
), ("device_ip", prefill_metadata.device_ip
|
||||
), ("super_device_id",
|
||||
prefill_metadata.super_device_id), ("rank_id", "0")]
|
||||
if v is not None
|
||||
}],
|
||||
"server_id":
|
||||
prefill_metadata.server_id
|
||||
}
|
||||
if is_same_server:
|
||||
prefill_server_device_info["device"].append( # type: ignore[attr-defined]
|
||||
{
|
||||
k: v
|
||||
for k, v in [(
|
||||
"device_id", decode_metadata.device_id
|
||||
), ("device_ip", decode_metadata.device_ip
|
||||
), ("super_device_id",
|
||||
decode_metadata.super_device_id), ("rank_id", "1")]
|
||||
if v is not None
|
||||
})
|
||||
else:
|
||||
decode_server_device_info = {
|
||||
"device": [{
|
||||
k: v
|
||||
for k, v in [(
|
||||
"device_id", decode_metadata.device_id
|
||||
), ("device_ip", decode_metadata.device_ip
|
||||
), ("super_device_id",
|
||||
decode_metadata.super_device_id), ("rank_id", "1")]
|
||||
if v is not None
|
||||
}],
|
||||
"server_id":
|
||||
decode_metadata.server_id
|
||||
}
|
||||
rank_table["server_list"].append( # type: ignore[attr-defined]
|
||||
prefill_server_device_info)
|
||||
if decode_server_device_info is not None:
|
||||
rank_table["server_list"].append( # type: ignore[attr-defined]
|
||||
decode_server_device_info)
|
||||
|
||||
if self.soc_info == AscendSocVersion.A3:
|
||||
# generate super_pod_list for rank table
|
||||
super_pod_list = []
|
||||
prefill_super_pod_info = {
|
||||
"super_pod_id": prefill_metadata.super_pod_id,
|
||||
"server_list": [{
|
||||
"server_id": prefill_metadata.server_id
|
||||
}],
|
||||
}
|
||||
if is_same_pod and not is_same_server:
|
||||
prefill_super_pod_info[
|
||||
"server_list"].append( # type: ignore[attr-defined]
|
||||
{"server_id": decode_metadata.server_id})
|
||||
super_pod_list.append(prefill_super_pod_info)
|
||||
if not is_same_pod:
|
||||
decode_super_pod_id = {
|
||||
"super_pod_id": decode_metadata.super_pod_id,
|
||||
"server_list": [{
|
||||
"server_id": decode_metadata.server_id
|
||||
}],
|
||||
}
|
||||
super_pod_list.append(decode_super_pod_id)
|
||||
rank_table[
|
||||
"super_pod_list"] = super_pod_list # type: ignore[assignment]
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}"
|
||||
)
|
||||
logger.info(f"rank table \n{rank_table}")
|
||||
logger.info(f"comm name: {comm_name}")
|
||||
logger.info(f"cluster rank info: {cluster_rank_info}")
|
||||
comm_id = self.llm_datadist.link(comm_name, cluster_rank_info,
|
||||
json.dumps(rank_table))
|
||||
while True:
|
||||
ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id)
|
||||
if ret == llm_datadist.RegisterMemStatus.OK:
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}"
|
||||
)
|
||||
break
|
||||
elif ret == llm_datadist.RegisterMemStatus.FAILED:
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}"
|
||||
)
|
||||
time.sleep(1)
|
||||
logger.info("Checking query_register_mem_status again")
|
||||
self.linked_cluster.update({remote_cluster_id: comm_id})
|
||||
logger.info(f"cached linked cluster: {self.linked_cluster}")
|
||||
logger.info(
|
||||
f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !"
|
||||
)
|
||||
return remote_cluster_id
|
||||
|
||||
def remove_remote_agent(self, cluster_id: int):
|
||||
if cluster_id not in self.linked_cluster:
|
||||
logger.warning(
|
||||
f"LLMDataDistCMgrConnectorWorker: Warning! Can't remove remote client with cluster id {cluster_id} for its not exist in linked_cluster list"
|
||||
)
|
||||
comm_id = self.linked_cluster[cluster_id]
|
||||
try:
|
||||
self.llm_datadist.unlink(comm_id)
|
||||
self.linked_cluster.pop(cluster_id)
|
||||
except LLMException:
|
||||
logger.error(
|
||||
f"Try to remove remote client with cluster id {cluster_id} failed!, program won't terminate, but please carefully check your environment"
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully remove remote client with cluster id {cluster_id} !"
|
||||
)
|
||||
|
||||
def connect_to_remote_agent(self, host: str, port: int) -> int:
|
||||
url = f"tcp://{host}:{port}"
|
||||
logger.debug(f"Querying metadata from url: {url}")
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_send = msg_encoder.encode(
|
||||
[LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata])
|
||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||
logger.info("Try request remote metadata from socket......")
|
||||
sock.send(msg_send)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder()
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
metadata = LLMDataDistCMgrAgentMetadata(**metadata)
|
||||
logger.info(f"recving metadata: {metadata}")
|
||||
cluster_id = self.add_remote_agent(metadata)
|
||||
return cluster_id
|
||||
|
||||
def send_finish_to_remote(self, host: str, port: int, request_id):
|
||||
url = f"tcp://{host}:{port}"
|
||||
logger.debug(f"Sending finished to remote: {url}")
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_send = msg_encoder.encode([
|
||||
LLMDataDistCMgrEvent.ReqForFinished,
|
||||
[request_id, self.tp_rank, self.tp_size]
|
||||
])
|
||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||
try:
|
||||
sock.send(msg_send)
|
||||
logger.debug(
|
||||
f"Request id {request_id} finished message send to remote {url}"
|
||||
)
|
||||
_ = sock.recv()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send reqest_id {request_id} to prefill: {e}")
|
||||
|
||||
def _read_blocks(
|
||||
self,
|
||||
local_block_ids: list[int],
|
||||
remote_block_ids: list[int],
|
||||
remote_ip: str,
|
||||
remote_port: int,
|
||||
remote_engine_id: str,
|
||||
request_id: str,
|
||||
remote_tp_size: str,
|
||||
):
|
||||
# if remote_ip not in self.linked_cluster:
|
||||
tp_offset = self.tp_rank % int(remote_tp_size)
|
||||
remote_cluster_id = self.connect_to_remote_agent(
|
||||
remote_ip, remote_port + tp_offset)
|
||||
num_local_blocks = len(local_block_ids)
|
||||
if num_local_blocks == 0:
|
||||
return
|
||||
num_remote_blocks = len(remote_block_ids)
|
||||
assert num_local_blocks <= num_remote_blocks
|
||||
if num_local_blocks < num_remote_blocks:
|
||||
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||
|
||||
logger.info(f"remote cluster id is: {remote_cluster_id}")
|
||||
if self.use_mla:
|
||||
remote_cache_key_k_normed = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=0)
|
||||
remote_cache_key_k_pe = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=1)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_normed,
|
||||
self.cache[0], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_pe,
|
||||
self.cache[1], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
else:
|
||||
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key,
|
||||
self.cache, # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
self.send_finish_to_remote(remote_ip, remote_port, request_id)
|
||||
with self.thread_lock:
|
||||
self.finished_reqs.add(request_id)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""Get the finished recving and sending requuests."""
|
||||
import copy
|
||||
with self.thread_lock:
|
||||
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
|
||||
self.finished_reqs.clear()
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
return req_ids_to_ret, None
|
||||
else:
|
||||
return None, req_ids_to_ret
|
||||
|
||||
|
||||
# adopt this from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any,
|
||||
addr: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx: Optional[zmq.Context] = None # type: ignore[name-defined]
|
||||
try:
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
if socket_type == zmq.ROUTER: # type: ignore[attr-defined]
|
||||
socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined]
|
||||
socket.bind(addr)
|
||||
elif socket_type == zmq.REQ: # type: ignore[attr-defined]
|
||||
socket = ctx.socket(zmq.REQ) # type: ignore[attr-defined]
|
||||
socket.connect(addr)
|
||||
else:
|
||||
raise ValueError(f"Unexpected socket type: {socket_type}")
|
||||
|
||||
yield socket
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
556
vllm_ascend/distributed/moe_comm_method.py
Normal file
556
vllm_ascend/distributed/moe_comm_method.py
Normal file
@@ -0,0 +1,556 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
"""Base class for MoE communication methods."""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare the MoE communication method.
|
||||
|
||||
This method is called before quant_method.apply to prepare the
|
||||
communication method. It can be used to initialize any necessary
|
||||
resources or configurations.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""Finalize the MoE communication method.
|
||||
|
||||
This method is called after quant_method.apply to finalize the
|
||||
communication method. It can be used to clean up any resources or
|
||||
configurations.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
"""Pre-process before MLP.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
|
||||
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
||||
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
||||
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
|
||||
Mapping from global expert IDs to local expert IDs.
|
||||
num_experts (int): Number of local experts (experts on this device).
|
||||
apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
|
||||
- permuted_hidden_states (torch.Tensor): Tensor of shape
|
||||
(num_tokens * top_k_num, hidden_size) after permuting
|
||||
hidden_states based on topk_ids.
|
||||
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
|
||||
Number of tokens assigned to each expert.
|
||||
- dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, )
|
||||
Dynamic scale for each expert, used for quantization.
|
||||
- group_list_type (int): Type of group list, 0 for `cumsum`
|
||||
and 1 for `count`. This is mainly for `npu_grouped_matmul`
|
||||
to determine how to handle the output.
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented in the subclass.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
"""Post-process after MLP.
|
||||
|
||||
Args:
|
||||
mlp_output (torch.Tensor): Tensor of shape
|
||||
(num_tokens * top_k_num, hidden_size) after MLP.
|
||||
hidden_states (torch.Tensor): Tensor of shape
|
||||
(num_tokens, hidden_size) to be updated with the final output.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AllGatherCommImpl(MoECommMethod):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
This implementation should be compatible with all scenarios, and
|
||||
thus it is the default implementation for MoE communication methods.
|
||||
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||||
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
|
||||
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
||||
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
||||
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
||||
for pre-processing and post-processing, respectively.
|
||||
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
||||
use `torch_npu.npu_moe_token_unpermute` instead.
|
||||
This is a workaround and should be removed after the issue is fixed.
|
||||
"""
|
||||
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""When DP size > 1, pad the hidden states and router logits for communication."""
|
||||
if self.moe_config.dp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
|
||||
self.num_tokens = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_dp - self.num_tokens
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
hidden_states = self.moe_config.dp_group.all_gather(
|
||||
hidden_states, 0)
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""When DP size > 1, reduce-scatter the hidden states to get the final output.
|
||||
|
||||
When TP size > 1, all-reduce the hidden states to get the final output.
|
||||
"""
|
||||
if self.moe_config.dp_size > 1:
|
||||
hidden_states = data_parallel_reduce_scatter(hidden_states, dim=0)
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor, # noqa: F841
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
self.topk_weights = topk_weights
|
||||
self.topk_ids = topk_ids
|
||||
|
||||
first_expert_idx = 0
|
||||
if expert_map is not None:
|
||||
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# So we need to filter out invalid tokens by zeroing their weights.
|
||||
# This is a workaround and should be removed after the issue is fixed
|
||||
mask = expert_map[topk_ids] != -1
|
||||
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
|
||||
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
|
||||
self.topk_weights = torch.where(mask, topk_weights, 0.0)
|
||||
|
||||
first_expert_idx = self.moe_config.ep_rank * num_experts
|
||||
last_expert_idx = first_expert_idx + num_experts
|
||||
|
||||
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
|
||||
torch_npu.npu_moe_init_routing_v2(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
active_num=num_tokens * self.moe_config.experts_per_token,
|
||||
expert_num=self.moe_config.num_experts,
|
||||
expert_tokens_num_type=1, # Only support `count` mode now
|
||||
expert_tokens_num_flag=True, # Output `expert_tokens`
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=-1,
|
||||
))
|
||||
self.expanded_row_idx = expanded_row_idx
|
||||
permuted_hidden_states = permuted_hidden_states
|
||||
|
||||
group_list_type = 1 # `count` mode
|
||||
|
||||
return permuted_hidden_states, expert_tokens, None, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=mlp_output,
|
||||
sorted_indices=self.expanded_row_idx,
|
||||
probs=self.topk_weights)
|
||||
|
||||
|
||||
class NativeAllGatherCommImpl(AllGatherCommImpl):
|
||||
"""This implementation should be compatible with all scenarios.
|
||||
|
||||
Note that this implementation purely consists of native PyTorch ops
|
||||
and does not use any NPU-specific ops. So the performance may not be optimal.
|
||||
But it is a good fallback for scenarios where NPU-specific ops are not available.
|
||||
"""
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
# Generate token indices and flatten
|
||||
token_indices = torch.arange(num_tokens,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64)
|
||||
token_indices = (token_indices.unsqueeze(1).expand(
|
||||
-1, self.moe_config.experts_per_token).reshape(-1))
|
||||
|
||||
# Flatten token-to-expert mappings and map to local experts
|
||||
weights_flat = topk_weights.view(-1)
|
||||
experts_flat = topk_ids.view(-1)
|
||||
local_experts_flat = (expert_map[experts_flat]
|
||||
if expert_map is not None else experts_flat)
|
||||
|
||||
# Filter valid token-expert pairs
|
||||
mask = local_experts_flat != -1
|
||||
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# So we need to filter out invalid tokens by zeroing their weights.
|
||||
# This is a workaround and should be removed after the issue is fixed
|
||||
filtered_weights = torch.where(mask, weights_flat,
|
||||
torch.zeros_like(weights_flat)).to(
|
||||
topk_weights.dtype)
|
||||
filtered_experts = torch.where(
|
||||
mask,
|
||||
local_experts_flat,
|
||||
torch.full_like(local_experts_flat, num_experts),
|
||||
).to(topk_ids.dtype)
|
||||
|
||||
# Sort by local expert IDs
|
||||
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
||||
self.sorted_token_indices = token_indices[sort_indices]
|
||||
self.sorted_weights = filtered_weights[sort_indices]
|
||||
|
||||
# Compute token counts with minlength of num_experts
|
||||
# This is equivalent to but faster than:
|
||||
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
||||
token_counts = torch.zeros(num_experts + 1,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64)
|
||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
||||
expert_tokens = token_counts[:num_experts]
|
||||
|
||||
# Rearrange hidden_states
|
||||
permuted_hidden_states = hidden_states[self.sorted_token_indices]
|
||||
|
||||
group_list_type = 1 # `count` mode
|
||||
|
||||
return permuted_hidden_states, expert_tokens, None, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
|
||||
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
||||
mlp_output)
|
||||
|
||||
hidden_states[:] = final_hidden_states
|
||||
|
||||
|
||||
class MC2CommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: Optional[FusedMoEConfig]):
|
||||
super().__init__(moe_config)
|
||||
|
||||
# NOTE: We do not need to use mc2_group's rank and world size
|
||||
# because ep_group and mc2_group basically have the same init params.
|
||||
# We only init another group because of the restriction of MC2:
|
||||
# "No other groups can be used in the same process as the MC2 group."
|
||||
self.mc2_comm_name = get_mc2_group().device_group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(self.moe_config.ep_rank)
|
||||
|
||||
# Feature flags
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||||
"npu_moe_distribute_dispatch_v2")
|
||||
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
self.need_extra_args = self.is_ascend_a3
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
|
||||
# tp_size and tp_rank.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""The target_pad_length is calculated in forward_context, here we pad the
|
||||
hidden states and router logits. And if TP size > 1, we also need to split
|
||||
the tensors accordingly.
|
||||
"""
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
forward_context = get_forward_context()
|
||||
self.mc2_mask = forward_context.mc2_mask
|
||||
target_pad_length = forward_context.padded_num_tokens
|
||||
pad_size = target_pad_length - self.num_tokens
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_mc2_mask = torch.tensor_split(self.mc2_mask,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
self.split_hidden_states = split_hidden_states
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
self.mc2_mask = split_mc2_mask[self.tp_rank]
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""If TP size > 1, all-gather the hidden states to get the final output.
|
||||
|
||||
Also, unpad the hidden states if needed.
|
||||
"""
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
# Store tensors needed for post_process
|
||||
self.topk_ids = topk_ids
|
||||
self.topk_weights = topk_weights.to(torch.float32)
|
||||
|
||||
dispatch_kwargs = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": self.moe_config.num_experts,
|
||||
"global_bs": 0,
|
||||
"scales": None,
|
||||
"quant_mode": 2 if apply_a8_quantization else 0,
|
||||
"group_ep": self.mc2_comm_name,
|
||||
"ep_world_size": self.moe_config.ep_size,
|
||||
"ep_rank_id": self.moe_config.ep_rank,
|
||||
}
|
||||
|
||||
if self.need_extra_args:
|
||||
dispatch_kwargs.update({
|
||||
"group_tp": self.mc2_comm_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
||||
dispatch_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
|
||||
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
|
||||
|
||||
(
|
||||
permuted_hidden_states,
|
||||
dynamic_scale,
|
||||
self.assist_info_for_combine,
|
||||
expert_tokens,
|
||||
self.ep_recv_counts,
|
||||
self.tp_recv_counts,
|
||||
) = dispatch(**dispatch_kwargs)[:6]
|
||||
|
||||
group_list_type = 1
|
||||
|
||||
return permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
combine_kwargs = {
|
||||
"expand_x": mlp_output,
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_scales": self.topk_weights,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": self.moe_config.num_experts,
|
||||
"global_bs": 0,
|
||||
"ep_send_counts": self.ep_recv_counts,
|
||||
"group_ep": self.mc2_comm_name,
|
||||
"ep_world_size": self.moe_config.ep_size,
|
||||
"ep_rank_id": self.moe_config.ep_rank,
|
||||
}
|
||||
|
||||
if self.enable_dispatch_v2:
|
||||
combine_kwargs[
|
||||
"assist_info_for_combine"] = self.assist_info_for_combine
|
||||
else:
|
||||
combine_kwargs["expand_idx"] = self.assist_info_for_combine
|
||||
|
||||
if self.need_extra_args:
|
||||
combine_kwargs.update({
|
||||
"tp_send_counts": self.tp_recv_counts,
|
||||
"group_tp": self.mc2_comm_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
||||
combine_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
|
||||
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
|
||||
|
||||
hidden_states[:] = combine(**combine_kwargs)
|
||||
|
||||
|
||||
class AlltoAllCommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_grouped_matmul` is available.
|
||||
|
||||
This implementation uses all-to-all communication to exchange tokens
|
||||
between data parallel ranks before and after the MLP computation. It should
|
||||
have better performance than AllGatherCommImpl when DP size > 1.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: Optional[FusedMoEConfig]):
|
||||
super().__init__(moe_config)
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
get_token_dispatcher
|
||||
self.token_dispatcher = get_token_dispatcher(
|
||||
"TokenDispatcherWithAll2AllV")
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
|
||||
# tp_size and tp_rank.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
pad_size = self.tp_size - self.num_tokens
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
self.split_hidden_states = split_hidden_states
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""If TP size > 1, all-gather the hidden states to get the final output.
|
||||
|
||||
Also, unpad the hidden states if needed.
|
||||
"""
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
None,
|
||||
log2phy=None,
|
||||
with_quant=apply_a8_quantization)
|
||||
return results["hidden_states"], results["group_list"], results[
|
||||
"dynamic_scale"], results["group_list_type"]
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
hidden_states[:] = self.token_dispatcher.token_combine(mlp_output)
|
||||
1070
vllm_ascend/distributed/mooncake_connector.py
Normal file
1070
vllm_ascend/distributed/mooncake_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
119
vllm_ascend/distributed/parallel_state.py
Normal file
119
vllm_ascend/distributed/parallel_state.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
|
||||
init_model_parallel_group)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
_MLP_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
_LMTP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_mc2_group() -> GroupCoordinator:
|
||||
assert _MC2 is not None, ("mc2 group is not initialized")
|
||||
return _MC2
|
||||
|
||||
|
||||
def get_lmhead_tp_group() -> GroupCoordinator:
|
||||
assert _LMTP is not None, (
|
||||
"lm head tensor parallel group is not initialized")
|
||||
return _LMTP
|
||||
|
||||
|
||||
def get_mlp_tp_group() -> GroupCoordinator:
|
||||
assert _MLP_TP is not None, ("mlp group is not initialized")
|
||||
return _MLP_TP
|
||||
|
||||
|
||||
def model_parallel_initialized():
|
||||
return (_MC2 is not None)
|
||||
|
||||
|
||||
def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
if model_parallel_initialized():
|
||||
return
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
backend = torch.distributed.get_backend(get_world_group().device_group)
|
||||
|
||||
# The layout of all ranks: ExternalDP * EP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
# every dp rank can generate independently (in verl integration).
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, parallel_config.data_parallel_size *
|
||||
parallel_config.tensor_parallel_size)
|
||||
global _MC2
|
||||
group_ranks = all_ranks.unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
_MC2 = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="mc2")
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE:
|
||||
global _MLP_TP
|
||||
assert _MLP_TP is None, (
|
||||
"mlp tensor model parallel group is already initialized")
|
||||
|
||||
mlp_tp = parallel_config.data_parallel_size
|
||||
|
||||
all_ranks_mlp_head = torch.arange(world_size).reshape(
|
||||
-1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa
|
||||
group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
_MLP_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="mlp_tp")
|
||||
|
||||
lmhead_tensor_parallel_size = get_ascend_config(
|
||||
).lmhead_tensor_parallel_size
|
||||
if lmhead_tensor_parallel_size is not None:
|
||||
group_ranks = []
|
||||
global _LMTP
|
||||
num_lmhead_tensor_parallel_groups: int = (world_size //
|
||||
lmhead_tensor_parallel_size)
|
||||
for i in range(num_lmhead_tensor_parallel_groups):
|
||||
ranks = list(
|
||||
range(i * lmhead_tensor_parallel_size,
|
||||
(i + 1) * lmhead_tensor_parallel_size))
|
||||
group_ranks.append(ranks)
|
||||
_LMTP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="lmheadtp")
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_mlp_tp_group().world_size
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_rank():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_mlp_tp_group().rank_in_group
|
||||
|
||||
|
||||
def destroy_ascend_model_parallel():
|
||||
global _MC2
|
||||
if _MC2:
|
||||
_MC2.destroy()
|
||||
_MC2 = None
|
||||
|
||||
global _MLP_TP
|
||||
if _MLP_TP:
|
||||
_MLP_TP.destroy()
|
||||
_MLP_TP = None
|
||||
|
||||
global _LMTP
|
||||
if _LMTP:
|
||||
_LMTP.destroy()
|
||||
_LMTP = None
|
||||
248
vllm_ascend/distributed/tensor_parallel.py
Normal file
248
vllm_ascend/distributed/tensor_parallel.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapts from: Megatron/megatron/core/tensor_parallel/mappings.py.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
import torch
|
||||
|
||||
|
||||
def _gather_along_first_dim(input_, group, output_split_sizes=None):
|
||||
"""Gather tensors and concatenate along the first dimension.
|
||||
|
||||
Args:
|
||||
input_tensor (torch.Tensor):
|
||||
A tensor to be gathered.
|
||||
output_split_sizes (List[int], optional):
|
||||
A list specifying the sizes of the output splits along the first dimension.
|
||||
If None, equal splitting is assumed. Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Gathered tensor.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
if output_split_sizes is None:
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
else:
|
||||
dim_size[0] = sum(output_split_sizes)
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
output_tensor_list = list(
|
||||
torch.split(output, output_split_sizes, dim=0))
|
||||
torch.distributed.all_gather(output_tensor_list, input_, group=group)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather_along_last_dim(input_, group):
|
||||
"""Gather tensors and concatenate along the last dimension."""
|
||||
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
tensor_list = output.chunk(world_size, dim=0)
|
||||
output = torch.cat(tensor_list, dim=-1).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _reduce_scatter_along_first_dim(input_,
|
||||
group,
|
||||
input_split_sizes=None,
|
||||
use_global_buffer=False):
|
||||
"""Reduce-scatter the input tensor across model parallel group.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): The input tensor to be reduce-scattered.
|
||||
input_split_sizes (List[int], optional): A list specifying the sizes of
|
||||
the input splits along the first dimension for each rank. If None,
|
||||
equal splitting is assumed. Default: None.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
if input_split_sizes is None:
|
||||
dim_size = list(input_.size())
|
||||
assert (
|
||||
dim_size[0] % world_size == 0
|
||||
), "First dimension of the tensor should be divisible by tensor parallel size"
|
||||
|
||||
dim_size[0] = dim_size[0] // world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.reduce_scatter_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
else:
|
||||
rank = torch.distributed.get_rank(group)
|
||||
input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0))
|
||||
|
||||
output = torch.empty_like(input_tensor_list[rank])
|
||||
torch.distributed.reduce_scatter(output,
|
||||
input_tensor_list,
|
||||
group=group)
|
||||
return output
|
||||
|
||||
|
||||
def _reduce_scatter_along_last_dim(input_, group):
|
||||
"""Reduce-scatter tensors on the last dimension."""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
target_shape = list(input_.size())
|
||||
target_shape[-1] = target_shape[-1] // world_size
|
||||
input_ = input_.reshape(-1, input_.shape[-1])
|
||||
split_tensors = torch.split(input_,
|
||||
split_size_or_sections=input_.shape[-1] //
|
||||
world_size,
|
||||
dim=1)
|
||||
concat_tensor = torch.cat(split_tensors, dim=0)
|
||||
output = _reduce_scatter_along_first_dim(concat_tensor,
|
||||
group).reshape(target_shape)
|
||||
return output
|
||||
|
||||
|
||||
def all_gather_last_dim_from_tensor_parallel_region(input_, group):
|
||||
"""Wrapper for autograd function: forward: AG, backward RS <last dim>"""
|
||||
return _gather_along_last_dim(input_, group)
|
||||
|
||||
|
||||
def reduce_scatter_to_sequence_parallel_region(input_,
|
||||
group,
|
||||
input_split_sizes=None):
|
||||
"""Wrapper for autograd function: forward: RS, backward AG <first dim>"""
|
||||
return _reduce_scatter_along_first_dim(input_, group, input_split_sizes)
|
||||
|
||||
|
||||
def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group):
|
||||
"""Wrapper for autograd function: forward: RS, backward AG: AG <last dim>"""
|
||||
return _reduce_scatter_along_last_dim(input_, group)
|
||||
|
||||
|
||||
def gather_from_sequence_parallel_region(
|
||||
input_,
|
||||
group,
|
||||
output_split_sizes=None,
|
||||
):
|
||||
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
|
||||
return _gather_along_first_dim(input_, group, output_split_sizes)
|
||||
|
||||
|
||||
def all_to_all(group, input, output_split_sizes=None, input_split_sizes=None):
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input
|
||||
|
||||
input = input.contiguous()
|
||||
if output_split_sizes is None:
|
||||
# Equal split (all2all)
|
||||
output = torch.empty_like(input)
|
||||
else:
|
||||
# Unequal split (all2all-v)
|
||||
output = input.new_empty(
|
||||
size=[sum(output_split_sizes)] + list(input.size()[1:]),
|
||||
dtype=input.dtype,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
torch.distributed.all_to_all_single(
|
||||
output,
|
||||
input,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def all_to_all_sp2hp(input_, group):
|
||||
"""
|
||||
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
|
||||
[num_tokens/TP, H] to [num_tokens, H/TP].
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor):
|
||||
The input tensor which has been distributed along the sequence
|
||||
dimension.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor with shape [num_tokens, H/TP].
|
||||
|
||||
"""
|
||||
if group is None:
|
||||
return input_
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
tp_group = group
|
||||
input_ = input_.reshape(-1, input_.shape[-1])
|
||||
split_tensors = torch.split(input_,
|
||||
split_size_or_sections=input_.shape[-1] //
|
||||
world_size,
|
||||
dim=1)
|
||||
concat_tensor = torch.cat(split_tensors, dim=0)
|
||||
output = all_to_all(tp_group, concat_tensor)
|
||||
return output
|
||||
|
||||
|
||||
def all_to_all_hp2sp(input_, group):
|
||||
"""
|
||||
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
|
||||
[num_tokens, H/TP] to [num_tokens/TP, H].
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor):
|
||||
The input tensor which has been distributed along the hidden
|
||||
dimension.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor with shape [num_tokens/TP, H].
|
||||
"""
|
||||
if group is None:
|
||||
return input_
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
input_ = input_.reshape(-1, input_.shape[-1])
|
||||
tp_group = group
|
||||
input_exchanged = all_to_all(tp_group, input_)
|
||||
input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1])
|
||||
split_tensors = torch.split(
|
||||
input_reshaped,
|
||||
split_size_or_sections=input_reshaped.shape[0] // world_size,
|
||||
dim=0)
|
||||
output = torch.cat(split_tensors, dim=-1)
|
||||
return output
|
||||
Reference in New Issue
Block a user