first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import device_communicators # noqa: F401
from . import kv_transfer # noqa: F401

View File

@@ -0,0 +1,60 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.logger import logger
from vllm_br import envs
class SUPACommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: dist.ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[dist.ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
self.device = torch.supa.current_device()
# TODO: Deprecate this method in the future if torch_br support gather
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
"""All gather as gather"""
output_tensor = self.all_gather(input_, dim)
if self.rank_in_group == dst:
return output_tensor
return None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if envs.VLLM_BR_USE_FP32_ALL_REDUCE and input_ is not None and input_.dtype == torch.bfloat16:
logger.debug(
'[Patch] patch all_reduce: use fp32 all_reduce when env VLLM_BR_USE_FP32_ALL_REDUCE is set'
)
input_ = input_.to(torch.float32)
dist.all_reduce(input_, group=self.device_group)
input_ = input_.to(torch.bfloat16)
else:
dist.all_reduce(input_, group=self.device_group)
return input_

View File

@@ -0,0 +1,18 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import base_device_communicator # noqa: F401
from . import pysccl_wrapper # noqa: F401

View File

@@ -0,0 +1,44 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import torch
import vllm
def supa_prepare_communication_buffer_for_model(
self, model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
"""
if not self.use_all2all:
return
if not self.is_ep_communicator:
return
moe_modules = [
module for module in model.modules()
# TODO(bnell): Should use isinstance but can't. Maybe search for
# presence of quant_method.init_prepare_finalize?
if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module)
vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.prepare_communication_buffer_for_model = supa_prepare_communication_buffer_for_model

View File

@@ -0,0 +1,420 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# This file is a pure Python wrapper for the SCCL library.
# The main purpose is to use SCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls SCCL correctly, but `cupy` itself
# often gets stuck when initializing the SCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for SCCL. It is usually
# doable, but we often encounter issues related with succl versions, and need
# to switch between different versions of SCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of SCCL by
# changing the environment variable `VLLM_SCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Optional
import torch
from torch.distributed import ReduceOp
from vllm.logger import logger
from vllm_br import envs
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
succlResult_t = ctypes.c_int
succlComm_t = ctypes.c_void_p
class succlUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
suStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p
succlDataType_t = ctypes.c_int
class succlDataTypeEnum:
succlInt8 = 0
succlChar = 0
succlUint8 = 1
succlInt16 = 2
succlUint16 = 3
succlInt32 = 4
succlInt = 4
succlUint32 = 5
succlInt64 = 6
succlUint64 = 7
succlBfloat16 = 8
succlFloat32 = 9
succlFloat = 9
succlFloat64 = 10
succlDouble = 10
succlNumTypes = 11
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.succlInt8
if dtype == torch.uint8:
return cls.succlUint8
if dtype == torch.int32:
return cls.succlInt32
if dtype == torch.int64:
return cls.succlInt64
if dtype == torch.float16:
return cls.succlBfloat16
if dtype == torch.float32:
return cls.succlFloat32
if dtype == torch.float64:
return cls.succlFloat64
if dtype == torch.bfloat16:
return cls.succlBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
succlRedOp_t = ctypes.c_int
class succlRedOpTypeEnum:
succlSum = 0
succlProd = 1
succlMax = 2
succlMin = 3
succlAvg = 4
succlNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.succlSum
if op == ReduceOp.PRODUCT:
return cls.succlProd
if op == ReduceOp.MAX:
return cls.succlMax
if op == ReduceOp.MIN:
return cls.succlMin
if op == ReduceOp.AVG:
return cls.succlAvg
raise ValueError(f"Unsupported op: {op}")
@dataclass
class Function:
name: str
restype: Any
argtypes: list[Any]
class SCCLLibrary:
exported_functions = [
# const char* succlGetErrorString(succlResult_t result)
Function("succlGetErrorString", ctypes.c_char_p, [succlResult_t]),
# succlResult_t succlGetVersion(int *version);
Function("succlGetVersion", succlResult_t,
[ctypes.POINTER(ctypes.c_int)]),
# succlResult_t succlGetUniqueId(succlUniqueId* uniqueId);
Function("succlGetUniqueId", succlResult_t,
[ctypes.POINTER(succlUniqueId)]),
# succlResult_t succlCommInitRank(
# succlComm_t* comm, int nranks, succlUniqueId commId, int rank);
# note that succlComm_t is a pointer type, so the first argument
# is a pointer to a pointer
Function("succlCommInitRank", succlResult_t, [
ctypes.POINTER(succlComm_t), ctypes.c_int, succlUniqueId,
ctypes.c_int, ctypes.c_void_p
]),
# succlResult_t succlAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# succlDataType_t datatype, succlRedOp_t op, succlComm_t comm,
# suStream_t stream);
# note that suStream_t is a pointer type, so the last argument
# is a pointer
Function("succlAllReduce", succlResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
succlRedOp_t, succlComm_t, suStream_t, ctypes.c_void_p
]),
# succlResult_t succlReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# succlDataType_t datatype, succlRedOp_t op, int root,
# succlComm_t comm, suStream_t stream);
# note that suStream_t is a pointer type, so the last argument
# is a pointer
Function("succlReduce", succlResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
succlRedOp_t, ctypes.c_int, succlComm_t, suStream_t,
ctypes.c_void_p
]),
# succlResult_t succlAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# succlDataType_t datatype, succlComm_t comm,
# suStream_t stream);
# note that suStream_t is a pointer type, so the last argument
# is a pointer
Function("succlAllGather", succlResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
succlComm_t, suStream_t, ctypes.c_void_p
]),
# succlResult_t succlReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# succlDataType_t datatype, succlRedOp_t op, succlComm_t comm,
# suStream_t stream);
# note that suStream_t is a pointer type, so the last argument
# is a pointer
Function("succlReduceScatter", succlResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
succlRedOp_t, succlComm_t, suStream_t, ctypes.c_void_p
]),
# succlResult_t succlSend(
# const void* sendbuff, size_t count, succlDataType_t datatype,
# int dest, succlComm_t comm, suStream_t stream);
Function("succlSend", succlResult_t, [
buffer_type, ctypes.c_size_t, succlDataType_t, ctypes.c_int,
succlComm_t, suStream_t, ctypes.c_void_p
]),
# succlResult_t succlRecv(
# void* recvbuff, size_t count, succlDataType_t datatype,
# int src, succlComm_t comm, suStream_t stream);
Function("succlRecv", succlResult_t, [
buffer_type, ctypes.c_size_t, succlDataType_t, ctypes.c_int,
succlComm_t, suStream_t, ctypes.c_void_p
]),
# succlResult_t succlBroadcast(
# const void* sendbuff, void* recvbuff, size_t count,
# succlDataType_t datatype, int root, succlComm_t comm,
# suStream_t stream);
Function("succlBroadcast", succlResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
ctypes.c_int, succlComm_t, suStream_t, ctypes.c_void_p
]),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# succlResult_t succlCommDestroy(succlComm_t comm);
Function("succlCommDestroy", succlResult_t, [succlComm_t]),
# succlResult_t succlGroupStart();
Function("succlGroupStart", succlResult_t, []),
# succlResult_t succlGroupEnd();
Function("succlGroupEnd", succlResult_t, []),
# Function("succldemoSetdevice", succlResult_t, [ctypes.c_int]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_sccl_library()
try:
if so_file not in SCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
SCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = SCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
logger.error(
"Failed to load SCCL library from %s. "
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the sccl library might not exist, be corrupted "
"or it does not support the current platform %s. "
"If you already have the library, please set the "
"environment variable VLLM_SCCL_SO_PATH"
" to point to the correct sccl library path.", so_file,
platform.platform())
raise e
if so_file not in SCCLLibrary.path_to_dict_mapping:
_funcs: dict[str, Any] = {}
for func in SCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
SCCLLibrary.path_to_dict_mapping[so_file] = _funcs
self._funcs = SCCLLibrary.path_to_dict_mapping[so_file]
def succlGetErrorString(self, result: succlResult_t) -> str:
return self._funcs["succlGetErrorString"](result).decode("utf-8")
def SUCCL_CHECK(self, result: succlResult_t) -> None:
if result != 0:
error_str = self.succlGetErrorString(result)
raise RuntimeError(f"SCCL error: {error_str}")
def succlGetVersion(self) -> str:
version = ctypes.c_int()
self.SUCCL_CHECK(self._funcs["succlGetVersion"](ctypes.byref(version)))
version_str = str(version.value)
# something like 21903 --> "2.19.3"
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
def succlGetUniqueId(self) -> succlUniqueId:
unique_id = succlUniqueId()
self.SUCCL_CHECK(self._funcs["succlGetUniqueId"](
ctypes.byref(unique_id)))
return unique_id
def unique_id_from_bytes(self, data: bytes) -> succlUniqueId:
if len(data) != 128:
raise ValueError(
f"Expected 128 bytes for succlUniqueId, got {len(data)} bytes")
unique_id = succlUniqueId()
ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
return unique_id
def succlCommInitRank(self, world_size: int, unique_id: succlUniqueId,
rank: int) -> succlComm_t:
comm = succlComm_t()
result = self._funcs["succlCommInitRank"](ctypes.byref(comm),
world_size, unique_id, rank,
None)
self.SUCCL_CHECK(result)
return comm
# def succldemoSetdevice(self, deviceid:int):
# self._funcs["succldemoSetdevice"](deviceid)
def succlAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: succlComm_t,
stream: suStream_t) -> None:
# `datatype` actually should be `succlDataType_t`
# and `op` should be `succlRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.SUCCL_CHECK(self._funcs["succlAllReduce"](sendbuff, recvbuff,
count, datatype, op,
comm, stream, None))
def succlReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, root: int,
comm: succlComm_t, stream: suStream_t) -> None:
# `datatype` actually should be `succlDataType_t`
# and `op` should be `succlRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.SUCCL_CHECK(self._funcs["succlReduce"](sendbuff, recvbuff, count,
datatype, op, root, comm,
stream, None))
def succlReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int,
comm: succlComm_t, stream: suStream_t) -> None:
# `datatype` actually should be `succlDataType_t`
# and `op` should be `succlRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.SUCCL_CHECK(self._funcs["succlReduceScatter"](sendbuff, recvbuff,
count, datatype, op,
comm, stream, None))
def succlAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, comm: succlComm_t,
stream: suStream_t) -> None:
# `datatype` actually should be `succlDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.SUCCL_CHECK(self._funcs["succlAllGather"](sendbuff, recvbuff,
count, datatype, comm,
stream, None))
def succlSend(self, sendbuff: buffer_type, count: int, datatype: int,
dest: int, comm: succlComm_t, stream: suStream_t) -> None:
self.SUCCL_CHECK(self._funcs["succlSend"](sendbuff, count, datatype,
dest, comm, stream, None))
def succlRecv(self, recvbuff: buffer_type, count: int, datatype: int,
src: int, comm: succlComm_t, stream: suStream_t) -> None:
self.SUCCL_CHECK(self._funcs["succlRecv"](recvbuff, count, datatype,
src, comm, stream, None))
def succlBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, root: int, comm: succlComm_t,
stream: suStream_t) -> None:
self.SUCCL_CHECK(self._funcs["succlBroadcast"](sendbuff, recvbuff,
count, datatype, root,
comm, stream, None))
def succlCommDestroy(self, comm: succlComm_t) -> None:
self.SUCCL_CHECK(self._funcs["succlCommDestroy"](comm))
def succlGroupStart(self) -> None:
self.SUCCL_CHECK(self._funcs["succlGroupStart"]())
def succlGroupEnd(self) -> None:
self.SUCCL_CHECK(self._funcs["succlGroupEnd"]())
def find_sccl_library() -> str:
"""
We either use the library file specified by the `VLLM_SCCL_SO_PATH`
environment variable, or we find the library file brought by PyTorch.
After importing `torch`, `libsuccl.so.2` or `librccl.so.1` can be
found by `ctypes` automatically.
"""
so_file = envs.VLLM_SCCL_SO_PATH
# manually load the sccl library
if so_file:
logger.info(
"Found sccl from environment variable VLLM_SCCL_SO_PATH=%s",
so_file)
else:
raise ValueError("SCCL lib file not found.")
return so_file
__all__ = [
"SCCLLibrary", "succlDataTypeEnum", "succlRedOpTypeEnum", "succlUniqueId",
"succlComm_t", "suStream_t", "buffer_type"
]

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import kv_connector # noqa: F401

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import v1 # noqa: F401

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import base, p2p # noqa: F401

View File

@@ -0,0 +1,28 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# from vllm.logger import logger
# from vllm.v1.core.sched.output import SchedulerOutput
# from vllm.v1.outputs import KVConnectorOutput
# class KVConnectorRole(enum.Enum):
# # Connector running in the scheduler process
# SCHEDULER = 0
# # Connector running in the worker process
# WORKER = 1
# vllm.distributed.kv_transfer.kv_connector.v1.base.KVConnectorRole=KVConnectorRole

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import p2p_succl_engine # noqa: F401
from . import p2p_succl_connector, tensor_memory_pool # noqa: F401

View File

@@ -0,0 +1,535 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
import torch_br
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_world_group
from vllm.logger import logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm_br.distributed.kv_transfer.kv_connector.v1.p2p.p2p_succl_engine import (
P2pSucclEngine)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
@dataclass
class ReqMeta:
# Request Id
request_id: str
# Request block ids
block_ids: torch.Tensor
# Request num tokens
num_tokens: int
@staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
block_size: int) -> "ReqMeta":
block_ids_tensor = torch.tensor(block_ids)
return ReqMeta(
request_id=request_id,
block_ids=block_ids_tensor,
num_tokens=len(token_ids),
)
@dataclass
class P2pSucclConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta]
def __init__(self):
self.requests = []
def add_request(
self,
request_id: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
) -> None:
self.requests.append(
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size))
class P2pSucclConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {}
self.config = vllm_config.kv_transfer_config
self.is_producer = self.config.is_kv_producer
self.chunked_prefill: dict[str, Any] = {}
self._rank = get_world_group().rank \
if role == KVConnectorRole.WORKER else 0
self._local_rank = get_world_group().local_rank \
if role == KVConnectorRole.WORKER else 0
self.p2p_nccl_engine = P2pSucclEngine(
local_rank=self._local_rank,
config=self.config,
hostname="",
port_offset=self._rank,
) if role == KVConnectorRole.WORKER else None
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
# Only consumer/decode loads KV Cache
if self.is_producer:
return
assert self.p2p_nccl_engine is not None
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
def inject_kv_into_layer(
layer: torch.Tensor,
kv_cache: torch.Tensor,
block_ids: torch.Tensor,
request_id: str,
) -> None:
"""
Inject KV cache data into a given attention layer tensor.
This function updates `layer` in-place with values from `kv_cache`,
handling different backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
If the number of provided block IDs does not match the number of KV
blocks, only the overlapping portion is updated, and a warning is
logged.
Args:
layer (torch.Tensor): The attention layer KV tensor to update.
kv_cache (torch.Tensor): The KV cache tensor to inject.
block_ids (torch.Tensor): Indices of the blocks to update.
request_id (str): Request identifier used for logging.
Returns:
None. The function modifies `layer` in-place.
"""
if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer
num_block = kv_cache.shape[1]
block_len = min(len(block_ids), num_block)
block_ids = block_ids[:block_len]
th_gran = layer.shape[2] // self._block_size
for i, block_index in enumerate(block_ids.tolist()):
dst_0 = block_index // th_gran
dst_1 = (block_index % th_gran) * self._block_size
layer[0][dst_0][dst_1:dst_1 +
self._block_size] = kv_cache[0][i]
elif layer.shape[0] == 2: # FlashAttention
num_block = kv_cache.shape[1]
block_len = min(len(block_ids), num_block)
block_ids = block_ids[:block_len]
th_gran = layer.shape[2] // self._block_size
for i, block_index in enumerate(block_ids.tolist()):
dst_0 = block_index // th_gran
dst_1 = (block_index % th_gran) * self._block_size
layer[0][dst_0][dst_1:dst_1 +
self._block_size] = kv_cache[0][i]
layer[1][dst_0][dst_1:dst_1 +
self._block_size] = kv_cache[1][i]
# Get the metadata
metadata: KVConnectorMetadata = \
self._get_connector_metadata()
assert isinstance(metadata, P2pSucclConnectorMetadata)
if metadata is None:
return
# Load the KV for each request each layer
for request in metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank)
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE
kv_cache = getattr(layer, 'kv_cache', None)
if kv_cache is None:
continue
layer = kv_cache[forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name, remote_address)
if kv_cache is None:
logger.warning("🚧kv_cache is None, %s", request.request_id)
continue
inject_kv_into_layer(layer, kv_cache, request.block_ids,
request.request_id)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
# Only producer/prefill saves KV Cache
if not self.is_producer:
return
assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
block_ids: torch.Tensor,
) -> torch.Tensor:
"""
Extract KV cache slices from a given attention layer tensor.
This function handles multiple backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
Args:
layer (torch.Tensor): The KV cache from the attention layer.
block_ids (torch.Tensor): Indices of blocks to extract.
Returns:
torch.Tensor: A tensor containing the extracted KV slices.
Returns None if the layout is unsupported.
"""
if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer
origin_shape = layer.shape
shape = [
origin_shape[0],
len(block_ids), self._block_size, origin_shape[3]
]
layer_send = torch_br._empty_ut_only(shape,
dtype=layer.dtype,
tensor_type='BUFFER_ANY',
device=layer.device)
th_gran = origin_shape[2] // self._block_size
for i, block_index in enumerate(block_ids.tolist()):
dst_0 = block_index // th_gran
dst_1 = (block_index % th_gran) * self._block_size
layer_send[0][i] = layer[0][dst_0][dst_1:dst_1 +
self._block_size]
return layer_send
if layer.shape[0] == 2: # FlashAttention
origin_shape = layer.shape
shape = [
origin_shape[0],
len(block_ids), self._block_size, origin_shape[3]
]
layer_send = torch_br._empty_ut_only(shape,
dtype=layer.dtype,
tensor_type='BUFFER_ANY',
device=layer.device)
th_gran = origin_shape[2] // self._block_size
for i, block_index in enumerate(block_ids.tolist()):
dst_0 = block_index // th_gran
dst_1 = (block_index % th_gran) * self._block_size
layer_send[0][i] = layer[0][dst_0][dst_1:dst_1 +
self._block_size]
layer_send[1][i] = layer[1][dst_0][dst_1:dst_1 +
self._block_size]
return layer_send
return None
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pSucclConnectorMetadata)
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
def wait_for_save(self):
if self.is_producer:
assert self.p2p_nccl_engine is not None
self.p2p_nccl_engine.wait_for_sent()
def get_finished(
self, finished_req_ids: set[str],
**kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
assert self.p2p_nccl_engine is not None
no_compile_layers = (
self._vllm_config.compilation_config.static_forward_context)
return self.p2p_nccl_engine.get_finished(finished_req_ids,
no_compile_layers)
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if self.is_producer:
return 0, False
num_external_tokens = (len(request.prompt_token_ids) - 1 -
num_computed_tokens)
if num_external_tokens < 0:
num_external_tokens = 0
return num_external_tokens, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
"""
if not self.is_producer and num_external_tokens > 0:
self._requests_need_load[request.request_id] = (
request, blocks.get_block_ids()[0])
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = P2pSucclConnectorMetadata()
for new_req in scheduler_output.scheduled_new_reqs:
if self.is_producer:
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[new_req.req_id]
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
# the request's prompt is chunked prefill
if num_tokens < len(new_req.prompt_token_ids):
# 'CachedRequestData' has no attribute 'prompt_token_ids'
self.chunked_prefill[new_req.req_id] = (
new_req.block_ids[0], new_req.prompt_token_ids)
continue
# the request's prompt is not chunked prefill
meta.add_request(request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size)
continue
if new_req.req_id in self._requests_need_load:
meta.add_request(request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size)
self._requests_need_load.pop(new_req.req_id)
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
if self.is_producer:
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_computed_tokens)
assert req_id in self.chunked_prefill
block_ids = new_block_ids[0]
if not resumed_from_preemption:
block_ids = (self.chunked_prefill[req_id][0] + block_ids)
prompt_token_ids = self.chunked_prefill[req_id][1]
# the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids):
self.chunked_prefill[req_id] = (block_ids,
prompt_token_ids)
continue
# the request's prompt is all prefilled finally
meta.add_request(request_id=req_id,
token_ids=prompt_token_ids,
block_ids=block_ids,
block_size=self._block_size)
self.chunked_prefill.pop(req_id, None)
continue
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not resumed_from_preemption:
break
if req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(req_id)
total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = new_block_ids[0]
meta.add_request(request_id=req_id,
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size)
self._requests_need_load.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
self.chunked_prefill.pop(request.request_id, None)
return False, None
# ==============================
# Static methods
# ==============================
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
# Regular expression to match the string hostname and integer port
if is_prefill:
pattern = r"___decode_addr_(.*):(\d+)"
else:
pattern = r"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match = re.search(pattern, request_id)
if match:
# Extract the ranks
ip = match.group(1)
port = int(match.group(2))
return ip, port
raise ValueError(
f"Request id {request_id} does not contain hostname and port")
@staticmethod
def check_tensors_except_dim(tensor1, tensor2, dim):
shape1 = tensor1.size()
shape2 = tensor2.size()
if len(shape1) != len(shape2) or not all(
s1 == s2
for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim):
raise NotImplementedError(
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
"and others will be supported in future PRs.")
KVConnectorFactory.register_connector(
"P2pSucclConnector",
"vllm_br.distributed.kv_transfer.kv_connector.v1.p2p.p2p_succl_connector",
"P2pSucclConnector")

View File

@@ -0,0 +1,572 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import os
import threading
import time
import typing
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Optional
import msgpack
import torch
import torch_br
import zmq
from torch_br.supa._internal import get_tensor_info
from vllm.config import KVTransferConfig
# import vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine
from vllm.utils import get_ip
from vllm_br.distributed.device_communicators.pysccl_wrapper import (
SCCLLibrary, buffer_type, succlComm_t, succlDataTypeEnum, suStream_t)
from vllm_br.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool)
from vllm_br.platform import SUPAPlatform
logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 1
@contextmanager
def set_p2p_succl_context(num_channels: str):
original_values: dict[str, Any] = {}
env_vars = [
'SUCCL_MAX_NCHANNELS',
'SUCCL_MIN_NCHANNELS',
'SUCCL_CUMEM_ENABLE',
'SUCCL_BUFFSIZE',
'SUCCL_PROTO', # LL,LL128,SIMPLE
'SUCCL_ALGO', # RING,TREE
]
for var in env_vars:
original_values[var] = os.environ.get(var)
logger.info("set_p2p_succl_context, original_values: %s", original_values)
try:
os.environ['SUCCL_MAX_NCHANNELS'] = num_channels
os.environ['SUCCL_MIN_NCHANNELS'] = num_channels
os.environ['SUCCL_CUMEM_ENABLE'] = '1'
yield
finally:
for var in env_vars:
if original_values[var] is not None:
os.environ[var] = original_values[var]
else:
os.environ.pop(var, None)
@dataclass
class SendQueueItem:
tensor_id: str
remote_address: str
tensor: torch.Tensor
class P2pSucclEngine:
def __init__(self,
local_rank: int,
config: KVTransferConfig,
hostname: str = "",
port_offset: int = 0,
library_path: Optional[str] = None) -> None:
self.config = config
self.rank = port_offset
self.local_rank = local_rank
self.device = torch.device(f"supa:{self.local_rank}")
if config is not None:
device_cursor = self.config.get_from_extra_config(
"device_cursor", 0)
self.device = torch.device(
f"supa:{self.local_rank + int(device_cursor)}")
SUPAPlatform.set_device(self.device)
self.succl = SCCLLibrary(library_path)
if not hostname:
hostname = get_ip()
port = int(self.config.kv_port) + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
self._hostname = hostname
self._port = port
# Each card corresponds to a ZMQ address.
self.zmq_address = f"{self._hostname}:{self._port}"
# The `http_port` must be consistent with the port of OpenAI.
self.http_address = (
f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}")
# If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled.
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.bind(f"tcp://{self.zmq_address}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
self.send_store_cv = threading.Condition()
self.send_queue_cv = threading.Condition()
self.recv_store_cv = threading.Condition()
self.send_stream = torch_br.supa.Stream()
self.recv_stream = self.send_stream
mem_pool_size_gb = float(
self.config.get_from_extra_config("mem_pool_size_gb",
DEFAULT_MEM_POOL_SIZE_GB))
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb *
1024**3)) # GB
# The sending type includes tree mutually exclusive options:
# PUT, GET, PUT_ASYNC.
self.send_type = self.config.get_from_extra_config(
"send_type", "PUT_ASYNC")
if self.send_type == "GET":
# tensor_id: torch.Tensor
self.send_store: dict[str, torch.Tensor] = {}
else:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self.send_queue: deque[SendQueueItem] = deque()
if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self.send_async,
daemon=True)
self._send_thread.start()
# tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {}
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.socks: dict[str, Any] = {} # remote_address: client socket
self.comms: dict[str, Any] = {} # remote_address: (succlComm_t, rank)
self.buffer_size = 0
self.buffer_size_threshold = float(self.config.kv_buffer_size)
self.succl_num_channels = self.config.get_from_extra_config(
"nccl_num_channels", "8")
self._listener_thread = threading.Thread(
target=self.listen_for_requests, daemon=True)
self._listener_thread.start()
self._ping_thread = None
if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self.ping, daemon=True)
self._ping_thread.start()
logger.info(
"💯P2pSucclEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
"threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank,
self.http_address, self.zmq_address, self.proxy_address,
self.send_type, self.buffer_size_threshold,
self.succl_num_channels)
def create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None
if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
sock.connect(f"tcp://{remote_address}")
self.socks[remote_address] = sock
if remote_address in self.comms:
logger.info("👋comm exists, remote_address:%s, comms:%s",
remote_address, self.comms)
return sock, self.comms[remote_address]
unique_id = self.succl.succlGetUniqueId()
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
sock.send(msgpack.dumps(data))
rank = 0
SUPAPlatform.set_device(self.device)
comm: succlComm_t = self.succl.succlCommInitRank(
2, unique_id, rank)
self.comms[remote_address] = (comm, rank)
logger.info("🤝succlCommInitRank Success, %s👉%s, MyRank:%s",
self.zmq_address, remote_address, rank)
return self.socks[remote_address], self.comms[remote_address]
def send_tensor(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
item = SendQueueItem(tensor_id=tensor_id,
remote_address=remote_address,
tensor=tensor)
if self.send_type == "PUT":
return self.send_sync(item)
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
# GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
if tensor_size > self.buffer_size_threshold:
logger.warning(
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
"buffer size threshold :%d, skip send to %s, rank:%d",
tensor_id, tensor_size, self.buffer_size_threshold,
remote_address, self.rank)
return False
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
assert len(self.send_store) > 0
oldest_tensor_id = next(iter(self.send_store))
oldest_tensor = self.send_store.pop(oldest_tensor_id)
oldest_tensor_size = oldest_tensor.element_size(
) * oldest_tensor.numel()
self.buffer_size -= oldest_tensor_size
logger.debug(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, self.buffer_size,
oldest_tensor_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address,
tensor_id, tensor_size, tensor.shape, self.rank,
self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True
def recv_tensor(
self,
tensor_id: str,
remote_address: typing.Optional[str] = None,
) -> torch.Tensor:
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
start_time = time.time()
with self.recv_store_cv:
while tensor_id not in self.recv_store:
self.recv_store_cv.wait()
tensor = self.recv_store[tensor_id]
if tensor is not None:
if isinstance(tensor, tuple):
addr, dtype, shape = tensor
tensor = self.pool.load_tensor(addr, dtype, shape,
self.device)
else:
self.buffer_size -= (tensor.element_size() *
tensor.numel())
else:
duration = time.time() - start_time
logger.warning(
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
"rank:%d", remote_address, tensor_id, duration * 1000,
self.rank)
return tensor
# GET
if remote_address is None:
return None
if remote_address not in self.socks:
self.create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
data = {"cmd": "GET", "tensor_id": tensor_id}
sock.send(msgpack.dumps(data))
message = sock.recv()
data = msgpack.loads(message)
if data["ret"] != 0:
logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
remote_address, tensor_id, data["ret"])
return None
with torch_br.supa.stream(self.recv_stream):
tensor = torch_br._empty_ut_only(data["shape"],
dtype=getattr(
torch, data["dtype"]),
tensor_type='BUFFER_ANY',
device=self.device)
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
return tensor
def listen_for_requests(self):
while True:
socks = dict(self.poller.poll())
if self.router_socket not in socks:
continue
remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.succl.unique_id_from_bytes(
bytes(data["unique_id"]))
rank = 1
SUPAPlatform.set_device(self.device)
comm: succlComm_t = self.succl.succlCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info("🤝suclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
try:
with torch_br.supa.stream(self.recv_stream):
tensor = torch_br._empty_ut_only(
data["shape"],
dtype=getattr(torch, data["dtype"]),
tensor_type='BUFFER_ANY',
device=self.device)
self.router_socket.send_multipart([remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()]
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d", self.zmq_address,
remote_address.decode(), data, addr)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart([remote_address, b"1"])
tensor = None
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address, remote_address.decode(),
data)
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
tensor = self.send_store.pop(tensor_id, None)
if tensor is not None:
data = {
"ret": 0,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", ""),
"tensor_type": get_tensor_info(tensor)[0]['layout']
}
# LRU
self.send_store[tensor_id] = tensor
self.have_sent_tensor_id(tensor_id)
else:
data = {"ret": 1}
self.router_socket.send_multipart(
[remote_address, msgpack.dumps(data)])
if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()]
self.send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream)
else:
logger.warning(
"🚧Unexpected, Received message from %s, data:%s",
remote_address, data)
def have_sent_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.send_request_id_to_tensor_ids:
self.send_request_id_to_tensor_ids[request_id] = set()
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
def have_received_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.recv_request_id_to_tensor_ids:
self.recv_request_id_to_tensor_ids[request_id] = set()
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
def send_async(self):
while True:
with self.send_queue_cv:
while not self.send_queue:
self.send_queue_cv.wait()
item = self.send_queue.popleft()
if not self.send_queue:
self.send_queue_cv.notify()
self.send_sync(item)
def wait_for_sent(self):
if self.send_type == "PUT_ASYNC":
start_time = time.time()
with self.send_queue_cv:
while self.send_queue:
self.send_queue_cv.wait()
duration = time.time() - start_time
logger.debug(
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank)
def send_sync(self, item: SendQueueItem) -> bool:
if item.remote_address is None:
return False
if item.remote_address not in self.socks:
self.create_connect(item.remote_address)
tensor = item.tensor
sock = self.socks[item.remote_address]
comm, rank = self.comms[item.remote_address]
data = {
"cmd": "PUT",
"tensor_id": item.tensor_id,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", ""),
"tensor_type": get_tensor_info(tensor)[0]['layout']
}
sock.send(msgpack.dumps(data))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, item.remote_address, rank, data,
tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3,
response.decode())
return False
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self.have_sent_tensor_id(item.tensor_id)
return True
def get_finished(
self, finished_req_ids: set[str], no_compile_layers
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
# Clear the buffer upon request completion.
for request_id in finished_req_ids:
for layer_name in no_compile_layers:
tensor_id = request_id + "#" + layer_name
if tensor_id in self.recv_store:
with self.recv_store_cv:
tensor = self.recv_store.pop(tensor_id, None)
self.send_request_id_to_tensor_ids.pop(
request_id, None)
self.recv_request_id_to_tensor_ids.pop(
request_id, None)
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.pool.free(addr)
# TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set()
# TODO:Retrieve requests that have already received the KV cache.
finished_recving: set[str] = set()
return finished_sending or None, finished_recving or None
def ping(self):
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}")
data = {
"type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"zmq_address": self.zmq_address
}
while True:
sock.send(msgpack.dumps(data))
time.sleep(3)
def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, (
f"this succl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch_br.supa.Stream()
with torch_br.supa.stream(stream):
self.succl.succlSend(buffer_type(tensor.data_ptr()),
tensor.numel(),
succlDataTypeEnum.from_torch(tensor.dtype),
dst, comm, suStream_t(stream.supa_stream))
stream.synchronize()
def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
assert tensor.device == self.device, (
f"this succl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch_br.supa.Stream()
with torch_br.supa.stream(stream):
self.succl.succlRecv(buffer_type(tensor.data_ptr()),
tensor.numel(),
succlDataTypeEnum.from_torch(tensor.dtype),
src, comm, suStream_t(stream.supa_stream))
stream.synchronize()
def close(self) -> None:
self._listener_thread.join()
if self.send_type == "PUT_ASYNC":
self._send_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()

View File

@@ -0,0 +1,280 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import ctypes
import math
from dataclasses import dataclass
import torch
from vllm.logger import logger
@dataclass
class MemoryBlock:
size: int
addr: int
"""A memory pool for managing pinned host memory allocations for tensors.
This class implements a buddy allocation system to efficiently manage pinned
host memory for tensor storage. It supports allocation, deallocation, and
tensor storage/retrieval operations.
Key Features:
- Uses power-of-two block sizes for efficient buddy allocation
- Supports splitting and merging of memory blocks
- Provides methods to store CUDA tensors in pinned host memory
- Allows loading tensors from pinned memory back to device
- Automatically cleans up memory on destruction
Attributes:
max_block_size (int): Maximum block size (rounded to nearest power of two)
min_block_size (int): Minimum block size (rounded to nearest power of two)
free_lists (dict): Dictionary of free memory blocks by size
allocated_blocks (dict): Dictionary of currently allocated blocks
base_tensor (torch.Tensor): Base pinned memory tensor
base_address (int): Base memory address of the pinned memory region
Example:
>>> pool = TensorMemoryPool(max_block_size=1024*1024)
>>> tensor = torch.randn(100, device='cuda')
>>> addr = pool.store_tensor(tensor)
>>> loaded_tensor = pool.load_tensor(addr, tensor.dtype,
... tensor.shape, 'cuda')
>>> pool.free(addr)
"""
class TensorMemoryPool:
"""Initializes the memory pool with given size constraints.
Args:
max_block_size (int): Maximum size of memory blocks to manage
min_block_size (int, optional): Minimum size of memory blocks
to manage. Defaults to 512.
Raises:
ValueError: If block sizes are invalid or max_block_size is less
than min_block_size
"""
def __init__(self, max_block_size: int, min_block_size: int = 512):
if max_block_size <= 0 or min_block_size <= 0:
raise ValueError("Block sizes must be positive")
if max_block_size < min_block_size:
raise ValueError(
"Max block size must be greater than min block size")
self.max_block_size = self._round_to_power_of_two(max_block_size)
self.min_block_size = self._round_to_power_of_two(min_block_size)
self.free_lists: dict[int, dict[int, MemoryBlock]] = {}
self.allocated_blocks: dict[int, MemoryBlock] = {}
self._initialize_free_lists()
self._allocate_pinned_memory()
atexit.register(self.cleanup)
def _round_to_power_of_two(self, size: int) -> int:
return 1 << (size - 1).bit_length()
def _initialize_free_lists(self):
size = self.max_block_size
while size >= self.min_block_size:
self.free_lists[size] = {}
size //= 2
def _allocate_pinned_memory(self):
self.base_tensor = torch.empty(self.max_block_size // 4,
dtype=torch.float32,
pin_memory=True)
self.base_address = self.base_tensor.data_ptr()
initial_block = MemoryBlock(size=self.max_block_size,
addr=self.base_address)
self.free_lists[self.max_block_size][
initial_block.addr] = initial_block
logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d",
self.base_address, self.max_block_size)
def allocate(self, size: int) -> int:
"""Allocates a memory block of at least the requested size.
Args:
size (int): Minimum size of memory to allocate
Returns:
int: Address of the allocated memory block
Raises:
ValueError: If size is invalid or insufficient memory is available
"""
if size <= 0:
raise ValueError("Allocation size must be positive")
required_size = self._round_to_power_of_two(
max(size, self.min_block_size))
if required_size > self.max_block_size:
raise ValueError("Requested size exceeds maximum block size")
current_size = required_size
while current_size <= self.max_block_size:
if self.free_lists[current_size]:
_, block = self.free_lists[current_size].popitem()
self._split_block(block, required_size)
self.allocated_blocks[block.addr] = block
return block.addr
current_size *= 2
raise ValueError("Insufficient memory")
def _split_block(self, block: MemoryBlock, required_size: int):
while (block.size > required_size
and block.size // 2 >= self.min_block_size):
buddy_size = block.size // 2
buddy_addr = block.addr + buddy_size
buddy = MemoryBlock(size=buddy_size, addr=buddy_addr)
block.size = buddy_size
self.free_lists[buddy_size][buddy.addr] = buddy
def free(self, addr: int):
"""Frees an allocated memory block.
Args:
addr (int): Address of the block to free
Raises:
ValueError: If address is invalid or not allocated
"""
if addr not in self.allocated_blocks:
raise ValueError("Invalid address to free")
block = self.allocated_blocks.pop(addr)
self._merge_buddies(block)
def _merge_buddies(self, block: MemoryBlock):
MAX_MERGE_DEPTH = 30
depth = 0
while depth < MAX_MERGE_DEPTH:
buddy_offset = block.size if (block.addr - self.base_address) % (
2 * block.size) == 0 else -block.size
buddy_addr = block.addr + buddy_offset
buddy = self.free_lists[block.size].get(buddy_addr)
if buddy:
del self.free_lists[buddy.size][buddy.addr]
merged_addr = min(block.addr, buddy.addr)
merged_size = block.size * 2
block = MemoryBlock(size=merged_size, addr=merged_addr)
depth += 1
else:
break
self.free_lists[block.size][block.addr] = block
def store_tensor(self, tensor: torch.Tensor) -> int:
"""Stores a CUDA tensor in pinned host memory.
Args:
tensor (torch.Tensor): CUDA tensor to store
Returns:
int: Address where the tensor is stored
Raises:
ValueError: If tensor is not on CUDA or allocation fails
"""
if not tensor.is_cuda:
raise ValueError("Only CUDA tensors can be stored")
size = tensor.element_size() * tensor.numel()
addr = self.allocate(size)
block = self.allocated_blocks[addr]
if block.size < size:
self.free(addr)
raise ValueError(
f"Allocated block size {block.size} is smaller than "
f"required size {size}")
try:
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
cpu_tensor = torch.frombuffer(buffer,
dtype=tensor.dtype,
count=tensor.numel()).reshape(
tensor.shape)
except ValueError as err:
self.free(addr)
raise ValueError(f"Failed to create tensor view: {err}") from err
cpu_tensor.copy_(tensor)
return addr
def load_tensor(self, addr: int, dtype: torch.dtype,
shape: tuple[int, ...], device) -> torch.Tensor:
"""Loads a tensor from pinned host memory to the specified device.
Args:
addr (int): Address where tensor is stored
dtype (torch.dtype): Data type of the tensor
shape (tuple[int, ...]): Shape of the tensor
device: Target device for the loaded tensor
Returns:
torch.Tensor: The loaded tensor on the specified device
Raises:
ValueError: If address is invalid or sizes don't match
"""
if addr not in self.allocated_blocks:
raise ValueError("Invalid address to load")
block = self.allocated_blocks[addr]
num_elements = math.prod(shape)
dtype_size = torch.tensor([], dtype=dtype).element_size()
required_size = num_elements * dtype_size
if required_size > block.size:
raise ValueError("Requested tensor size exceeds block size")
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
cpu_tensor = torch.frombuffer(buffer, dtype=dtype,
count=num_elements).reshape(shape)
cuda_tensor = torch.empty(shape, dtype=dtype, device=device)
cuda_tensor.copy_(cpu_tensor)
return cuda_tensor
def cleanup(self):
"""Cleans up all memory resources and resets the pool state."""
self.free_lists.clear()
self.allocated_blocks.clear()
if hasattr(self, 'base_tensor'):
del self.base_tensor
def __del__(self):
self.cleanup()

View File

@@ -0,0 +1,473 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
import torch.distributed
import torch_br
import vllm
import vllm.distributed.parallel_state
from vllm.distributed import GroupCoordinator
from vllm.distributed.parallel_state import (_WORLD, TensorMetadata,
_split_tensor_dict, get_pp_group,
get_tp_group, get_world_group,
init_model_parallel_group, logger)
from vllm_br import envs
@dataclass
class GraphCaptureContext:
stream: torch_br.supa.Stream
@contextmanager
#@patch_to(GroupCoordinator.graph_capture)
def graph_capture_(self,
graph_capture_context: Optional[GraphCaptureContext] = None
):
if graph_capture_context is None:
stream = torch_br.supa.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
# only supa uses this function,
# so we don't abstract it into the base class
#maybe_ca_context = nullcontext()
#from vllm_br.distributed.communicator import SUPACommunicator
#if self.device_communicator is not None:
# assert isinstance(self.device_communicator, SUPACommunicator)
# ca_comm = self.device_communicator.ca_comm
# if ca_comm is not None:
# maybe_ca_context = ca_comm.capture() # type: ignore
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch_br.supa.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch_br.supa.stream(stream):
yield graph_capture_context
vllm.distributed.parallel_state.GroupCoordinator.graph_capture = graph_capture_
@contextmanager
#@patch_to(graph_capture)
def graph_capture_supa(device: torch.device):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the SUPA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current SUPA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
context = GraphCaptureContext(torch_br.supa.Stream(device=device))
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
context):
yield context
vllm.distributed.parallel_state.graph_capture = graph_capture_supa
def is_global_first_rank() -> bool:
"""
Check if the current process is the first rank globally across all
parallelism strategies (PP, TP, DP, EP, etc.).
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
or `get_pp_group().is_first_rank`, this function checks the global rank
across all parallelism dimensions.
Returns:
bool: True if this is the global first rank (rank 0), False otherwise.
Returns True if distributed is not initialized (single process).
"""
try:
# If world group is available, use it for the most accurate check
if _WORLD is not None:
return _WORLD.is_first_rank
# If torch distributed is not initialized, assume single process
if not torch.distributed.is_initialized():
return True
# Fallback to torch's global rank
return torch.distributed.get_rank() == 0
except Exception:
# If anything goes wrong, assume this is the first rank
return True
def generate_multi_node_parallel_groups(
total_procs: int,
tp_size: int,
pp_size: int,
dp_size: int,
) -> dict:
if total_procs == 16 and tp_size == 8 and pp_size == 2 and dp_size == 1:
tp_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
pp_groups = [[0, 4], [1, 5], [2, 6], [3, 7], [8, 12], [9, 13],
[10, 14], [11, 15]]
dp_groups = [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10],
[11], [12], [13], [14], [15]]
ep_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
else:
raise ValueError(
"Unsupported VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE parallel config of"
" tp_size: {tp_size} pp_size: {pp_size} dp_size: {dp_size}"
"Currently only 'tp8pp2dp1' is allowed.")
return {
"tp_groups": tp_groups,
"pp_groups": pp_groups,
"dp_groups": dp_groups,
"ep_groups": ep_groups,
}
# sync v0.11 api update, while code logic possibly need sync with vllm original code implementation
def initialize_model_parallel_cross_tp(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None,
) -> None:
"""
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
backend: name of torch distributed communication backend.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
data_parallel_size = 1
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size
# the layout order is: ExternalDP x DP x PP x TP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
# DP is the data parallel group that is part of the model,
# all the ranks in the same DP group should generate simultaneously,
# i.e. the `generate` call in the same DP group should be called together,
# otherwise it will cause deadlock.
# to get group_ranks for each dimension, transpose that dimension to the
# last dimension, then reshape to 2D, then unbind the last dimension
all_ranks = torch.arange(world_size).reshape(
-1, data_parallel_size, pipeline_model_parallel_size,
tensor_model_parallel_size) # noqa
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
groups = generate_multi_node_parallel_groups(
world_size, tensor_model_parallel_size,
pipeline_model_parallel_size, data_parallel_size)
logger.info("supernode reorganized groups: %s", groups)
# Build the tensor model-parallel groups.
assert vllm.distributed.parallel_state._TP is None, (
"tensor model parallel group is already initialized")
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
group_ranks = groups['tp_groups']
else:
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group
vllm.distributed.parallel_state._TP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="tp")
# Build the DCP model-parallel groups.
# global _DCP
assert vllm.distributed.parallel_state._DCP is None, (
"decode context model parallel group is already initialized")
# Note(hc): In the current implementation of decode context parallel,
# dcp_size must not exceed tp_size, because the world size does not
# change by DCP, it simply reuses the GPUs of TP group, and split one
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(
-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
vllm.distributed.parallel_state._DCP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="dcp")
# Build the pipeline model-parallel groups.
assert vllm.distributed.parallel_state._PP is None, (
"pipeline model parallel group is already initialized")
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
group_ranks = groups['pp_groups']
else:
group_ranks = all_ranks.transpose(2, 3).reshape(
-1, pipeline_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
vllm.distributed.parallel_state._PP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pp")
assert vllm.distributed.parallel_state._DP is None, (
"data parallel group is already initialized")
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
group_ranks = groups['dp_groups']
else:
group_ranks = all_ranks.transpose(1, 3).reshape(
-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
vllm.distributed.parallel_state._DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp")
assert vllm.distributed.parallel_state._EP is None, (
"expert parallel group is already initialized")
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
group_ranks = groups['ep_groups']
else:
group_ranks = all_ranks.transpose(1, 2).reshape(
-1, data_parallel_size * tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
vllm.distributed.parallel_state._EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep")
logger.info(
"rank %s in world size %s is assigned as (br) "
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
vllm.distributed.parallel_state._DP.rank_in_group,
vllm.distributed.parallel_state._PP.rank_in_group,
vllm.distributed.parallel_state._TP.rank_in_group,
vllm.distributed.parallel_state._EP.rank_in_group)
vllm.distributed.parallel_state.initialize_model_parallel = initialize_model_parallel_cross_tp
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: Optional[dict[str, bool]] = None,
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
all_gather_group: The group for the all-gather operation. If provided,
an optimization is enabled where each rank in the group sends a
slice of a tensor and the receiver reconstructs it using an
all-gather, which can improve performance. This is typically the
tensor-parallel group.
all_gather_tensors: A dictionary to specify which tensors should use
the all-gather optimization, which is only effective when
`all_gather_group` is provided. By default, this optimization is
on for any tensor whose size is divisible by the
`all_gather_group`'s world size. However, it should be disabled
for tensors that are not fully replicated across the group (e.g.,
the residual tensor when sequence parallelism is enabled). This
dictionary allows overriding the default behavior on a per-tensor
basis.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
self.device_communicator.send_tensor_dict( # type: ignore
tensor_dict, dst)
return None
metadata_list: list[tuple[Any, Any]] = []
assert isinstance(tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
tensor_keys = [
k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)
]
assert len(tensor_keys) == len(tensor_list)
for key, tensor in zip(tensor_keys, tensor_list):
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
if all_gather_tensors else use_all_gather
if use_all_gather:
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor,
dst=self.ranks[dst],
group=metadata_group)
else:
# ensure tensor is ready
torch.supa.synchronize()
# use group for GPU tensors
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None
def recv_tensor_dict(
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: Optional[dict[str, bool]] = None,
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
all_gather_group: The group for the all-gather operation. If provided,
an optimization is enabled where each rank in the group sends a
slice of a tensor and the receiver reconstructs it using an
all-gather, which can improve performance. This is typically the
tensor-parallel group.
all_gather_tensors: A dictionary to specify which tensors should use
the all-gather optimization, which is only effective when
`all_gather_group` is provided. By default, this optimization is
on for any tensor whose size is divisible by the
`all_gather_group`'s world size. However, it should be disabled
for tensors that are not fully replicated across the group (e.g.,
the residual tensor when sequence parallelism is enabled). This
dictionary allows overriding the default behavior on a per-tensor
basis.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.recv_tensor_dict( # type: ignore
src)
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
if all_gather_tensors else use_all_gather
if use_all_gather:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
# ensure recv is done
torch.supa.synchronize()
if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0)
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict
vllm.distributed.GroupCoordinator.send_tensor_dict = send_tensor_dict
vllm.distributed.GroupCoordinator.recv_tensor_dict = recv_tensor_dict