v0.10.1rc1

This commit is contained in:
2025-09-09 09:40:35 +08:00
parent d6f6ef41fe
commit 9149384e03
432 changed files with 84698 additions and 1 deletions

View File

@@ -0,0 +1,28 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm.distributed.kv_transfer.kv_connector.factory import \
KVConnectorFactory
KVConnectorFactory.register_connector(
"LLMDataDistCMgrConnector",
"vllm_ascend.distributed.llmdatadist_c_mgr_connector",
"LLMDataDistCMgrConnector")
KVConnectorFactory.register_connector(
"MooncakeConnectorV1", "vllm_ascend.distributed.mooncake_connector",
"MooncakeConnector")

View File

@@ -0,0 +1,25 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.distributed.parallel_state import get_dp_group
def data_parallel_reduce_scatter(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Reduce-Scatter the input tensor across data parallel group."""
return get_dp_group().reduce_scatter(input_, dim)

View File

@@ -0,0 +1,75 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import List, Optional
import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase
class NPUCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: dist.ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[dist.ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
# init device according to rank
self.device = torch.npu.current_device()
def all_to_all(self,
input_: torch.Tensor,
scatter_dim: int = 0,
gather_dim: int = -1,
scatter_sizes: Optional[List[int]] = None,
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
if scatter_dim < 0:
scatter_dim += input_.dim()
if gather_dim < 0:
gather_dim += input_.dim()
if scatter_sizes is not None and gather_sizes is not None:
input_list = [
t.contiguous()
for t in torch.split(input_, scatter_sizes, scatter_dim)
]
output_list = []
tensor_shape_base = input_list[self.rank].size()
for i in range(self.world_size):
tensor_shape = list(tensor_shape_base)
tensor_shape[gather_dim] = gather_sizes[i]
output_list.append(
torch.empty(tensor_shape,
dtype=input_.dtype,
device=input_.device))
else:
input_list = [
t.contiguous() for t in torch.tensor_split(
input_, self.world_size, scatter_dim)
]
output_list = [
torch.empty_like(input_list[i]) for i in range(self.world_size)
]
dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
return output_tensor

View File

@@ -0,0 +1,165 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import logger
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
hcclRedOpTypeEnum, hcclUniqueId)
from vllm_ascend.utils import current_stream
class PyHcclCommunicator:
def __init__(
self,
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the PyHcclCommunicator to. If None,
it will be bind to f"npu:{local_rank}".
library_path: the path to the HCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.HCCL, (
"PyHcclCommunicator should be attached to a non-HCCL group.")
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
else:
self.rank = group.rank
self.world_size = group.world_size
self.group = group
# if world_size == 1, no need to create communicator
if self.world_size == 1:
self.available = False
self.disabled = True
return
try:
self.hccl = HCCLLibrary(library_path)
except Exception:
# disable because of missing HCCL library
# e.g. in a non-NPU environment
self.available = False
self.disabled = True
return
self.available = True
self.disabled = False
logger.info("vLLM is using pyhccl")
if isinstance(device, int):
device = torch.device(f"npu:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
if self.rank == 0:
# get the unique id from HCCL
with torch.npu.device(device):
self.unique_id = self.hccl.hcclGetUniqueId()
else:
# construct an empty unique id
self.unique_id = hcclUniqueId()
if not isinstance(group, StatelessProcessGroup):
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
# hccl communicator and stream will use this device
# `torch.npu.device` is a context manager that changes the
# current npu device to the specified one
with torch.npu.device(device):
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
self.world_size, self.unique_id, self.rank)
stream = current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
stream.synchronize()
del data
def all_reduce(self,
in_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
if self.disabled:
return None
# hccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert in_tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}")
out_tensor = torch.empty_like(in_tensor)
if stream is None:
stream = current_stream()
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
hcclDataTypeEnum.from_torch(in_tensor.dtype),
hcclRedOpTypeEnum.from_torch(op), self.comm,
aclrtStream_t(stream.npu_stream))
return out_tensor
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
if src == self.rank:
buffer = buffer_type(tensor.data_ptr())
else:
buffer = buffer_type(tensor.data_ptr())
self.hccl.hcclBroadcast(buffer, tensor.numel(),
hcclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, aclrtStream_t(stream.npu_stream))

View File

@@ -0,0 +1,253 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from torch.distributed import ReduceOp
from vllm.logger import logger
from vllm_ascend.utils import find_hccl_library
# export types and functions from hccl to Python ===
# for the original hccl definition, please check
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl.h#L90
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl_types.h#L48
hcclResult_t = ctypes.c_int
hcclComm_t = ctypes.c_void_p
class hcclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 4108)]
aclrtStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p
hcclDataType_t = ctypes.c_int
class hcclDataTypeEnum:
hcclInt8 = 0
hcclInt16 = 1
hcclInt32 = 2
hcclFloat16 = 3
hcclFloat32 = 4
hcclInt64 = 5
hcclUint64 = 6
hcclUint8 = 7
hcclUint16 = 8
hcclUint32 = 9
hcclFloat64 = 10
hcclBfloat16 = 11
hcclInt128 = 12
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.hcclInt8
if dtype == torch.uint8:
return cls.hcclUint8
if dtype == torch.int32:
return cls.hcclInt32
if dtype == torch.int64:
return cls.hcclInt64
if dtype == torch.float16:
return cls.hcclFloat16
if dtype == torch.float32:
return cls.hcclFloat32
if dtype == torch.float64:
return cls.hcclFloat64
if dtype == torch.bfloat16:
return cls.hcclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
hcclRedOp_t = ctypes.c_int
class hcclRedOpTypeEnum:
hcclSum = 0
hcclProd = 1
hcclMax = 2
hcclMin = 3
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.hcclSum
if op == ReduceOp.PRODUCT:
return cls.hcclProd
if op == ReduceOp.MAX:
return cls.hcclMax
if op == ReduceOp.MIN:
return cls.hcclMin
raise ValueError(f"Unsupported op: {op}")
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
class HCCLLibrary:
exported_functions = [
# const char* HcclGetErrorString(HcclResult code);
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
Function("HcclGetRootInfo", hcclResult_t,
[ctypes.POINTER(hcclUniqueId)]),
# HcclResult HcclCommInitRootInfo(
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
Function("HcclCommInitRootInfo", hcclResult_t, [
ctypes.c_int,
ctypes.POINTER(hcclUniqueId),
ctypes.c_int,
ctypes.POINTER(hcclComm_t),
]),
# HcclResult HcclAllReduce(
# void *sendBuf, void *recvBuf, uint64_t count,
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
# aclrtStream stream);
Function("HcclAllReduce", hcclResult_t, [
buffer_type,
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
hcclRedOp_t,
hcclComm_t,
aclrtStream_t,
]),
# HcclResult HcclBroadcast(
# void *buf, uint64_t count,
# HcclDataType dataType, uint32_t root,
# HcclComm comm, aclrtStream stream);
Function("HcclBroadcast", hcclResult_t, [
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
ctypes.c_int,
hcclComm_t,
aclrtStream_t,
]),
# HcclResult HcclCommDestroy(HcclComm comm);
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the correspongding directory
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_hccl_library()
try:
if so_file not in HCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
HCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = HCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
logger.error(
"Failed to load HCCL library from %s. "
"It is expected if you are not running on Ascend NPUs."
"Otherwise, the hccl library might not exist, be corrupted "
"or it does not support the current platform %s. "
"If you already have the library, please set the "
"environment variable HCCL_SO_PATH"
" to point to the correct hccl library path.", so_file,
platform.platform())
raise e
if so_file not in HCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
for func in HCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
HCCLLibrary.path_to_dict_mapping[so_file] = _funcs
self._funcs = HCCLLibrary.path_to_dict_mapping[so_file]
def hcclGetErrorString(self, result: hcclResult_t) -> str:
return self._funcs["HcclGetErrorString"](result).decode("utf-8")
def HCCL_CHECK(self, result: hcclResult_t) -> None:
if result != 0:
error_str = self.hcclGetErrorString(result)
raise RuntimeError(f"HCCL error: {error_str}")
def hcclGetUniqueId(self) -> hcclUniqueId:
unique_id = hcclUniqueId()
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
ctypes.byref(unique_id)))
return unique_id
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
rank: int) -> hcclComm_t:
comm = hcclComm_t()
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
return comm
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: hcclComm_t,
stream: aclrtStream_t) -> None:
# `datatype` actually should be `hcclDataType_t`
# and `op` should be `hcclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
datatype, op, comm,
stream))
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
root: int, comm: hcclComm_t,
stream: aclrtStream_t) -> None:
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
root, comm, stream))
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))
__all__ = [
"HCCLLibrary",
"hcclDataTypeEnum",
"hcclRedOpTypeEnum",
"hcclUniqueId",
"hcclComm_t",
"aclrtStream_t",
"buffer_type",
]

View File

@@ -0,0 +1,894 @@
import contextlib
import json
import math
import os
import threading
import time
from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Optional, Tuple
import llm_datadist # type: ignore
import msgspec
import torch
import zmq
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
LLMException, LLMRole)
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_tp_group, get_world_group
from vllm.forward_context import ForwardContext
from vllm.utils import get_ip, logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request, RequestStatus
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
TORCH_DTYPE_TO_NPU_DTYPE = {
torch.half: llm_datadist.DataType.DT_FLOAT16,
torch.float16: llm_datadist.DataType.DT_FLOAT16,
torch.bfloat16: llm_datadist.DataType.DT_BF16,
torch.float: llm_datadist.DataType.DT_FLOAT,
torch.float32: llm_datadist.DataType.DT_FLOAT,
torch.int8: llm_datadist.DataType.DT_INT8,
torch.int64: llm_datadist.DataType.DT_INT64,
torch.int32: llm_datadist.DataType.DT_INT32
}
class LLMDataDistCMgrEvent(Enum):
ReqForMetadata = 0
ReqForFinished = 1
class LLMDataDistCMgrAgentMetadata(msgspec.Struct):
super_pod_id: str
server_id: str
device_id: str
device_ip: str
super_device_id: str
cluster_id: int
@dataclass
class ReqMeta:
local_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
remote_port: str
engine_id: str
remote_tp_size: str
class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.requests: dict[str, ReqMeta] = {}
def add_new_req(self, request_id: str, local_block_ids: list[int],
kv_transfer_params: dict[str, Any]):
self.requests[request_id] = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_tp_size=kv_transfer_params["remote_tp_size"],
)
class LLMDataDistCMgrConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: Optional[
LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler(
vllm_config, self.engine_id)
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = LLMDataDistCMgrConnectorWorker(vllm_config)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(
self,
kv_caches: dict[
str, # type: ignore[override]
Tuple[torch.Tensor]]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished(finished_req_ids)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata,
LLMDataDistCMgrConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""LLMDataDistCMgrConnector does not do layerwise saving, the load is in blocking manager."""
pass
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata, **kwargs) -> None:
"""LLMDataDistCMgrConnector does not save explicitly."""
pass
def wait_for_save(self):
"""LLMDataDistCMgrConnector does not save explicitly."""
pass
class LLMDataDistCMgrConnectorScheduler():
def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
self.local_ip = get_ip()
# Can not retrieve the parallel config since it is not initialized.
self.local_dp_rank = None
self.tp_size = None
dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
params = request.kv_transfer_params
logger.debug(
f"LLMDataDistCMgrConnector get_num_new_matched_tokens: num_computed_tokens={num_computed_tokens}, kv_transfer_params={params}"
)
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
assert num_computed_tokens % self.block_size == 0
# Note: We use the full token count as transmit data here.
count = max(len(request.prompt_token_ids) - num_computed_tokens, 0)
return count, count > 0
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
num_externel_tokens: int):
params = request.kv_transfer_params
logger.debug(
f"LLMDataDistCMgrConnector update states num_externel_tokens: {num_externel_tokens} kv_transfer_params: {params}"
)
if params is not None and params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port", "remote_tp_size")):
self._reqs_need_recv[request.request_id] = (
request, blocks.get_unhashed_block_ids())
else:
logger.warning("" \
f"Invalid KVTransferParams {params}, This request will be discard")
else:
assert num_externel_tokens == 0
params["do_remote_prefill"] = False
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = LLMDataDistCMgrConnectorMetadata()
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
meta.add_new_req(request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params)
self._reqs_need_recv.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
params = request.kv_transfer_params
logger.debug(
"LLMDataDistCMgrConnector request_finished, request_status=%s, "
"kv_transfer_params=%s", request.status, params)
if (params is None or not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
return False, None
# note: NIXL transfer the full block only, but I don't see any reason to do that, so here
# we just transfer any data that computed from prefill node
# note: there might be some issue on this, check it if there is any unexpected result
computed_block_ids = block_ids
delay_free_blocks = len(computed_block_ids) > 0
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(computed_block_ids), request.request_id)
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id,
remote_host=self.local_ip,
remote_port=self.port,
remote_tp_size=str(
self.vllm_config.parallel_config.tensor_parallel_size),
)
class LLMDataDistCMgrConnectorWorker():
"""
Implementation of Worker side methods
"""
def __init__(self, vllm_config: VllmConfig):
assert vllm_config.kv_transfer_config is not None
logger.info("Initialize the LLMDataDistCMgrConnectorWorker")
# we assume the local node only contains dp and tp, and tp will not communicate inter-node.
# for any scenario beyond this scope, the functionality of this connector is not guaranteed.
self.local_rank_on_node = get_world_group().rank % (
vllm_config.parallel_config.data_parallel_size_local *
vllm_config.parallel_config.tensor_parallel_size)
self.local_rank = get_world_group().local_rank
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.tp_rank = get_tp_group().rank_in_group
self.rank = get_world_group().rank
self.local_ip = get_ip()
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
self.local_agent_metadata: Optional[
LLMDataDistCMgrAgentMetadata] = None
self.vllm_config = vllm_config
self.executor = ThreadPoolExecutor(1)
self.thread_lock = threading.Lock()
self.llm_datadist_role = None
self.llm_datadist_remote_role = None
if self.kv_transfer_config.kv_role == "kv_producer":
self.llm_datadist_role = LLMRole.PROMPT
self.llm_datadist_remote_role = LLMRole.DECODER
elif self.kv_transfer_config.kv_role == "kv_consumer":
self.llm_datadist_role = LLMRole.DECODER
self.llm_datadist_remote_role = LLMRole.PROMPT
else:
raise RuntimeError(
f"LLMDataDistWorker: Receive unexpected kv role in LLMDataDistWorker, this worker now only support kv_producer and kv_consumer, but receiving {vllm_config.kv_transfer_config.kv_role}"
)
# linked_cluster record the cluster that already build the connection its format should be {"cluster_id": "comm_name"}
self.linked_cluster: dict[Any, Any] = {}
self.prefill_device_list: list[tuple[int, int]] = []
self.decode_device_list: list[tuple[int, int]] = []
global_rank_table = self.read_offline_rank_table()
self.local_agent_metadata = self.read_agent_metadata(global_rank_table)
self.llm_datadist = LLMDataDist(self.llm_datadist_role,
self.local_agent_metadata.cluster_id)
self.init_llm_datadist()
self.finished_reqs: set[str] = set()
self.soc_info = get_ascend_soc_version()
# Set hccl deterministic for model execute
os.environ["HCCL_DETERMINISTIC"] = "true"
self.done_receiving_counts: defaultdict[str,
set[int]] = defaultdict(set)
def listen_for_agent_metadata_req(self, event: threading.Event):
assert self.local_agent_metadata is not None
port = envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
url = f"tcp://{envs_ascend.VLLM_ASCEND_LLMDD_RPC_IP}:{port}"
msg_encoder = msgspec.msgpack.Encoder()
msg_decoder = msgspec.msgpack.Decoder()
msg_to_send = msg_encoder.encode(self.local_agent_metadata)
logger.debug(f"Start to listen to address: {url}")
logger.debug(
f"The local agent metadata have {len(msg_to_send)} bytes here")
logger.info(
f"LLMDataDistCMgrConnectorWorker: Cluster {self.local_agent_metadata.cluster_id} start to listen request from peers"
)
with zmq_ctx(zmq.ROUTER, url) as sock: # type: ignore[attr-defined]
event.set()
while True:
identity, _, msg = sock.recv_multipart()
event_msg, decode_msg = msg_decoder.decode(msg)
event_msg = LLMDataDistCMgrEvent(event_msg)
if event_msg == LLMDataDistCMgrEvent.ReqForMetadata:
if "cluster_id" in decode_msg:
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
logger.info(
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
)
sock.send_multipart((identity, b"", msg_to_send))
self.add_remote_agent(decode_msg)
else:
logger.warning(
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
)
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
finished_req_id = decode_msg[0]
decode_tp_rank = decode_msg[1]
decode_tp_size = decode_msg[2]
with self.thread_lock:
if self._increment_task_count(finished_req_id,
decode_tp_rank,
decode_tp_size):
logger.debug(
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
)
self.finished_reqs.add(finished_req_id)
sock.send_multipart(
(identity, b"", b"receiving decode finished"))
else:
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
)
def _increment_task_count(self, request_id: str, tp_rank: int,
decode_tp_size: int):
if request_id not in self.done_receiving_counts:
self.done_receiving_counts[request_id] = set()
if tp_rank in self.done_receiving_counts[request_id]:
logger.warning(
f"Received duplicate done signal for request {request_id} "
f"from tp rank {tp_rank}. Ignoring.")
return False
self.done_receiving_counts[request_id].add(tp_rank)
if len(self.done_receiving_counts[request_id]) == decode_tp_size:
self.done_receiving_counts.pop(request_id)
logger.info("All transfers completed for request: "
f"{request_id}. Total ranks: "
f"{decode_tp_size}.")
return True
return False
def init_llm_datadist(self):
assert self.local_agent_metadata is not None
llm_config = LLMConfig()
llm_config.device_id = self.local_rank
llm_config.sync_kv_timeout = 20000
llm_config.enable_switch_role = True
llm_config.enable_cache_manager = True
llm_config.enable_remote_cache_accessible = True
llm_config_options = llm_config.generate_options()
self.llm_datadist.init(llm_config_options)
self.cache_manager = self.llm_datadist.cache_manager
logger.info(
f"Done initialize llm_datadist in rank {self.rank}, local rank {self.local_rank}, cluster id {self.local_agent_metadata.cluster_id}"
)
def read_offline_rank_table(self):
assert (
envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH"
rank_table_path = envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
with open(rank_table_path, "r", encoding="utf-8") as f:
global_rank_table = json.load(f)
decode_device_list = global_rank_table["decode_device_list"]
for decode_device in decode_device_list:
server_id = decode_device["server_id"]
device_id = decode_device["device_id"]
self.decode_device_list.append((server_id, device_id))
prefill_device_list = global_rank_table["prefill_device_list"]
for prefill_device in prefill_device_list:
server_id = prefill_device["server_id"]
device_id = prefill_device["device_id"]
self.prefill_device_list.append((server_id, device_id))
# global_rank_table = json.dumps(global_rank_table)
return global_rank_table
@staticmethod
def _get_visible_devices() -> Callable[[str], bool]:
"""
Return a test function that check if the given device ID is visible.
i.e. ASCEND_RT_VISIBLE_DEVICES is not set or contains the device_id.
"""
visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
if not visible_devices:
return lambda device_id: True
visible_device_list = visible_devices.split(",")
return lambda device_id: device_id in visible_device_list
def read_agent_metadata(self, global_rank_table):
device_filter = LLMDataDistCMgrConnectorWorker._get_visible_devices()
devices_type_list = []
agent_metadata = None
if self.llm_datadist_role == LLMRole.PROMPT:
devices_type_list.append("prefill_device_list")
elif self.llm_datadist_role == LLMRole.DECODER:
devices_type_list.append("decode_device_list")
else:
devices_type_list.append("prefill_device_list")
devices_type_list.append("decode_device_list")
for device_type in devices_type_list:
device_list = global_rank_table[device_type]
device_list = [
d for d in device_list if d.get("server_id") == self.local_ip
and device_filter(d.get("device_id", ""))
]
if len(device_list) <= self.tp_rank:
continue
device_info = device_list[self.tp_rank]
super_pod_id_ = device_info.get("super_pod_id", None)
server_id_ = device_info["server_id"]
device_id_ = device_info["device_id"]
device_ip_ = device_info["device_ip"]
super_device_id_ = device_info.get("super_device_id", None)
cluster_id_ = int(device_info["cluster_id"])
agent_metadata = LLMDataDistCMgrAgentMetadata(
super_pod_id=super_pod_id_,
server_id=server_id_,
device_id=device_id_,
device_ip=device_ip_,
super_device_id=super_device_id_,
cluster_id=cluster_id_,
)
assert agent_metadata is not None, f"Can't read the target server_id {self.local_ip} and device_rank {self.rank} from rank table"
return agent_metadata
def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
assert len(first_kv_cache_tuple) > 1
assert self.local_agent_metadata is not None
kv_cache_dtype = first_kv_cache.dtype
self.use_mla: bool = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1)
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
# MHA case. [2 (k and v), num_blocks, ...]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = math.prod(block_shape)
self.cache_addr: list[int] = []
alignment = 2 * 1024 * 1024
if self.use_mla:
cache_k_normed_addr_list = []
cache_k_pe_addr_list = []
k_normed = None
k_pe = None
for cache_or_caches in kv_caches.values():
assert len(cache_or_caches) > 1
k_normed, k_pe = cache_or_caches[0], cache_or_caches[1]
cache_k_normed_addr_list.append(k_normed.data_ptr())
cache_k_pe_addr_list.append(k_pe.data_ptr())
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list)
cache_desc_k_normed = CacheDesc(
len(self.cache_addr[0]), [*k_normed.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_desc_k_pe = CacheDesc(
len(self.cache_addr[1]), [*k_pe.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=0)
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=1)
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe)
self.cache_key = (cache_key_k_normed, cache_key_k_pe)
try:
cache_k_normed = self.cache_manager.register_blocks_cache(
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
cache_k_pe = self.cache_manager.register_blocks_cache(
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
self.cache = (cache_k_normed, cache_k_pe)
logger.info("LLMDataDistWorker: End of register Paged Cache.")
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
else:
for cache_or_caches in kv_caches.values():
for cache in cache_or_caches:
base_addr = cache.data_ptr()
assert base_addr % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
self.cache_addr.append(base_addr)
# register paged kv cache into the llm_cache manager
self.cache_desc = CacheDesc(
len(self.cache_addr), [*cache.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
self.cache_key = BlocksCacheKey(
cluster_id=int(self.local_agent_metadata.cluster_id))
logger.info(
f"num of cache: {len(self.cache_addr)}, size of cache: {[*cache.shape]}, real size of cache: {first_kv_cache.shape}"
)
try:
self.cache = self.cache_manager.register_blocks_cache(
self.cache_desc, self.cache_addr, self.cache_key)
logger.info(
"LLMDataDistCMgrConnectorWorker: End of register Paged Cache."
)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
self.ready_event = threading.Event()
self.metadata_agent_listener_t = threading.Thread(
target=self.listen_for_agent_metadata_req,
args=(self.ready_event, ),
daemon=True,
name="metadata_agent_listener")
self.metadata_agent_listener_t.start()
self.ready_event.wait()
def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata):
futures = []
for req_id, meta in metadata.requests.items():
logger.debug(f"Start to transmit {req_id}")
future = self.executor.submit(
self._read_blocks,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_ip=meta.remote_host,
remote_port=int(meta.remote_port),
remote_engine_id=meta.engine_id,
request_id=req_id,
remote_tp_size=meta.remote_tp_size,
)
futures.append(future)
def handle_exception(future):
if future.exception():
logger.error(f"KV transfer task failed: {future.exception()}")
for future in futures:
future.add_done_callback(handle_exception)
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
assert self.local_agent_metadata is not None
remote_cluster_id = metadata.cluster_id
if remote_cluster_id in self.linked_cluster:
logger.debug(
f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection"
)
return remote_cluster_id
remote_super_pod_id = metadata.super_pod_id
remote_server_id = metadata.server_id
is_same_server = remote_server_id == self.local_agent_metadata.server_id
is_same_pod = remote_super_pod_id == self.local_agent_metadata.super_pod_id
if self.llm_datadist_role == LLMRole.PROMPT:
prefill_metadata = self.local_agent_metadata
decode_metadata = metadata
else:
prefill_metadata = metadata
decode_metadata = self.local_agent_metadata
comm_name = f"pd_comm_{prefill_metadata.device_ip}_{decode_metadata.device_ip}"
cluster_rank_info = {
prefill_metadata.cluster_id: 0,
decode_metadata.cluster_id: 1
}
rank_table = {}
rank_table["version"] = "1.2"
rank_table["server_count"] = "1" if is_same_server else "2"
rank_table["status"] = "completed"
# generate server_list for rank table
rank_table["server_list"] = [] # type: ignore[assignment]
decode_server_device_info = None
prefill_server_device_info = {
"device": [{
k: v
for k, v in [(
"device_id", prefill_metadata.device_id
), ("device_ip", prefill_metadata.device_ip
), ("super_device_id",
prefill_metadata.super_device_id), ("rank_id", "0")]
if v is not None
}],
"server_id":
prefill_metadata.server_id
}
if is_same_server:
prefill_server_device_info["device"].append( # type: ignore[attr-defined]
{
k: v
for k, v in [(
"device_id", decode_metadata.device_id
), ("device_ip", decode_metadata.device_ip
), ("super_device_id",
decode_metadata.super_device_id), ("rank_id", "1")]
if v is not None
})
else:
decode_server_device_info = {
"device": [{
k: v
for k, v in [(
"device_id", decode_metadata.device_id
), ("device_ip", decode_metadata.device_ip
), ("super_device_id",
decode_metadata.super_device_id), ("rank_id", "1")]
if v is not None
}],
"server_id":
decode_metadata.server_id
}
rank_table["server_list"].append( # type: ignore[attr-defined]
prefill_server_device_info)
if decode_server_device_info is not None:
rank_table["server_list"].append( # type: ignore[attr-defined]
decode_server_device_info)
if self.soc_info == AscendSocVersion.A3:
# generate super_pod_list for rank table
super_pod_list = []
prefill_super_pod_info = {
"super_pod_id": prefill_metadata.super_pod_id,
"server_list": [{
"server_id": prefill_metadata.server_id
}],
}
if is_same_pod and not is_same_server:
prefill_super_pod_info[
"server_list"].append( # type: ignore[attr-defined]
{"server_id": decode_metadata.server_id})
super_pod_list.append(prefill_super_pod_info)
if not is_same_pod:
decode_super_pod_id = {
"super_pod_id": decode_metadata.super_pod_id,
"server_list": [{
"server_id": decode_metadata.server_id
}],
}
super_pod_list.append(decode_super_pod_id)
rank_table[
"super_pod_list"] = super_pod_list # type: ignore[assignment]
logger.info(
f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}"
)
logger.info(f"rank table \n{rank_table}")
logger.info(f"comm name: {comm_name}")
logger.info(f"cluster rank info: {cluster_rank_info}")
comm_id = self.llm_datadist.link(comm_name, cluster_rank_info,
json.dumps(rank_table))
while True:
ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id)
if ret == llm_datadist.RegisterMemStatus.OK:
logger.info(
f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}"
)
break
elif ret == llm_datadist.RegisterMemStatus.FAILED:
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}"
)
time.sleep(1)
logger.info("Checking query_register_mem_status again")
self.linked_cluster.update({remote_cluster_id: comm_id})
logger.info(f"cached linked cluster: {self.linked_cluster}")
logger.info(
f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !"
)
return remote_cluster_id
def remove_remote_agent(self, cluster_id: int):
if cluster_id not in self.linked_cluster:
logger.warning(
f"LLMDataDistCMgrConnectorWorker: Warning! Can't remove remote client with cluster id {cluster_id} for its not exist in linked_cluster list"
)
comm_id = self.linked_cluster[cluster_id]
try:
self.llm_datadist.unlink(comm_id)
self.linked_cluster.pop(cluster_id)
except LLMException:
logger.error(
f"Try to remove remote client with cluster id {cluster_id} failed!, program won't terminate, but please carefully check your environment"
)
logger.info(
f"Successfully remove remote client with cluster id {cluster_id} !"
)
def connect_to_remote_agent(self, host: str, port: int) -> int:
url = f"tcp://{host}:{port}"
logger.debug(f"Querying metadata from url: {url}")
msg_encoder = msgspec.msgpack.Encoder()
msg_send = msg_encoder.encode(
[LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata])
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
logger.info("Try request remote metadata from socket......")
sock.send(msg_send)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder()
metadata = decoder.decode(metadata_bytes)
metadata = LLMDataDistCMgrAgentMetadata(**metadata)
logger.info(f"recving metadata: {metadata}")
cluster_id = self.add_remote_agent(metadata)
return cluster_id
def send_finish_to_remote(self, host: str, port: int, request_id):
url = f"tcp://{host}:{port}"
logger.debug(f"Sending finished to remote: {url}")
msg_encoder = msgspec.msgpack.Encoder()
msg_send = msg_encoder.encode([
LLMDataDistCMgrEvent.ReqForFinished,
[request_id, self.tp_rank, self.tp_size]
])
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
try:
sock.send(msg_send)
logger.debug(
f"Request id {request_id} finished message send to remote {url}"
)
_ = sock.recv()
except Exception as e:
logger.error(
f"Failed to send reqest_id {request_id} to prefill: {e}")
def _read_blocks(
self,
local_block_ids: list[int],
remote_block_ids: list[int],
remote_ip: str,
remote_port: int,
remote_engine_id: str,
request_id: str,
remote_tp_size: str,
):
# if remote_ip not in self.linked_cluster:
tp_offset = self.tp_rank % int(remote_tp_size)
remote_cluster_id = self.connect_to_remote_agent(
remote_ip, remote_port + tp_offset)
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
return
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
logger.info(f"remote cluster id is: {remote_cluster_id}")
if self.use_mla:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
else:
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key,
self.cache, # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
self.send_finish_to_remote(remote_ip, remote_port, request_id)
with self.thread_lock:
self.finished_reqs.add(request_id)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""Get the finished recving and sending requuests."""
import copy
with self.thread_lock:
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
self.finished_reqs.clear()
if self.llm_datadist_role == LLMRole.PROMPT:
return req_ids_to_ret, None
else:
return None, req_ids_to_ret
# adopt this from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@contextlib.contextmanager
def zmq_ctx(socket_type: Any,
addr: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
ctx: Optional[zmq.Context] = None # type: ignore[name-defined]
try:
ctx = zmq.Context() # type: ignore[attr-defined]
if socket_type == zmq.ROUTER: # type: ignore[attr-defined]
socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined]
socket.bind(addr)
elif socket_type == zmq.REQ: # type: ignore[attr-defined]
socket = ctx.socket(zmq.REQ) # type: ignore[attr-defined]
socket.connect(addr)
else:
raise ValueError(f"Unexpected socket type: {socket_type}")
yield socket
finally:
if ctx is not None:
ctx.destroy(linger=0)

View File

@@ -0,0 +1,556 @@
from abc import ABC, abstractmethod
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
@abstractmethod
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare the MoE communication method.
This method is called before quant_method.apply to prepare the
communication method. It can be used to initialize any necessary
resources or configurations.
"""
pass
@abstractmethod
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""Finalize the MoE communication method.
This method is called after quant_method.apply to finalize the
communication method. It can be used to clean up any resources or
configurations.
"""
pass
@abstractmethod
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
"""Pre-process before MLP.
Args:
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
Mapping from global expert IDs to local expert IDs.
num_experts (int): Number of local experts (experts on this device).
apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8).
Returns:
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
- permuted_hidden_states (torch.Tensor): Tensor of shape
(num_tokens * top_k_num, hidden_size) after permuting
hidden_states based on topk_ids.
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
Number of tokens assigned to each expert.
- dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, )
Dynamic scale for each expert, used for quantization.
- group_list_type (int): Type of group list, 0 for `cumsum`
and 1 for `count`. This is mainly for `npu_grouped_matmul`
to determine how to handle the output.
Raises:
NotImplementedError: If the method is not implemented in the subclass.
"""
pass
@abstractmethod
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
"""Post-process after MLP.
Args:
mlp_output (torch.Tensor): Tensor of shape
(num_tokens * top_k_num, hidden_size) after MLP.
hidden_states (torch.Tensor): Tensor of shape
(num_tokens, hidden_size) to be updated with the final output.
"""
pass
class AllGatherCommImpl(MoECommMethod):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.
This implementation should be compatible with all scenarios, and
thus it is the default implementation for MoE communication methods.
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
and `torch_npu.npu_moe_token_unpermute` for post-processing
to handle the token-to-expert mapping and communication efficiently.
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
for pre-processing and post-processing, respectively.
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
use `torch_npu.npu_moe_token_unpermute` instead.
This is a workaround and should be removed after the issue is fixed.
"""
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""When DP size > 1, pad the hidden states and router logits for communication."""
if self.moe_config.dp_size > 1:
forward_context = get_forward_context()
max_tokens_across_dp = forward_context.max_tokens_across_dp
self.num_tokens = hidden_states.shape[0]
pad_size = max_tokens_across_dp - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0)
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""When DP size > 1, reduce-scatter the hidden states to get the final output.
When TP size > 1, all-reduce the hidden states to get the final output.
"""
if self.moe_config.dp_size > 1:
hidden_states = data_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = hidden_states[:self.num_tokens]
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor, # noqa: F841
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
num_tokens = hidden_states.shape[0]
self.topk_weights = topk_weights
self.topk_ids = topk_ids
first_expert_idx = 0
if expert_map is not None:
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
mask = expert_map[topk_ids] != -1
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
self.topk_weights = torch.where(mask, topk_weights, 0.0)
first_expert_idx = self.moe_config.ep_rank * num_experts
last_expert_idx = first_expert_idx + num_experts
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
active_num=num_tokens * self.moe_config.experts_per_token,
expert_num=self.moe_config.num_experts,
expert_tokens_num_type=1, # Only support `count` mode now
expert_tokens_num_flag=True, # Output `expert_tokens`
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=-1,
))
self.expanded_row_idx = expanded_row_idx
permuted_hidden_states = permuted_hidden_states
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, None, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
permuted_tokens=mlp_output,
sorted_indices=self.expanded_row_idx,
probs=self.topk_weights)
class NativeAllGatherCommImpl(AllGatherCommImpl):
"""This implementation should be compatible with all scenarios.
Note that this implementation purely consists of native PyTorch ops
and does not use any NPU-specific ops. So the performance may not be optimal.
But it is a good fallback for scenarios where NPU-specific ops are not available.
"""
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
num_tokens = hidden_states.shape[0]
# Generate token indices and flatten
token_indices = torch.arange(num_tokens,
device=hidden_states.device,
dtype=torch.int64)
token_indices = (token_indices.unsqueeze(1).expand(
-1, self.moe_config.experts_per_token).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = (expert_map[experts_flat]
if expert_map is not None else experts_flat)
# Filter valid token-expert pairs
mask = local_experts_flat != -1
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
filtered_weights = torch.where(mask, weights_flat,
torch.zeros_like(weights_flat)).to(
topk_weights.dtype)
filtered_experts = torch.where(
mask,
local_experts_flat,
torch.full_like(local_experts_flat, num_experts),
).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
self.sorted_token_indices = token_indices[sort_indices]
self.sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(num_experts + 1,
device=hidden_states.device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
expert_tokens = token_counts[:num_experts]
# Rearrange hidden_states
permuted_hidden_states = hidden_states[self.sorted_token_indices]
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, None, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
final_hidden_states = torch.zeros_like(hidden_states)
final_hidden_states.index_add_(0, self.sorted_token_indices,
mlp_output)
hidden_states[:] = final_hidden_states
class MC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""
def __init__(self, moe_config: Optional[FusedMoEConfig]):
super().__init__(moe_config)
# NOTE: We do not need to use mc2_group's rank and world size
# because ep_group and mc2_group basically have the same init params.
# We only init another group because of the restriction of MC2:
# "No other groups can be used in the same process as the MC2 group."
self.mc2_comm_name = get_mc2_group().device_group._get_backend(
torch.device("npu")).get_hccl_comm_name(self.moe_config.ep_rank)
# Feature flags
self.enable_dispatch_v2 = hasattr(torch_npu,
"npu_moe_distribute_dispatch_v2")
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
self.need_extra_args = self.is_ascend_a3
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
# tp_size and tp_rank.
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""The target_pad_length is calculated in forward_context, here we pad the
hidden states and router logits. And if TP size > 1, we also need to split
the tensors accordingly.
"""
self.num_tokens, _ = hidden_states.shape
forward_context = get_forward_context()
self.mc2_mask = forward_context.mc2_mask
target_pad_length = forward_context.padded_num_tokens
pad_size = target_pad_length - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
split_mc2_mask = torch.tensor_split(self.mc2_mask,
self.tp_size,
dim=0)
self.split_hidden_states = split_hidden_states
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
self.mc2_mask = split_mc2_mask[self.tp_rank]
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""If TP size > 1, all-gather the hidden states to get the final output.
Also, unpad the hidden states if needed.
"""
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
# Store tensors needed for post_process
self.topk_ids = topk_ids
self.topk_weights = topk_weights.to(torch.float32)
dispatch_kwargs = {
"x": hidden_states,
"expert_ids": self.topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": self.moe_config.num_experts,
"global_bs": 0,
"scales": None,
"quant_mode": 2 if apply_a8_quantization else 0,
"group_ep": self.mc2_comm_name,
"ep_world_size": self.moe_config.ep_size,
"ep_rank_id": self.moe_config.ep_rank,
}
if self.need_extra_args:
dispatch_kwargs.update({
"group_tp": self.mc2_comm_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.is_ascend_a3 and self.enable_dispatch_v2:
dispatch_kwargs.update({
"x_active_mask": self.mc2_mask,
})
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
(
permuted_hidden_states,
dynamic_scale,
self.assist_info_for_combine,
expert_tokens,
self.ep_recv_counts,
self.tp_recv_counts,
) = dispatch(**dispatch_kwargs)[:6]
group_list_type = 1
return permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
combine_kwargs = {
"expand_x": mlp_output,
"expert_ids": self.topk_ids,
"expert_scales": self.topk_weights,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": self.moe_config.num_experts,
"global_bs": 0,
"ep_send_counts": self.ep_recv_counts,
"group_ep": self.mc2_comm_name,
"ep_world_size": self.moe_config.ep_size,
"ep_rank_id": self.moe_config.ep_rank,
}
if self.enable_dispatch_v2:
combine_kwargs[
"assist_info_for_combine"] = self.assist_info_for_combine
else:
combine_kwargs["expand_idx"] = self.assist_info_for_combine
if self.need_extra_args:
combine_kwargs.update({
"tp_send_counts": self.tp_recv_counts,
"group_tp": self.mc2_comm_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.is_ascend_a3 and self.enable_dispatch_v2:
combine_kwargs.update({
"x_active_mask": self.mc2_mask,
})
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
hidden_states[:] = combine(**combine_kwargs)
class AlltoAllCommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_grouped_matmul` is available.
This implementation uses all-to-all communication to exchange tokens
between data parallel ranks before and after the MLP computation. It should
have better performance than AllGatherCommImpl when DP size > 1.
"""
def __init__(self, moe_config: Optional[FusedMoEConfig]):
super().__init__(moe_config)
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
get_token_dispatcher
self.token_dispatcher = get_token_dispatcher(
"TokenDispatcherWithAll2AllV")
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
# tp_size and tp_rank.
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
self.split_hidden_states = split_hidden_states
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""If TP size > 1, all-gather the hidden states to get the final output.
Also, unpad the hidden states if needed.
"""
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
results = self.token_dispatcher.token_dispatch(
hidden_states,
topk_weights,
topk_ids,
None,
log2phy=None,
with_quant=apply_a8_quantization)
return results["hidden_states"], results["group_list"], results[
"dynamic_scale"], results["group_list_type"]
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
hidden_states[:] = self.token_dispatcher.token_combine(mlp_output)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,119 @@
from typing import Optional
import torch
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
init_model_parallel_group)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
_MLP_TP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
def get_mc2_group() -> GroupCoordinator:
assert _MC2 is not None, ("mc2 group is not initialized")
return _MC2
def get_lmhead_tp_group() -> GroupCoordinator:
assert _LMTP is not None, (
"lm head tensor parallel group is not initialized")
return _LMTP
def get_mlp_tp_group() -> GroupCoordinator:
assert _MLP_TP is not None, ("mlp group is not initialized")
return _MLP_TP
def model_parallel_initialized():
return (_MC2 is not None)
def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
if model_parallel_initialized():
return
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
backend = torch.distributed.get_backend(get_world_group().device_group)
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.tensor_parallel_size)
global _MC2
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_MC2 = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="mc2")
if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE:
global _MLP_TP
assert _MLP_TP is None, (
"mlp tensor model parallel group is already initialized")
mlp_tp = parallel_config.data_parallel_size
all_ranks_mlp_head = torch.arange(world_size).reshape(
-1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa
group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group
_MLP_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="mlp_tp")
lmhead_tensor_parallel_size = get_ascend_config(
).lmhead_tensor_parallel_size
if lmhead_tensor_parallel_size is not None:
group_ranks = []
global _LMTP
num_lmhead_tensor_parallel_groups: int = (world_size //
lmhead_tensor_parallel_size)
for i in range(num_lmhead_tensor_parallel_groups):
ranks = list(
range(i * lmhead_tensor_parallel_size,
(i + 1) * lmhead_tensor_parallel_size))
group_ranks.append(ranks)
_LMTP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="lmheadtp")
def get_mlp_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return get_mlp_tp_group().world_size
def get_mlp_tensor_model_parallel_rank():
"""Return world size for the tensor model parallel group."""
return get_mlp_tp_group().rank_in_group
def destroy_ascend_model_parallel():
global _MC2
if _MC2:
_MC2.destroy()
_MC2 = None
global _MLP_TP
if _MLP_TP:
_MLP_TP.destroy()
_MLP_TP = None
global _LMTP
if _LMTP:
_LMTP.destroy()
_LMTP = None

View File

@@ -0,0 +1,248 @@
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapts from: Megatron/megatron/core/tensor_parallel/mappings.py.
# This file is a part of the vllm-ascend project.
import torch
def _gather_along_first_dim(input_, group, output_split_sizes=None):
"""Gather tensors and concatenate along the first dimension.
Args:
input_tensor (torch.Tensor):
A tensor to be gathered.
output_split_sizes (List[int], optional):
A list specifying the sizes of the output splits along the first dimension.
If None, equal splitting is assumed. Default: None.
Returns:
torch.Tensor: Gathered tensor.
"""
world_size = torch.distributed.get_world_size(group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
if output_split_sizes is None:
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
torch.distributed.all_gather_into_tensor(output,
input_.contiguous(),
group=group)
else:
dim_size[0] = sum(output_split_sizes)
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
output_tensor_list = list(
torch.split(output, output_split_sizes, dim=0))
torch.distributed.all_gather(output_tensor_list, input_, group=group)
return output
def _gather_along_last_dim(input_, group):
"""Gather tensors and concatenate along the last dimension."""
world_size = torch.distributed.get_world_size(group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
torch.distributed.all_gather_into_tensor(output,
input_.contiguous(),
group=group)
tensor_list = output.chunk(world_size, dim=0)
output = torch.cat(tensor_list, dim=-1).contiguous()
return output
def _reduce_scatter_along_first_dim(input_,
group,
input_split_sizes=None,
use_global_buffer=False):
"""Reduce-scatter the input tensor across model parallel group.
Args:
input_ (torch.Tensor): The input tensor to be reduce-scattered.
input_split_sizes (List[int], optional): A list specifying the sizes of
the input splits along the first dimension for each rank. If None,
equal splitting is assumed. Default: None.
"""
world_size = torch.distributed.get_world_size(group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
if input_split_sizes is None:
dim_size = list(input_.size())
assert (
dim_size[0] % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
torch.distributed.reduce_scatter_tensor(output,
input_.contiguous(),
group=group)
else:
rank = torch.distributed.get_rank(group)
input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0))
output = torch.empty_like(input_tensor_list[rank])
torch.distributed.reduce_scatter(output,
input_tensor_list,
group=group)
return output
def _reduce_scatter_along_last_dim(input_, group):
"""Reduce-scatter tensors on the last dimension."""
world_size = torch.distributed.get_world_size(group)
target_shape = list(input_.size())
target_shape[-1] = target_shape[-1] // world_size
input_ = input_.reshape(-1, input_.shape[-1])
split_tensors = torch.split(input_,
split_size_or_sections=input_.shape[-1] //
world_size,
dim=1)
concat_tensor = torch.cat(split_tensors, dim=0)
output = _reduce_scatter_along_first_dim(concat_tensor,
group).reshape(target_shape)
return output
def all_gather_last_dim_from_tensor_parallel_region(input_, group):
"""Wrapper for autograd function: forward: AG, backward RS <last dim>"""
return _gather_along_last_dim(input_, group)
def reduce_scatter_to_sequence_parallel_region(input_,
group,
input_split_sizes=None):
"""Wrapper for autograd function: forward: RS, backward AG <first dim>"""
return _reduce_scatter_along_first_dim(input_, group, input_split_sizes)
def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group):
"""Wrapper for autograd function: forward: RS, backward AG: AG <last dim>"""
return _reduce_scatter_along_last_dim(input_, group)
def gather_from_sequence_parallel_region(
input_,
group,
output_split_sizes=None,
):
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
return _gather_along_first_dim(input_, group, output_split_sizes)
def all_to_all(group, input, output_split_sizes=None, input_split_sizes=None):
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input
input = input.contiguous()
if output_split_sizes is None:
# Equal split (all2all)
output = torch.empty_like(input)
else:
# Unequal split (all2all-v)
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype,
device=torch.npu.current_device(),
)
torch.distributed.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
return output
def all_to_all_sp2hp(input_, group):
"""
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
[num_tokens/TP, H] to [num_tokens, H/TP].
Args:
input_ (torch.Tensor):
The input tensor which has been distributed along the sequence
dimension.
Returns:
torch.Tensor: The output tensor with shape [num_tokens, H/TP].
"""
if group is None:
return input_
world_size = torch.distributed.get_world_size(group=group)
tp_group = group
input_ = input_.reshape(-1, input_.shape[-1])
split_tensors = torch.split(input_,
split_size_or_sections=input_.shape[-1] //
world_size,
dim=1)
concat_tensor = torch.cat(split_tensors, dim=0)
output = all_to_all(tp_group, concat_tensor)
return output
def all_to_all_hp2sp(input_, group):
"""
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
[num_tokens, H/TP] to [num_tokens/TP, H].
Args:
input_ (torch.Tensor):
The input tensor which has been distributed along the hidden
dimension.
Returns:
torch.Tensor: The output tensor with shape [num_tokens/TP, H].
"""
if group is None:
return input_
world_size = torch.distributed.get_world_size(group=group)
input_ = input_.reshape(-1, input_.shape[-1])
tp_group = group
input_exchanged = all_to_all(tp_group, input_)
input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1])
split_tensors = torch.split(
input_reshaped,
split_size_or_sections=input_reshaped.shape[0] // world_size,
dim=0)
output = torch.cat(split_tensors, dim=-1)
return output