first commit
This commit is contained in:
17
vllm_br/distributed/__init__.py
Normal file
17
vllm_br/distributed/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from . import device_communicators # noqa: F401
|
||||
from . import kv_transfer # noqa: F401
|
||||
BIN
vllm_br/distributed/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/distributed/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/distributed/__pycache__/communicator.cpython-310.pyc
Normal file
BIN
vllm_br/distributed/__pycache__/communicator.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/distributed/__pycache__/parallel_state.cpython-310.pyc
Normal file
BIN
vllm_br/distributed/__pycache__/parallel_state.cpython-310.pyc
Normal file
Binary file not shown.
60
vllm_br/distributed/communicator.py
Normal file
60
vllm_br/distributed/communicator.py
Normal file
@@ -0,0 +1,60 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase)
|
||||
from vllm.logger import logger
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
class SUPACommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: dist.ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[dist.ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
self.device = torch.supa.current_device()
|
||||
|
||||
# TODO: Deprecate this method in the future if torch_br support gather
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
"""All gather as gather"""
|
||||
|
||||
output_tensor = self.all_gather(input_, dim)
|
||||
if self.rank_in_group == dst:
|
||||
return output_tensor
|
||||
return None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
if envs.VLLM_BR_USE_FP32_ALL_REDUCE and input_ is not None and input_.dtype == torch.bfloat16:
|
||||
logger.debug(
|
||||
'[Patch] patch all_reduce: use fp32 all_reduce when env VLLM_BR_USE_FP32_ALL_REDUCE is set'
|
||||
)
|
||||
input_ = input_.to(torch.float32)
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
input_ = input_.to(torch.bfloat16)
|
||||
else:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
18
vllm_br/distributed/device_communicators/__init__.py
Normal file
18
vllm_br/distributed/device_communicators/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import base_device_communicator # noqa: F401
|
||||
from . import pysccl_wrapper # noqa: F401
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,44 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
|
||||
|
||||
def supa_prepare_communication_buffer_for_model(
|
||||
self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
"""
|
||||
if not self.use_all2all:
|
||||
return
|
||||
|
||||
if not self.is_ep_communicator:
|
||||
return
|
||||
|
||||
moe_modules = [
|
||||
module for module in model.modules()
|
||||
# TODO(bnell): Should use isinstance but can't. Maybe search for
|
||||
# presence of quant_method.init_prepare_finalize?
|
||||
if (module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE")
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize(module)
|
||||
|
||||
|
||||
vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.prepare_communication_buffer_for_model = supa_prepare_communication_buffer_for_model
|
||||
420
vllm_br/distributed/device_communicators/pysccl_wrapper.py
Normal file
420
vllm_br/distributed/device_communicators/pysccl_wrapper.py
Normal file
@@ -0,0 +1,420 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# This file is a pure Python wrapper for the SCCL library.
|
||||
# The main purpose is to use SCCL combined with CUDA graph.
|
||||
# Before writing this script, we tried the following approach:
|
||||
# 1. We tried to use `cupy`, it calls SCCL correctly, but `cupy` itself
|
||||
# often gets stuck when initializing the SCCL communicator.
|
||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||
# contains many other potential cuda APIs, that are not allowed during
|
||||
# capturing the CUDA graph. For further details, please check
|
||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||
#
|
||||
# Another rejected idea is to write a C/C++ binding for SCCL. It is usually
|
||||
# doable, but we often encounter issues related with succl versions, and need
|
||||
# to switch between different versions of SCCL. See
|
||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||
# recompilation of the code every time we want to switch between different
|
||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||
# more flexible. We can easily switch between different versions of SCCL by
|
||||
# changing the environment variable `VLLM_SCCL_SO_PATH`, or the `so_file`
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from vllm.logger import logger
|
||||
from vllm_br import envs
|
||||
|
||||
# === export types and functions from nccl to Python ===
|
||||
# for the original nccl definition, please check
|
||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||
|
||||
succlResult_t = ctypes.c_int
|
||||
succlComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class succlUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
suStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
succlDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class succlDataTypeEnum:
|
||||
succlInt8 = 0
|
||||
succlChar = 0
|
||||
succlUint8 = 1
|
||||
succlInt16 = 2
|
||||
succlUint16 = 3
|
||||
succlInt32 = 4
|
||||
succlInt = 4
|
||||
succlUint32 = 5
|
||||
succlInt64 = 6
|
||||
succlUint64 = 7
|
||||
succlBfloat16 = 8
|
||||
succlFloat32 = 9
|
||||
succlFloat = 9
|
||||
succlFloat64 = 10
|
||||
succlDouble = 10
|
||||
succlNumTypes = 11
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.succlInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.succlUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.succlInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.succlInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.succlBfloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.succlFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.succlFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.succlBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
succlRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class succlRedOpTypeEnum:
|
||||
succlSum = 0
|
||||
succlProd = 1
|
||||
succlMax = 2
|
||||
succlMin = 3
|
||||
succlAvg = 4
|
||||
succlNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.succlSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.succlProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.succlMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.succlMin
|
||||
if op == ReduceOp.AVG:
|
||||
return cls.succlAvg
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
class SCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* succlGetErrorString(succlResult_t result)
|
||||
Function("succlGetErrorString", ctypes.c_char_p, [succlResult_t]),
|
||||
# succlResult_t succlGetVersion(int *version);
|
||||
Function("succlGetVersion", succlResult_t,
|
||||
[ctypes.POINTER(ctypes.c_int)]),
|
||||
# succlResult_t succlGetUniqueId(succlUniqueId* uniqueId);
|
||||
Function("succlGetUniqueId", succlResult_t,
|
||||
[ctypes.POINTER(succlUniqueId)]),
|
||||
# succlResult_t succlCommInitRank(
|
||||
# succlComm_t* comm, int nranks, succlUniqueId commId, int rank);
|
||||
# note that succlComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
Function("succlCommInitRank", succlResult_t, [
|
||||
ctypes.POINTER(succlComm_t), ctypes.c_int, succlUniqueId,
|
||||
ctypes.c_int, ctypes.c_void_p
|
||||
]),
|
||||
# succlResult_t succlAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# succlDataType_t datatype, succlRedOp_t op, succlComm_t comm,
|
||||
# suStream_t stream);
|
||||
# note that suStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("succlAllReduce", succlResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
|
||||
succlRedOp_t, succlComm_t, suStream_t, ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# succlResult_t succlReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# succlDataType_t datatype, succlRedOp_t op, int root,
|
||||
# succlComm_t comm, suStream_t stream);
|
||||
# note that suStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("succlReduce", succlResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
|
||||
succlRedOp_t, ctypes.c_int, succlComm_t, suStream_t,
|
||||
ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# succlResult_t succlAllGather(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# succlDataType_t datatype, succlComm_t comm,
|
||||
# suStream_t stream);
|
||||
# note that suStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("succlAllGather", succlResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
|
||||
succlComm_t, suStream_t, ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# succlResult_t succlReduceScatter(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# succlDataType_t datatype, succlRedOp_t op, succlComm_t comm,
|
||||
# suStream_t stream);
|
||||
# note that suStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("succlReduceScatter", succlResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
|
||||
succlRedOp_t, succlComm_t, suStream_t, ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# succlResult_t succlSend(
|
||||
# const void* sendbuff, size_t count, succlDataType_t datatype,
|
||||
# int dest, succlComm_t comm, suStream_t stream);
|
||||
Function("succlSend", succlResult_t, [
|
||||
buffer_type, ctypes.c_size_t, succlDataType_t, ctypes.c_int,
|
||||
succlComm_t, suStream_t, ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# succlResult_t succlRecv(
|
||||
# void* recvbuff, size_t count, succlDataType_t datatype,
|
||||
# int src, succlComm_t comm, suStream_t stream);
|
||||
Function("succlRecv", succlResult_t, [
|
||||
buffer_type, ctypes.c_size_t, succlDataType_t, ctypes.c_int,
|
||||
succlComm_t, suStream_t, ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# succlResult_t succlBroadcast(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# succlDataType_t datatype, int root, succlComm_t comm,
|
||||
# suStream_t stream);
|
||||
Function("succlBroadcast", succlResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
|
||||
ctypes.c_int, succlComm_t, suStream_t, ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
# it is better not to call it at all.
|
||||
# succlResult_t succlCommDestroy(succlComm_t comm);
|
||||
Function("succlCommDestroy", succlResult_t, [succlComm_t]),
|
||||
# succlResult_t succlGroupStart();
|
||||
Function("succlGroupStart", succlResult_t, []),
|
||||
# succlResult_t succlGroupEnd();
|
||||
Function("succlGroupEnd", succlResult_t, []),
|
||||
# Function("succldemoSetdevice", succlResult_t, [ctypes.c_int]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_sccl_library()
|
||||
try:
|
||||
if so_file not in SCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
SCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = SCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load SCCL library from %s. "
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise, the sccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable VLLM_SCCL_SO_PATH"
|
||||
" to point to the correct sccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in SCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: dict[str, Any] = {}
|
||||
for func in SCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
SCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = SCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def succlGetErrorString(self, result: succlResult_t) -> str:
|
||||
return self._funcs["succlGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def SUCCL_CHECK(self, result: succlResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.succlGetErrorString(result)
|
||||
raise RuntimeError(f"SCCL error: {error_str}")
|
||||
|
||||
def succlGetVersion(self) -> str:
|
||||
version = ctypes.c_int()
|
||||
self.SUCCL_CHECK(self._funcs["succlGetVersion"](ctypes.byref(version)))
|
||||
version_str = str(version.value)
|
||||
# something like 21903 --> "2.19.3"
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
patch = version_str[3:].lstrip("0")
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
def succlGetUniqueId(self) -> succlUniqueId:
|
||||
unique_id = succlUniqueId()
|
||||
self.SUCCL_CHECK(self._funcs["succlGetUniqueId"](
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def unique_id_from_bytes(self, data: bytes) -> succlUniqueId:
|
||||
if len(data) != 128:
|
||||
raise ValueError(
|
||||
f"Expected 128 bytes for succlUniqueId, got {len(data)} bytes")
|
||||
unique_id = succlUniqueId()
|
||||
ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
|
||||
return unique_id
|
||||
|
||||
def succlCommInitRank(self, world_size: int, unique_id: succlUniqueId,
|
||||
rank: int) -> succlComm_t:
|
||||
comm = succlComm_t()
|
||||
result = self._funcs["succlCommInitRank"](ctypes.byref(comm),
|
||||
world_size, unique_id, rank,
|
||||
None)
|
||||
self.SUCCL_CHECK(result)
|
||||
return comm
|
||||
|
||||
# def succldemoSetdevice(self, deviceid:int):
|
||||
# self._funcs["succldemoSetdevice"](deviceid)
|
||||
|
||||
def succlAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: succlComm_t,
|
||||
stream: suStream_t) -> None:
|
||||
# `datatype` actually should be `succlDataType_t`
|
||||
# and `op` should be `succlRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.SUCCL_CHECK(self._funcs["succlAllReduce"](sendbuff, recvbuff,
|
||||
count, datatype, op,
|
||||
comm, stream, None))
|
||||
|
||||
def succlReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, root: int,
|
||||
comm: succlComm_t, stream: suStream_t) -> None:
|
||||
# `datatype` actually should be `succlDataType_t`
|
||||
# and `op` should be `succlRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.SUCCL_CHECK(self._funcs["succlReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, root, comm,
|
||||
stream, None))
|
||||
|
||||
def succlReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int,
|
||||
comm: succlComm_t, stream: suStream_t) -> None:
|
||||
# `datatype` actually should be `succlDataType_t`
|
||||
# and `op` should be `succlRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.SUCCL_CHECK(self._funcs["succlReduceScatter"](sendbuff, recvbuff,
|
||||
count, datatype, op,
|
||||
comm, stream, None))
|
||||
|
||||
def succlAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, comm: succlComm_t,
|
||||
stream: suStream_t) -> None:
|
||||
# `datatype` actually should be `succlDataType_t`
|
||||
# which is an aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.SUCCL_CHECK(self._funcs["succlAllGather"](sendbuff, recvbuff,
|
||||
count, datatype, comm,
|
||||
stream, None))
|
||||
|
||||
def succlSend(self, sendbuff: buffer_type, count: int, datatype: int,
|
||||
dest: int, comm: succlComm_t, stream: suStream_t) -> None:
|
||||
self.SUCCL_CHECK(self._funcs["succlSend"](sendbuff, count, datatype,
|
||||
dest, comm, stream, None))
|
||||
|
||||
def succlRecv(self, recvbuff: buffer_type, count: int, datatype: int,
|
||||
src: int, comm: succlComm_t, stream: suStream_t) -> None:
|
||||
self.SUCCL_CHECK(self._funcs["succlRecv"](recvbuff, count, datatype,
|
||||
src, comm, stream, None))
|
||||
|
||||
def succlBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, root: int, comm: succlComm_t,
|
||||
stream: suStream_t) -> None:
|
||||
self.SUCCL_CHECK(self._funcs["succlBroadcast"](sendbuff, recvbuff,
|
||||
count, datatype, root,
|
||||
comm, stream, None))
|
||||
|
||||
def succlCommDestroy(self, comm: succlComm_t) -> None:
|
||||
self.SUCCL_CHECK(self._funcs["succlCommDestroy"](comm))
|
||||
|
||||
def succlGroupStart(self) -> None:
|
||||
self.SUCCL_CHECK(self._funcs["succlGroupStart"]())
|
||||
|
||||
def succlGroupEnd(self) -> None:
|
||||
self.SUCCL_CHECK(self._funcs["succlGroupEnd"]())
|
||||
|
||||
|
||||
def find_sccl_library() -> str:
|
||||
"""
|
||||
We either use the library file specified by the `VLLM_SCCL_SO_PATH`
|
||||
environment variable, or we find the library file brought by PyTorch.
|
||||
After importing `torch`, `libsuccl.so.2` or `librccl.so.1` can be
|
||||
found by `ctypes` automatically.
|
||||
"""
|
||||
so_file = envs.VLLM_SCCL_SO_PATH
|
||||
# manually load the sccl library
|
||||
if so_file:
|
||||
logger.info(
|
||||
"Found sccl from environment variable VLLM_SCCL_SO_PATH=%s",
|
||||
so_file)
|
||||
else:
|
||||
raise ValueError("SCCL lib file not found.")
|
||||
return so_file
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SCCLLibrary", "succlDataTypeEnum", "succlRedOpTypeEnum", "succlUniqueId",
|
||||
"succlComm_t", "suStream_t", "buffer_type"
|
||||
]
|
||||
17
vllm_br/distributed/kv_transfer/__init__.py
Normal file
17
vllm_br/distributed/kv_transfer/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import kv_connector # noqa: F401
|
||||
Binary file not shown.
17
vllm_br/distributed/kv_transfer/kv_connector/__init__.py
Normal file
17
vllm_br/distributed/kv_transfer/kv_connector/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import v1 # noqa: F401
|
||||
Binary file not shown.
17
vllm_br/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
17
vllm_br/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import base, p2p # noqa: F401
|
||||
Binary file not shown.
Binary file not shown.
28
vllm_br/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
28
vllm_br/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
@@ -0,0 +1,28 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# from vllm.logger import logger
|
||||
# from vllm.v1.core.sched.output import SchedulerOutput
|
||||
# from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
# class KVConnectorRole(enum.Enum):
|
||||
# # Connector running in the scheduler process
|
||||
# SCHEDULER = 0
|
||||
|
||||
# # Connector running in the worker process
|
||||
# WORKER = 1
|
||||
|
||||
# vllm.distributed.kv_transfer.kv_connector.v1.base.KVConnectorRole=KVConnectorRole
|
||||
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from . import p2p_succl_engine # noqa: F401
|
||||
from . import p2p_succl_connector, tensor_memory_pool # noqa: F401
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,535 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
import torch_br
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm_br.distributed.kv_transfer.kv_connector.v1.p2p.p2p_succl_engine import (
|
||||
P2pSucclEngine)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request Id
|
||||
request_id: str
|
||||
# Request block ids
|
||||
block_ids: torch.Tensor
|
||||
# Request num tokens
|
||||
num_tokens: int
|
||||
|
||||
@staticmethod
|
||||
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
|
||||
block_size: int) -> "ReqMeta":
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
return ReqMeta(
|
||||
request_id=request_id,
|
||||
block_ids=block_ids_tensor,
|
||||
num_tokens=len(token_ids),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class P2pSucclConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta]
|
||||
|
||||
def __init__(self):
|
||||
self.requests = []
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size))
|
||||
|
||||
|
||||
class P2pSucclConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Any] = {}
|
||||
self.config = vllm_config.kv_transfer_config
|
||||
self.is_producer = self.config.is_kv_producer
|
||||
self.chunked_prefill: dict[str, Any] = {}
|
||||
|
||||
self._rank = get_world_group().rank \
|
||||
if role == KVConnectorRole.WORKER else 0
|
||||
self._local_rank = get_world_group().local_rank \
|
||||
if role == KVConnectorRole.WORKER else 0
|
||||
self.p2p_nccl_engine = P2pSucclEngine(
|
||||
local_rank=self._local_rank,
|
||||
config=self.config,
|
||||
hostname="",
|
||||
port_offset=self._rank,
|
||||
) if role == KVConnectorRole.WORKER else None
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
|
||||
# Only consumer/decode loads KV Cache
|
||||
if self.is_producer:
|
||||
return
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
|
||||
def inject_kv_into_layer(
|
||||
layer: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
block_ids: torch.Tensor,
|
||||
request_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Inject KV cache data into a given attention layer tensor.
|
||||
|
||||
This function updates `layer` in-place with values from `kv_cache`,
|
||||
handling different backend layouts:
|
||||
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
|
||||
indexed along the first dimension.
|
||||
- FlashAttention: KV tensors are indexed along the second
|
||||
dimension.
|
||||
|
||||
If the number of provided block IDs does not match the number of KV
|
||||
blocks, only the overlapping portion is updated, and a warning is
|
||||
logged.
|
||||
|
||||
Args:
|
||||
layer (torch.Tensor): The attention layer KV tensor to update.
|
||||
kv_cache (torch.Tensor): The KV cache tensor to inject.
|
||||
block_ids (torch.Tensor): Indices of the blocks to update.
|
||||
request_id (str): Request identifier used for logging.
|
||||
|
||||
Returns:
|
||||
None. The function modifies `layer` in-place.
|
||||
"""
|
||||
if (isinstance(attn_metadata, MLACommonMetadata)
|
||||
or layer.shape[1] == 2): # MLA or FlashInfer
|
||||
num_block = kv_cache.shape[1]
|
||||
block_len = min(len(block_ids), num_block)
|
||||
block_ids = block_ids[:block_len]
|
||||
th_gran = layer.shape[2] // self._block_size
|
||||
for i, block_index in enumerate(block_ids.tolist()):
|
||||
dst_0 = block_index // th_gran
|
||||
dst_1 = (block_index % th_gran) * self._block_size
|
||||
layer[0][dst_0][dst_1:dst_1 +
|
||||
self._block_size] = kv_cache[0][i]
|
||||
elif layer.shape[0] == 2: # FlashAttention
|
||||
num_block = kv_cache.shape[1]
|
||||
block_len = min(len(block_ids), num_block)
|
||||
block_ids = block_ids[:block_len]
|
||||
th_gran = layer.shape[2] // self._block_size
|
||||
for i, block_index in enumerate(block_ids.tolist()):
|
||||
dst_0 = block_index // th_gran
|
||||
dst_1 = (block_index % th_gran) * self._block_size
|
||||
layer[0][dst_0][dst_1:dst_1 +
|
||||
self._block_size] = kv_cache[0][i]
|
||||
layer[1][dst_0][dst_1:dst_1 +
|
||||
self._block_size] = kv_cache[1][i]
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = \
|
||||
self._get_connector_metadata()
|
||||
assert isinstance(metadata, P2pSucclConnectorMetadata)
|
||||
|
||||
if metadata is None:
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, False)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
layer = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only process layers that have kv_cache
|
||||
# attribute (attention layers) Skip non-attention
|
||||
# layers like FusedMoE
|
||||
kv_cache = getattr(layer, 'kv_cache', None)
|
||||
if kv_cache is None:
|
||||
continue
|
||||
|
||||
layer = kv_cache[forward_context.virtual_engine]
|
||||
kv_cache = self.p2p_nccl_engine.recv_tensor(
|
||||
request.request_id + "#" + layer_name, remote_address)
|
||||
if kv_cache is None:
|
||||
logger.warning("🚧kv_cache is None, %s", request.request_id)
|
||||
continue
|
||||
|
||||
inject_kv_into_layer(layer, kv_cache, request.block_ids,
|
||||
request.request_id)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
# Only producer/prefill saves KV Cache
|
||||
if not self.is_producer:
|
||||
return
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
block_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Extract KV cache slices from a given attention layer tensor.
|
||||
|
||||
This function handles multiple backend layouts:
|
||||
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
|
||||
indexed along the first dimension.
|
||||
- FlashAttention: KV tensors are indexed along the second
|
||||
dimension.
|
||||
|
||||
Args:
|
||||
layer (torch.Tensor): The KV cache from the attention layer.
|
||||
block_ids (torch.Tensor): Indices of blocks to extract.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor containing the extracted KV slices.
|
||||
Returns None if the layout is unsupported.
|
||||
"""
|
||||
if (isinstance(attn_metadata, MLACommonMetadata)
|
||||
or layer.shape[1] == 2): # MLA or FlashInfer
|
||||
origin_shape = layer.shape
|
||||
shape = [
|
||||
origin_shape[0],
|
||||
len(block_ids), self._block_size, origin_shape[3]
|
||||
]
|
||||
layer_send = torch_br._empty_ut_only(shape,
|
||||
dtype=layer.dtype,
|
||||
tensor_type='BUFFER_ANY',
|
||||
device=layer.device)
|
||||
th_gran = origin_shape[2] // self._block_size
|
||||
for i, block_index in enumerate(block_ids.tolist()):
|
||||
dst_0 = block_index // th_gran
|
||||
dst_1 = (block_index % th_gran) * self._block_size
|
||||
layer_send[0][i] = layer[0][dst_0][dst_1:dst_1 +
|
||||
self._block_size]
|
||||
return layer_send
|
||||
|
||||
if layer.shape[0] == 2: # FlashAttention
|
||||
origin_shape = layer.shape
|
||||
shape = [
|
||||
origin_shape[0],
|
||||
len(block_ids), self._block_size, origin_shape[3]
|
||||
]
|
||||
layer_send = torch_br._empty_ut_only(shape,
|
||||
dtype=layer.dtype,
|
||||
tensor_type='BUFFER_ANY',
|
||||
device=layer.device)
|
||||
th_gran = origin_shape[2] // self._block_size
|
||||
for i, block_index in enumerate(block_ids.tolist()):
|
||||
dst_0 = block_index // th_gran
|
||||
dst_1 = (block_index % th_gran) * self._block_size
|
||||
layer_send[0][i] = layer[0][dst_0][dst_1:dst_1 +
|
||||
self._block_size]
|
||||
layer_send[1][i] = layer[1][dst_0][dst_1:dst_1 +
|
||||
self._block_size]
|
||||
return layer_send
|
||||
|
||||
return None
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, P2pSucclConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, True)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
|
||||
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
|
||||
kv_cache, remote_address)
|
||||
|
||||
def wait_for_save(self):
|
||||
if self.is_producer:
|
||||
assert self.p2p_nccl_engine is not None
|
||||
self.p2p_nccl_engine.wait_for_sent()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str],
|
||||
**kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer,
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
no_compile_layers = (
|
||||
self._vllm_config.compilation_config.static_forward_context)
|
||||
return self.p2p_nccl_engine.get_finished(finished_req_ids,
|
||||
no_compile_layers)
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
if self.is_producer:
|
||||
return 0, False
|
||||
|
||||
num_external_tokens = (len(request.prompt_token_ids) - 1 -
|
||||
num_computed_tokens)
|
||||
|
||||
if num_external_tokens < 0:
|
||||
num_external_tokens = 0
|
||||
|
||||
return num_external_tokens, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
if not self.is_producer and num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = (
|
||||
request, blocks.get_block_ids()[0])
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
|
||||
meta = P2pSucclConnectorMetadata()
|
||||
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = (
|
||||
scheduler_output.num_scheduled_tokens)[new_req.req_id]
|
||||
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
|
||||
# the request's prompt is chunked prefill
|
||||
if num_tokens < len(new_req.prompt_token_ids):
|
||||
# 'CachedRequestData' has no attribute 'prompt_token_ids'
|
||||
self.chunked_prefill[new_req.req_id] = (
|
||||
new_req.block_ids[0], new_req.prompt_token_ids)
|
||||
continue
|
||||
# the request's prompt is not chunked prefill
|
||||
meta.add_request(request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size)
|
||||
continue
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size)
|
||||
self._requests_need_load.pop(new_req.req_id)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = (
|
||||
scheduler_output.num_scheduled_tokens)[req_id]
|
||||
num_tokens = (num_scheduled_tokens + num_computed_tokens)
|
||||
assert req_id in self.chunked_prefill
|
||||
block_ids = new_block_ids[0]
|
||||
if not resumed_from_preemption:
|
||||
block_ids = (self.chunked_prefill[req_id][0] + block_ids)
|
||||
prompt_token_ids = self.chunked_prefill[req_id][1]
|
||||
# the request's prompt is chunked prefill again
|
||||
if num_tokens < len(prompt_token_ids):
|
||||
self.chunked_prefill[req_id] = (block_ids,
|
||||
prompt_token_ids)
|
||||
continue
|
||||
# the request's prompt is all prefilled finally
|
||||
meta.add_request(request_id=req_id,
|
||||
token_ids=prompt_token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size)
|
||||
self.chunked_prefill.pop(req_id, None)
|
||||
continue
|
||||
|
||||
# NOTE(rob): here we rely on the resumed requests being
|
||||
# the first N requests in the list scheduled_cache_reqs.
|
||||
if not resumed_from_preemption:
|
||||
break
|
||||
if req_id in self._requests_need_load:
|
||||
request, _ = self._requests_need_load.pop(req_id)
|
||||
total_tokens = num_computed_tokens + 1
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(request_id=req_id,
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size)
|
||||
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
|
||||
self.chunked_prefill.pop(request.request_id, None)
|
||||
|
||||
return False, None
|
||||
|
||||
# ==============================
|
||||
# Static methods
|
||||
# ==============================
|
||||
|
||||
@staticmethod
|
||||
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
|
||||
# Regular expression to match the string hostname and integer port
|
||||
if is_prefill:
|
||||
pattern = r"___decode_addr_(.*):(\d+)"
|
||||
else:
|
||||
pattern = r"___prefill_addr_(.*):(\d+)___"
|
||||
|
||||
# Use re.search to find the pattern in the request_id
|
||||
match = re.search(pattern, request_id)
|
||||
if match:
|
||||
# Extract the ranks
|
||||
ip = match.group(1)
|
||||
port = int(match.group(2))
|
||||
|
||||
return ip, port
|
||||
raise ValueError(
|
||||
f"Request id {request_id} does not contain hostname and port")
|
||||
|
||||
@staticmethod
|
||||
def check_tensors_except_dim(tensor1, tensor2, dim):
|
||||
shape1 = tensor1.size()
|
||||
shape2 = tensor2.size()
|
||||
|
||||
if len(shape1) != len(shape2) or not all(
|
||||
s1 == s2
|
||||
for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim):
|
||||
raise NotImplementedError(
|
||||
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
|
||||
"and others will be supported in future PRs.")
|
||||
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"P2pSucclConnector",
|
||||
"vllm_br.distributed.kv_transfer.kv_connector.v1.p2p.p2p_succl_connector",
|
||||
"P2pSucclConnector")
|
||||
@@ -0,0 +1,572 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import torch_br
|
||||
import zmq
|
||||
from torch_br.supa._internal import get_tensor_info
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
# import vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine
|
||||
from vllm.utils import get_ip
|
||||
from vllm_br.distributed.device_communicators.pysccl_wrapper import (
|
||||
SCCLLibrary, buffer_type, succlComm_t, succlDataTypeEnum, suStream_t)
|
||||
from vllm_br.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
|
||||
TensorMemoryPool)
|
||||
from vllm_br.platform import SUPAPlatform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MEM_POOL_SIZE_GB = 1
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_p2p_succl_context(num_channels: str):
|
||||
original_values: dict[str, Any] = {}
|
||||
env_vars = [
|
||||
'SUCCL_MAX_NCHANNELS',
|
||||
'SUCCL_MIN_NCHANNELS',
|
||||
'SUCCL_CUMEM_ENABLE',
|
||||
'SUCCL_BUFFSIZE',
|
||||
'SUCCL_PROTO', # LL,LL128,SIMPLE
|
||||
'SUCCL_ALGO', # RING,TREE
|
||||
]
|
||||
|
||||
for var in env_vars:
|
||||
original_values[var] = os.environ.get(var)
|
||||
|
||||
logger.info("set_p2p_succl_context, original_values: %s", original_values)
|
||||
|
||||
try:
|
||||
os.environ['SUCCL_MAX_NCHANNELS'] = num_channels
|
||||
os.environ['SUCCL_MIN_NCHANNELS'] = num_channels
|
||||
os.environ['SUCCL_CUMEM_ENABLE'] = '1'
|
||||
yield
|
||||
finally:
|
||||
for var in env_vars:
|
||||
if original_values[var] is not None:
|
||||
os.environ[var] = original_values[var]
|
||||
else:
|
||||
os.environ.pop(var, None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendQueueItem:
|
||||
tensor_id: str
|
||||
remote_address: str
|
||||
tensor: torch.Tensor
|
||||
|
||||
|
||||
class P2pSucclEngine:
|
||||
|
||||
def __init__(self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
hostname: str = "",
|
||||
port_offset: int = 0,
|
||||
library_path: Optional[str] = None) -> None:
|
||||
self.config = config
|
||||
self.rank = port_offset
|
||||
self.local_rank = local_rank
|
||||
self.device = torch.device(f"supa:{self.local_rank}")
|
||||
if config is not None:
|
||||
device_cursor = self.config.get_from_extra_config(
|
||||
"device_cursor", 0)
|
||||
self.device = torch.device(
|
||||
f"supa:{self.local_rank + int(device_cursor)}")
|
||||
SUPAPlatform.set_device(self.device)
|
||||
self.succl = SCCLLibrary(library_path)
|
||||
|
||||
if not hostname:
|
||||
hostname = get_ip()
|
||||
port = int(self.config.kv_port) + port_offset
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
|
||||
# Each card corresponds to a ZMQ address.
|
||||
self.zmq_address = f"{self._hostname}:{self._port}"
|
||||
|
||||
# The `http_port` must be consistent with the port of OpenAI.
|
||||
self.http_address = (
|
||||
f"{self._hostname}:"
|
||||
f"{self.config.kv_connector_extra_config['http_port']}")
|
||||
|
||||
# If `proxy_ip` or `proxy_port` is `""`,
|
||||
# then the ping thread will not be enabled.
|
||||
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
|
||||
proxy_port = self.config.get_from_extra_config("proxy_port", "")
|
||||
if proxy_ip == "" or proxy_port == "":
|
||||
self.proxy_address = ""
|
||||
else:
|
||||
self.proxy_address = proxy_ip + ":" + proxy_port
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.router_socket = self.context.socket(zmq.ROUTER)
|
||||
self.router_socket.bind(f"tcp://{self.zmq_address}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
|
||||
self.send_store_cv = threading.Condition()
|
||||
self.send_queue_cv = threading.Condition()
|
||||
self.recv_store_cv = threading.Condition()
|
||||
|
||||
self.send_stream = torch_br.supa.Stream()
|
||||
self.recv_stream = self.send_stream
|
||||
|
||||
mem_pool_size_gb = float(
|
||||
self.config.get_from_extra_config("mem_pool_size_gb",
|
||||
DEFAULT_MEM_POOL_SIZE_GB))
|
||||
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb *
|
||||
1024**3)) # GB
|
||||
|
||||
# The sending type includes tree mutually exclusive options:
|
||||
# PUT, GET, PUT_ASYNC.
|
||||
self.send_type = self.config.get_from_extra_config(
|
||||
"send_type", "PUT_ASYNC")
|
||||
if self.send_type == "GET":
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_store: dict[str, torch.Tensor] = {}
|
||||
else:
|
||||
# PUT or PUT_ASYNC
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_queue: deque[SendQueueItem] = deque()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread = threading.Thread(target=self.send_async,
|
||||
daemon=True)
|
||||
self._send_thread.start()
|
||||
|
||||
# tensor_id: torch.Tensor/(addr, dtype, shape)
|
||||
self.recv_store: dict[str, Any] = {}
|
||||
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.socks: dict[str, Any] = {} # remote_address: client socket
|
||||
self.comms: dict[str, Any] = {} # remote_address: (succlComm_t, rank)
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = float(self.config.kv_buffer_size)
|
||||
|
||||
self.succl_num_channels = self.config.get_from_extra_config(
|
||||
"nccl_num_channels", "8")
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self.listen_for_requests, daemon=True)
|
||||
self._listener_thread.start()
|
||||
|
||||
self._ping_thread = None
|
||||
if port_offset == 0 and self.proxy_address != "":
|
||||
self._ping_thread = threading.Thread(target=self.ping, daemon=True)
|
||||
self._ping_thread.start()
|
||||
|
||||
logger.info(
|
||||
"💯P2pSucclEngine init, rank:%d, local_rank:%d, http_address:%s, "
|
||||
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
|
||||
"threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank,
|
||||
self.http_address, self.zmq_address, self.proxy_address,
|
||||
self.send_type, self.buffer_size_threshold,
|
||||
self.succl_num_channels)
|
||||
|
||||
def create_connect(self, remote_address: typing.Optional[str] = None):
|
||||
assert remote_address is not None
|
||||
if remote_address not in self.socks:
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
sock.connect(f"tcp://{remote_address}")
|
||||
self.socks[remote_address] = sock
|
||||
if remote_address in self.comms:
|
||||
logger.info("👋comm exists, remote_address:%s, comms:%s",
|
||||
remote_address, self.comms)
|
||||
return sock, self.comms[remote_address]
|
||||
unique_id = self.succl.succlGetUniqueId()
|
||||
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
|
||||
sock.send(msgpack.dumps(data))
|
||||
rank = 0
|
||||
SUPAPlatform.set_device(self.device)
|
||||
comm: succlComm_t = self.succl.succlCommInitRank(
|
||||
2, unique_id, rank)
|
||||
self.comms[remote_address] = (comm, rank)
|
||||
logger.info("🤝succlCommInitRank Success, %s👉%s, MyRank:%s",
|
||||
self.zmq_address, remote_address, rank)
|
||||
return self.socks[remote_address], self.comms[remote_address]
|
||||
|
||||
def send_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.recv_store_cv.notify()
|
||||
return True
|
||||
item = SendQueueItem(tensor_id=tensor_id,
|
||||
remote_address=remote_address,
|
||||
tensor=tensor)
|
||||
|
||||
if self.send_type == "PUT":
|
||||
return self.send_sync(item)
|
||||
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
with self.send_queue_cv:
|
||||
self.send_queue.append(item)
|
||||
self.send_queue_cv.notify()
|
||||
return True
|
||||
|
||||
# GET
|
||||
with self.send_store_cv:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if tensor_size > self.buffer_size_threshold:
|
||||
logger.warning(
|
||||
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
|
||||
"buffer size threshold :%d, skip send to %s, rank:%d",
|
||||
tensor_id, tensor_size, self.buffer_size_threshold,
|
||||
remote_address, self.rank)
|
||||
return False
|
||||
while (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
assert len(self.send_store) > 0
|
||||
oldest_tensor_id = next(iter(self.send_store))
|
||||
oldest_tensor = self.send_store.pop(oldest_tensor_id)
|
||||
oldest_tensor_size = oldest_tensor.element_size(
|
||||
) * oldest_tensor.numel()
|
||||
self.buffer_size -= oldest_tensor_size
|
||||
logger.debug(
|
||||
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
|
||||
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
|
||||
remote_address, tensor_id, tensor_size, self.buffer_size,
|
||||
oldest_tensor_size, self.rank)
|
||||
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.buffer_size += tensor_size
|
||||
logger.debug(
|
||||
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
|
||||
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address,
|
||||
tensor_id, tensor_size, tensor.shape, self.rank,
|
||||
self.buffer_size,
|
||||
self.buffer_size / self.buffer_size_threshold * 100)
|
||||
return True
|
||||
|
||||
def recv_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.recv_store_cv:
|
||||
while tensor_id not in self.recv_store:
|
||||
self.recv_store_cv.wait()
|
||||
tensor = self.recv_store[tensor_id]
|
||||
|
||||
if tensor is not None:
|
||||
if isinstance(tensor, tuple):
|
||||
addr, dtype, shape = tensor
|
||||
tensor = self.pool.load_tensor(addr, dtype, shape,
|
||||
self.device)
|
||||
else:
|
||||
self.buffer_size -= (tensor.element_size() *
|
||||
tensor.numel())
|
||||
else:
|
||||
duration = time.time() - start_time
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
|
||||
"rank:%d", remote_address, tensor_id, duration * 1000,
|
||||
self.rank)
|
||||
return tensor
|
||||
|
||||
# GET
|
||||
if remote_address is None:
|
||||
return None
|
||||
|
||||
if remote_address not in self.socks:
|
||||
self.create_connect(remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
|
||||
data = {"cmd": "GET", "tensor_id": tensor_id}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
message = sock.recv()
|
||||
data = msgpack.loads(message)
|
||||
if data["ret"] != 0:
|
||||
logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
|
||||
remote_address, tensor_id, data["ret"])
|
||||
return None
|
||||
|
||||
with torch_br.supa.stream(self.recv_stream):
|
||||
tensor = torch_br._empty_ut_only(data["shape"],
|
||||
dtype=getattr(
|
||||
torch, data["dtype"]),
|
||||
tensor_type='BUFFER_ANY',
|
||||
device=self.device)
|
||||
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
|
||||
return tensor
|
||||
|
||||
def listen_for_requests(self):
|
||||
while True:
|
||||
socks = dict(self.poller.poll())
|
||||
if self.router_socket not in socks:
|
||||
continue
|
||||
|
||||
remote_address, message = self.router_socket.recv_multipart()
|
||||
data = msgpack.loads(message)
|
||||
if data["cmd"] == "NEW":
|
||||
unique_id = self.succl.unique_id_from_bytes(
|
||||
bytes(data["unique_id"]))
|
||||
|
||||
rank = 1
|
||||
SUPAPlatform.set_device(self.device)
|
||||
comm: succlComm_t = self.succl.succlCommInitRank(
|
||||
2, unique_id, rank)
|
||||
self.comms[remote_address.decode()] = (comm, rank)
|
||||
logger.info("🤝suclCommInitRank Success, %s👈%s, MyRank:%s",
|
||||
self.zmq_address, remote_address.decode(), rank)
|
||||
elif data["cmd"] == "PUT":
|
||||
tensor_id = data["tensor_id"]
|
||||
try:
|
||||
with torch_br.supa.stream(self.recv_stream):
|
||||
tensor = torch_br._empty_ut_only(
|
||||
data["shape"],
|
||||
dtype=getattr(torch, data["dtype"]),
|
||||
tensor_type='BUFFER_ANY',
|
||||
device=self.device)
|
||||
self.router_socket.send_multipart([remote_address, b"0"])
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
# Store Tensor in memory pool
|
||||
addr = self.pool.store_tensor(tensor)
|
||||
tensor = (addr, tensor.dtype, tensor.shape)
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Threshold, "
|
||||
"%s👈%s, data:%s, addr:%d", self.zmq_address,
|
||||
remote_address.decode(), data, addr)
|
||||
else:
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self.router_socket.send_multipart([remote_address, b"1"])
|
||||
tensor = None
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
|
||||
"data:%s", self.zmq_address, remote_address.decode(),
|
||||
data)
|
||||
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.have_received_tensor_id(tensor_id)
|
||||
self.recv_store_cv.notify()
|
||||
|
||||
elif data["cmd"] == "GET":
|
||||
tensor_id = data["tensor_id"]
|
||||
with self.send_store_cv:
|
||||
tensor = self.send_store.pop(tensor_id, None)
|
||||
if tensor is not None:
|
||||
data = {
|
||||
"ret": 0,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", ""),
|
||||
"tensor_type": get_tensor_info(tensor)[0]['layout']
|
||||
}
|
||||
# LRU
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.have_sent_tensor_id(tensor_id)
|
||||
else:
|
||||
data = {"ret": 1}
|
||||
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, msgpack.dumps(data)])
|
||||
|
||||
if data["ret"] == 0:
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self.send(comm, tensor.to(self.device), rank ^ 1,
|
||||
self.send_stream)
|
||||
else:
|
||||
logger.warning(
|
||||
"🚧Unexpected, Received message from %s, data:%s",
|
||||
remote_address, data)
|
||||
|
||||
def have_sent_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split('#')[0]
|
||||
if request_id not in self.send_request_id_to_tensor_ids:
|
||||
self.send_request_id_to_tensor_ids[request_id] = set()
|
||||
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def have_received_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split('#')[0]
|
||||
if request_id not in self.recv_request_id_to_tensor_ids:
|
||||
self.recv_request_id_to_tensor_ids[request_id] = set()
|
||||
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def send_async(self):
|
||||
while True:
|
||||
with self.send_queue_cv:
|
||||
while not self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
item = self.send_queue.popleft()
|
||||
if not self.send_queue:
|
||||
self.send_queue_cv.notify()
|
||||
self.send_sync(item)
|
||||
|
||||
def wait_for_sent(self):
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.send_queue_cv:
|
||||
while self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
duration = time.time() - start_time
|
||||
logger.debug(
|
||||
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
|
||||
" to be empty, rank:%d", duration * 1000, self.rank)
|
||||
|
||||
def send_sync(self, item: SendQueueItem) -> bool:
|
||||
if item.remote_address is None:
|
||||
return False
|
||||
if item.remote_address not in self.socks:
|
||||
self.create_connect(item.remote_address)
|
||||
tensor = item.tensor
|
||||
|
||||
sock = self.socks[item.remote_address]
|
||||
comm, rank = self.comms[item.remote_address]
|
||||
data = {
|
||||
"cmd": "PUT",
|
||||
"tensor_id": item.tensor_id,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", ""),
|
||||
"tensor_type": get_tensor_info(tensor)[0]['layout']
|
||||
}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
response = sock.recv()
|
||||
if response != b"0":
|
||||
logger.error(
|
||||
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
|
||||
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
|
||||
self.zmq_address, item.remote_address, rank, data,
|
||||
tensor.shape,
|
||||
tensor.element_size() * tensor.numel() / 1024**3,
|
||||
response.decode())
|
||||
return False
|
||||
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self.have_sent_tensor_id(item.tensor_id)
|
||||
|
||||
return True
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str], no_compile_layers
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer,
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
|
||||
# Clear the buffer upon request completion.
|
||||
for request_id in finished_req_ids:
|
||||
for layer_name in no_compile_layers:
|
||||
tensor_id = request_id + "#" + layer_name
|
||||
if tensor_id in self.recv_store:
|
||||
with self.recv_store_cv:
|
||||
tensor = self.recv_store.pop(tensor_id, None)
|
||||
self.send_request_id_to_tensor_ids.pop(
|
||||
request_id, None)
|
||||
self.recv_request_id_to_tensor_ids.pop(
|
||||
request_id, None)
|
||||
if isinstance(tensor, tuple):
|
||||
addr, _, _ = tensor
|
||||
self.pool.free(addr)
|
||||
|
||||
# TODO:Retrieve requests that have already sent the KV cache.
|
||||
finished_sending: set[str] = set()
|
||||
|
||||
# TODO:Retrieve requests that have already received the KV cache.
|
||||
finished_recving: set[str] = set()
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
def ping(self):
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
logger.debug("ping start, zmq_address:%s", self.zmq_address)
|
||||
sock.connect(f"tcp://{self.proxy_address}")
|
||||
data = {
|
||||
"type": "P" if self.config.is_kv_producer else "D",
|
||||
"http_address": self.http_address,
|
||||
"zmq_address": self.zmq_address
|
||||
}
|
||||
while True:
|
||||
sock.send(msgpack.dumps(data))
|
||||
time.sleep(3)
|
||||
|
||||
def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this succl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = torch_br.supa.Stream()
|
||||
with torch_br.supa.stream(stream):
|
||||
self.succl.succlSend(buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
succlDataTypeEnum.from_torch(tensor.dtype),
|
||||
dst, comm, suStream_t(stream.supa_stream))
|
||||
stream.synchronize()
|
||||
|
||||
def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this succl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = torch_br.supa.Stream()
|
||||
with torch_br.supa.stream(stream):
|
||||
self.succl.succlRecv(buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
succlDataTypeEnum.from_torch(tensor.dtype),
|
||||
src, comm, suStream_t(stream.supa_stream))
|
||||
stream.synchronize()
|
||||
|
||||
def close(self) -> None:
|
||||
self._listener_thread.join()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread.join()
|
||||
if self._ping_thread is not None:
|
||||
self._ping_thread.join()
|
||||
@@ -0,0 +1,280 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import atexit
|
||||
import ctypes
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryBlock:
|
||||
size: int
|
||||
addr: int
|
||||
|
||||
|
||||
"""A memory pool for managing pinned host memory allocations for tensors.
|
||||
|
||||
This class implements a buddy allocation system to efficiently manage pinned
|
||||
host memory for tensor storage. It supports allocation, deallocation, and
|
||||
tensor storage/retrieval operations.
|
||||
|
||||
Key Features:
|
||||
- Uses power-of-two block sizes for efficient buddy allocation
|
||||
- Supports splitting and merging of memory blocks
|
||||
- Provides methods to store CUDA tensors in pinned host memory
|
||||
- Allows loading tensors from pinned memory back to device
|
||||
- Automatically cleans up memory on destruction
|
||||
|
||||
Attributes:
|
||||
max_block_size (int): Maximum block size (rounded to nearest power of two)
|
||||
min_block_size (int): Minimum block size (rounded to nearest power of two)
|
||||
free_lists (dict): Dictionary of free memory blocks by size
|
||||
allocated_blocks (dict): Dictionary of currently allocated blocks
|
||||
base_tensor (torch.Tensor): Base pinned memory tensor
|
||||
base_address (int): Base memory address of the pinned memory region
|
||||
|
||||
Example:
|
||||
>>> pool = TensorMemoryPool(max_block_size=1024*1024)
|
||||
>>> tensor = torch.randn(100, device='cuda')
|
||||
>>> addr = pool.store_tensor(tensor)
|
||||
>>> loaded_tensor = pool.load_tensor(addr, tensor.dtype,
|
||||
... tensor.shape, 'cuda')
|
||||
>>> pool.free(addr)
|
||||
"""
|
||||
|
||||
|
||||
class TensorMemoryPool:
|
||||
"""Initializes the memory pool with given size constraints.
|
||||
|
||||
Args:
|
||||
max_block_size (int): Maximum size of memory blocks to manage
|
||||
min_block_size (int, optional): Minimum size of memory blocks
|
||||
to manage. Defaults to 512.
|
||||
|
||||
Raises:
|
||||
ValueError: If block sizes are invalid or max_block_size is less
|
||||
than min_block_size
|
||||
"""
|
||||
|
||||
def __init__(self, max_block_size: int, min_block_size: int = 512):
|
||||
if max_block_size <= 0 or min_block_size <= 0:
|
||||
raise ValueError("Block sizes must be positive")
|
||||
if max_block_size < min_block_size:
|
||||
raise ValueError(
|
||||
"Max block size must be greater than min block size")
|
||||
|
||||
self.max_block_size = self._round_to_power_of_two(max_block_size)
|
||||
self.min_block_size = self._round_to_power_of_two(min_block_size)
|
||||
|
||||
self.free_lists: dict[int, dict[int, MemoryBlock]] = {}
|
||||
self.allocated_blocks: dict[int, MemoryBlock] = {}
|
||||
|
||||
self._initialize_free_lists()
|
||||
self._allocate_pinned_memory()
|
||||
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def _round_to_power_of_two(self, size: int) -> int:
|
||||
return 1 << (size - 1).bit_length()
|
||||
|
||||
def _initialize_free_lists(self):
|
||||
size = self.max_block_size
|
||||
while size >= self.min_block_size:
|
||||
self.free_lists[size] = {}
|
||||
size //= 2
|
||||
|
||||
def _allocate_pinned_memory(self):
|
||||
self.base_tensor = torch.empty(self.max_block_size // 4,
|
||||
dtype=torch.float32,
|
||||
pin_memory=True)
|
||||
self.base_address = self.base_tensor.data_ptr()
|
||||
initial_block = MemoryBlock(size=self.max_block_size,
|
||||
addr=self.base_address)
|
||||
self.free_lists[self.max_block_size][
|
||||
initial_block.addr] = initial_block
|
||||
|
||||
logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d",
|
||||
self.base_address, self.max_block_size)
|
||||
|
||||
def allocate(self, size: int) -> int:
|
||||
"""Allocates a memory block of at least the requested size.
|
||||
|
||||
Args:
|
||||
size (int): Minimum size of memory to allocate
|
||||
|
||||
Returns:
|
||||
int: Address of the allocated memory block
|
||||
|
||||
Raises:
|
||||
ValueError: If size is invalid or insufficient memory is available
|
||||
"""
|
||||
if size <= 0:
|
||||
raise ValueError("Allocation size must be positive")
|
||||
|
||||
required_size = self._round_to_power_of_two(
|
||||
max(size, self.min_block_size))
|
||||
if required_size > self.max_block_size:
|
||||
raise ValueError("Requested size exceeds maximum block size")
|
||||
|
||||
current_size = required_size
|
||||
while current_size <= self.max_block_size:
|
||||
if self.free_lists[current_size]:
|
||||
_, block = self.free_lists[current_size].popitem()
|
||||
self._split_block(block, required_size)
|
||||
self.allocated_blocks[block.addr] = block
|
||||
return block.addr
|
||||
current_size *= 2
|
||||
|
||||
raise ValueError("Insufficient memory")
|
||||
|
||||
def _split_block(self, block: MemoryBlock, required_size: int):
|
||||
while (block.size > required_size
|
||||
and block.size // 2 >= self.min_block_size):
|
||||
buddy_size = block.size // 2
|
||||
buddy_addr = block.addr + buddy_size
|
||||
|
||||
buddy = MemoryBlock(size=buddy_size, addr=buddy_addr)
|
||||
block.size = buddy_size
|
||||
|
||||
self.free_lists[buddy_size][buddy.addr] = buddy
|
||||
|
||||
def free(self, addr: int):
|
||||
"""Frees an allocated memory block.
|
||||
|
||||
Args:
|
||||
addr (int): Address of the block to free
|
||||
|
||||
Raises:
|
||||
ValueError: If address is invalid or not allocated
|
||||
"""
|
||||
if addr not in self.allocated_blocks:
|
||||
raise ValueError("Invalid address to free")
|
||||
|
||||
block = self.allocated_blocks.pop(addr)
|
||||
self._merge_buddies(block)
|
||||
|
||||
def _merge_buddies(self, block: MemoryBlock):
|
||||
MAX_MERGE_DEPTH = 30
|
||||
depth = 0
|
||||
|
||||
while depth < MAX_MERGE_DEPTH:
|
||||
buddy_offset = block.size if (block.addr - self.base_address) % (
|
||||
2 * block.size) == 0 else -block.size
|
||||
buddy_addr = block.addr + buddy_offset
|
||||
buddy = self.free_lists[block.size].get(buddy_addr)
|
||||
if buddy:
|
||||
del self.free_lists[buddy.size][buddy.addr]
|
||||
merged_addr = min(block.addr, buddy.addr)
|
||||
merged_size = block.size * 2
|
||||
block = MemoryBlock(size=merged_size, addr=merged_addr)
|
||||
depth += 1
|
||||
else:
|
||||
break
|
||||
self.free_lists[block.size][block.addr] = block
|
||||
|
||||
def store_tensor(self, tensor: torch.Tensor) -> int:
|
||||
"""Stores a CUDA tensor in pinned host memory.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): CUDA tensor to store
|
||||
|
||||
Returns:
|
||||
int: Address where the tensor is stored
|
||||
|
||||
Raises:
|
||||
ValueError: If tensor is not on CUDA or allocation fails
|
||||
"""
|
||||
if not tensor.is_cuda:
|
||||
raise ValueError("Only CUDA tensors can be stored")
|
||||
|
||||
size = tensor.element_size() * tensor.numel()
|
||||
addr = self.allocate(size)
|
||||
block = self.allocated_blocks[addr]
|
||||
|
||||
if block.size < size:
|
||||
self.free(addr)
|
||||
raise ValueError(
|
||||
f"Allocated block size {block.size} is smaller than "
|
||||
f"required size {size}")
|
||||
|
||||
try:
|
||||
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
|
||||
cpu_tensor = torch.frombuffer(buffer,
|
||||
dtype=tensor.dtype,
|
||||
count=tensor.numel()).reshape(
|
||||
tensor.shape)
|
||||
except ValueError as err:
|
||||
self.free(addr)
|
||||
raise ValueError(f"Failed to create tensor view: {err}") from err
|
||||
|
||||
cpu_tensor.copy_(tensor)
|
||||
|
||||
return addr
|
||||
|
||||
def load_tensor(self, addr: int, dtype: torch.dtype,
|
||||
shape: tuple[int, ...], device) -> torch.Tensor:
|
||||
"""Loads a tensor from pinned host memory to the specified device.
|
||||
|
||||
Args:
|
||||
addr (int): Address where tensor is stored
|
||||
dtype (torch.dtype): Data type of the tensor
|
||||
shape (tuple[int, ...]): Shape of the tensor
|
||||
device: Target device for the loaded tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The loaded tensor on the specified device
|
||||
|
||||
Raises:
|
||||
ValueError: If address is invalid or sizes don't match
|
||||
"""
|
||||
if addr not in self.allocated_blocks:
|
||||
raise ValueError("Invalid address to load")
|
||||
|
||||
block = self.allocated_blocks[addr]
|
||||
num_elements = math.prod(shape)
|
||||
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||
required_size = num_elements * dtype_size
|
||||
|
||||
if required_size > block.size:
|
||||
raise ValueError("Requested tensor size exceeds block size")
|
||||
|
||||
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
|
||||
cpu_tensor = torch.frombuffer(buffer, dtype=dtype,
|
||||
count=num_elements).reshape(shape)
|
||||
|
||||
cuda_tensor = torch.empty(shape, dtype=dtype, device=device)
|
||||
|
||||
cuda_tensor.copy_(cpu_tensor)
|
||||
|
||||
return cuda_tensor
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleans up all memory resources and resets the pool state."""
|
||||
self.free_lists.clear()
|
||||
self.allocated_blocks.clear()
|
||||
if hasattr(self, 'base_tensor'):
|
||||
del self.base_tensor
|
||||
|
||||
def __del__(self):
|
||||
self.cleanup()
|
||||
473
vllm_br/distributed/parallel_state.py
Normal file
473
vllm_br/distributed/parallel_state.py
Normal file
@@ -0,0 +1,473 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch_br
|
||||
|
||||
import vllm
|
||||
import vllm.distributed.parallel_state
|
||||
from vllm.distributed import GroupCoordinator
|
||||
from vllm.distributed.parallel_state import (_WORLD, TensorMetadata,
|
||||
_split_tensor_dict, get_pp_group,
|
||||
get_tp_group, get_world_group,
|
||||
init_model_parallel_group, logger)
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphCaptureContext:
|
||||
stream: torch_br.supa.Stream
|
||||
|
||||
|
||||
@contextmanager
|
||||
#@patch_to(GroupCoordinator.graph_capture)
|
||||
def graph_capture_(self,
|
||||
graph_capture_context: Optional[GraphCaptureContext] = None
|
||||
):
|
||||
if graph_capture_context is None:
|
||||
stream = torch_br.supa.Stream()
|
||||
graph_capture_context = GraphCaptureContext(stream)
|
||||
else:
|
||||
stream = graph_capture_context.stream
|
||||
|
||||
# only supa uses this function,
|
||||
# so we don't abstract it into the base class
|
||||
#maybe_ca_context = nullcontext()
|
||||
#from vllm_br.distributed.communicator import SUPACommunicator
|
||||
#if self.device_communicator is not None:
|
||||
# assert isinstance(self.device_communicator, SUPACommunicator)
|
||||
# ca_comm = self.device_communicator.ca_comm
|
||||
# if ca_comm is not None:
|
||||
# maybe_ca_context = ca_comm.capture() # type: ignore
|
||||
|
||||
# ensure all initialization operations complete before attempting to
|
||||
# capture the graph on another stream
|
||||
curr_stream = torch_br.supa.current_stream()
|
||||
if curr_stream != stream:
|
||||
stream.wait_stream(curr_stream)
|
||||
|
||||
with torch_br.supa.stream(stream):
|
||||
yield graph_capture_context
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.GroupCoordinator.graph_capture = graph_capture_
|
||||
|
||||
|
||||
@contextmanager
|
||||
#@patch_to(graph_capture)
|
||||
def graph_capture_supa(device: torch.device):
|
||||
"""
|
||||
`graph_capture` is a context manager which should surround the code that
|
||||
is capturing the SUPA graph. Its main purpose is to ensure that the
|
||||
some operations will be run after the graph is captured, before the graph
|
||||
is replayed. It returns a `GraphCaptureContext` object which contains the
|
||||
necessary data for the graph capture. Currently, it only contains the
|
||||
stream that the graph capture is running on. This stream is set to the
|
||||
current SUPA stream when the context manager is entered and reset to the
|
||||
default stream when the context manager is exited. This is to ensure that
|
||||
the graph capture is running on a separate stream from the default stream,
|
||||
in order to explicitly distinguish the kernels to capture
|
||||
from other kernels possibly launched on background in the default stream.
|
||||
"""
|
||||
context = GraphCaptureContext(torch_br.supa.Stream(device=device))
|
||||
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
|
||||
context):
|
||||
yield context
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.graph_capture = graph_capture_supa
|
||||
|
||||
|
||||
def is_global_first_rank() -> bool:
|
||||
"""
|
||||
Check if the current process is the first rank globally across all
|
||||
parallelism strategies (PP, TP, DP, EP, etc.).
|
||||
|
||||
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
|
||||
or `get_pp_group().is_first_rank`, this function checks the global rank
|
||||
across all parallelism dimensions.
|
||||
|
||||
Returns:
|
||||
bool: True if this is the global first rank (rank 0), False otherwise.
|
||||
Returns True if distributed is not initialized (single process).
|
||||
"""
|
||||
try:
|
||||
# If world group is available, use it for the most accurate check
|
||||
if _WORLD is not None:
|
||||
return _WORLD.is_first_rank
|
||||
|
||||
# If torch distributed is not initialized, assume single process
|
||||
if not torch.distributed.is_initialized():
|
||||
return True
|
||||
|
||||
# Fallback to torch's global rank
|
||||
return torch.distributed.get_rank() == 0
|
||||
|
||||
except Exception:
|
||||
# If anything goes wrong, assume this is the first rank
|
||||
return True
|
||||
|
||||
|
||||
def generate_multi_node_parallel_groups(
|
||||
total_procs: int,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
dp_size: int,
|
||||
) -> dict:
|
||||
if total_procs == 16 and tp_size == 8 and pp_size == 2 and dp_size == 1:
|
||||
tp_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
|
||||
pp_groups = [[0, 4], [1, 5], [2, 6], [3, 7], [8, 12], [9, 13],
|
||||
[10, 14], [11, 15]]
|
||||
dp_groups = [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10],
|
||||
[11], [12], [13], [14], [15]]
|
||||
ep_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE parallel config of"
|
||||
" tp_size: {tp_size} pp_size: {pp_size} dp_size: {dp_size}"
|
||||
"Currently only 'tp8pp2dp1' is allowed.")
|
||||
return {
|
||||
"tp_groups": tp_groups,
|
||||
"pp_groups": pp_groups,
|
||||
"dp_groups": dp_groups,
|
||||
"ep_groups": ep_groups,
|
||||
}
|
||||
|
||||
|
||||
# sync v0.11 api update, while code logic possibly need sync with vllm original code implementation
|
||||
def initialize_model_parallel_cross_tp(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
decode_context_model_parallel_size: Optional[int] = 1,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model parallel groups.
|
||||
|
||||
Arguments:
|
||||
tensor_model_parallel_size: number of GPUs used for tensor model
|
||||
parallelism.
|
||||
pipeline_model_parallel_size: number of GPUs used for pipeline model
|
||||
parallelism.
|
||||
backend: name of torch distributed communication backend.
|
||||
|
||||
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
|
||||
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
||||
the model pipeline. The present function will
|
||||
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
|
||||
4 tensor model-parallel groups:
|
||||
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
|
||||
2 pipeline model-parallel groups:
|
||||
[g0, g2, g4, g6], [g1, g3, g5, g7]
|
||||
Note that for efficiency, the caller should make sure adjacent ranks
|
||||
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
||||
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
||||
ranks 8 to 15 belong to the second box.
|
||||
"""
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
backend = backend or torch.distributed.get_backend(
|
||||
get_world_group().device_group)
|
||||
|
||||
data_parallel_size = 1
|
||||
from vllm.config import get_current_vllm_config
|
||||
config = get_current_vllm_config()
|
||||
if config is not None:
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
|
||||
# the layout order is: ExternalDP x DP x PP x TP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
# every dp rank can generate independently (in verl integration).
|
||||
# DP is the data parallel group that is part of the model,
|
||||
# all the ranks in the same DP group should generate simultaneously,
|
||||
# i.e. the `generate` call in the same DP group should be called together,
|
||||
# otherwise it will cause deadlock.
|
||||
# to get group_ranks for each dimension, transpose that dimension to the
|
||||
# last dimension, then reshape to 2D, then unbind the last dimension
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, data_parallel_size, pipeline_model_parallel_size,
|
||||
tensor_model_parallel_size) # noqa
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
groups = generate_multi_node_parallel_groups(
|
||||
world_size, tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size, data_parallel_size)
|
||||
logger.info("supernode reorganized groups: %s", groups)
|
||||
# Build the tensor model-parallel groups.
|
||||
assert vllm.distributed.parallel_state._TP is None, (
|
||||
"tensor model parallel group is already initialized")
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
group_ranks = groups['tp_groups']
|
||||
else:
|
||||
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
vllm.distributed.parallel_state._TP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="tp")
|
||||
|
||||
# Build the DCP model-parallel groups.
|
||||
# global _DCP
|
||||
assert vllm.distributed.parallel_state._DCP is None, (
|
||||
"decode context model parallel group is already initialized")
|
||||
# Note(hc): In the current implementation of decode context parallel,
|
||||
# dcp_size must not exceed tp_size, because the world size does not
|
||||
# change by DCP, it simply reuses the GPUs of TP group, and split one
|
||||
# TP group into tp_size//dcp_size DCP groups.
|
||||
group_ranks = all_ranks.reshape(
|
||||
-1, decode_context_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
vllm.distributed.parallel_state._DCP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="dcp")
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
assert vllm.distributed.parallel_state._PP is None, (
|
||||
"pipeline model parallel group is already initialized")
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
group_ranks = groups['pp_groups']
|
||||
else:
|
||||
group_ranks = all_ranks.transpose(2, 3).reshape(
|
||||
-1, pipeline_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
vllm.distributed.parallel_state._PP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="pp")
|
||||
|
||||
assert vllm.distributed.parallel_state._DP is None, (
|
||||
"data parallel group is already initialized")
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
group_ranks = groups['dp_groups']
|
||||
else:
|
||||
group_ranks = all_ranks.transpose(1, 3).reshape(
|
||||
-1, data_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
vllm.distributed.parallel_state._DP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="dp")
|
||||
|
||||
assert vllm.distributed.parallel_state._EP is None, (
|
||||
"expert parallel group is already initialized")
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
group_ranks = groups['ep_groups']
|
||||
else:
|
||||
group_ranks = all_ranks.transpose(1, 2).reshape(
|
||||
-1, data_parallel_size * tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
vllm.distributed.parallel_state._EP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="ep")
|
||||
logger.info(
|
||||
"rank %s in world size %s is assigned as (br) "
|
||||
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
|
||||
vllm.distributed.parallel_state._DP.rank_in_group,
|
||||
vllm.distributed.parallel_state._PP.rank_in_group,
|
||||
vllm.distributed.parallel_state._TP.rank_in_group,
|
||||
vllm.distributed.parallel_state._EP.rank_in_group)
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.initialize_model_parallel = initialize_model_parallel_cross_tp
|
||||
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
||||
dst: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: Optional[dict[str, bool]] = None,
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Send the input tensor dictionary.
|
||||
NOTE: `dst` is the local rank of the source rank.
|
||||
|
||||
all_gather_group: The group for the all-gather operation. If provided,
|
||||
an optimization is enabled where each rank in the group sends a
|
||||
slice of a tensor and the receiver reconstructs it using an
|
||||
all-gather, which can improve performance. This is typically the
|
||||
tensor-parallel group.
|
||||
all_gather_tensors: A dictionary to specify which tensors should use
|
||||
the all-gather optimization, which is only effective when
|
||||
`all_gather_group` is provided. By default, this optimization is
|
||||
on for any tensor whose size is divisible by the
|
||||
`all_gather_group`'s world size. However, it should be disabled
|
||||
for tensors that are not fully replicated across the group (e.g.,
|
||||
the residual tensor when sequence parallelism is enabled). This
|
||||
dictionary allows overriding the default behavior on a per-tensor
|
||||
basis.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return tensor_dict
|
||||
all_gather_size = (1 if all_gather_group is None else
|
||||
all_gather_group.world_size)
|
||||
all_gather_rank = (0 if all_gather_group is None else
|
||||
all_gather_group.rank_in_group)
|
||||
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
self.device_communicator.send_tensor_dict( # type: ignore
|
||||
tensor_dict, dst)
|
||||
return None
|
||||
|
||||
metadata_list: list[tuple[Any, Any]] = []
|
||||
assert isinstance(tensor_dict,
|
||||
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
# `metadata_list` lives in CPU memory.
|
||||
# `send_object_list` has serialization & deserialization,
|
||||
# all happening on CPU. Therefore, we can use the CPU group.
|
||||
self.send_object(metadata_list, dst=dst)
|
||||
|
||||
tensor_keys = [
|
||||
k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)
|
||||
]
|
||||
assert len(tensor_keys) == len(tensor_list)
|
||||
|
||||
for key, tensor in zip(tensor_keys, tensor_list):
|
||||
if tensor.numel() == 0:
|
||||
# Skip sending empty tensors.
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
use_all_gather = (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0)
|
||||
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
|
||||
if all_gather_tensors else use_all_gather
|
||||
if use_all_gather:
|
||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.send(tensor,
|
||||
dst=self.ranks[dst],
|
||||
group=metadata_group)
|
||||
else:
|
||||
# ensure tensor is ready
|
||||
torch.supa.synchronize()
|
||||
# use group for GPU tensors
|
||||
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
|
||||
return None
|
||||
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: Optional[dict[str, bool]] = None,
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Recv the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
|
||||
all_gather_group: The group for the all-gather operation. If provided,
|
||||
an optimization is enabled where each rank in the group sends a
|
||||
slice of a tensor and the receiver reconstructs it using an
|
||||
all-gather, which can improve performance. This is typically the
|
||||
tensor-parallel group.
|
||||
all_gather_tensors: A dictionary to specify which tensors should use
|
||||
the all-gather optimization, which is only effective when
|
||||
`all_gather_group` is provided. By default, this optimization is
|
||||
on for any tensor whose size is divisible by the
|
||||
`all_gather_group`'s world size. However, it should be disabled
|
||||
for tensors that are not fully replicated across the group (e.g.,
|
||||
the residual tensor when sequence parallelism is enabled). This
|
||||
dictionary allows overriding the default behavior on a per-tensor
|
||||
basis.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return None
|
||||
all_gather_size = (1 if all_gather_group is None else
|
||||
all_gather_group.world_size)
|
||||
all_gather_rank = (0 if all_gather_group is None else
|
||||
all_gather_group.rank_in_group)
|
||||
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
return self.device_communicator.recv_tensor_dict( # type: ignore
|
||||
src)
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
tensor_dict: dict[str, Any] = {}
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
device=value.device)
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
tensor_dict[key] = tensor
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
use_all_gather = (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0)
|
||||
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
|
||||
if all_gather_tensors else use_all_gather
|
||||
|
||||
if use_all_gather:
|
||||
orig_shape = tensor.shape
|
||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.recv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
torch.distributed.recv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=group)
|
||||
# ensure recv is done
|
||||
torch.supa.synchronize()
|
||||
if use_all_gather:
|
||||
# do the allgather
|
||||
tensor = all_gather_group.all_gather( # type: ignore
|
||||
tensor, dim=0)
|
||||
tensor = tensor.reshape(orig_shape)
|
||||
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
return tensor_dict
|
||||
|
||||
|
||||
vllm.distributed.GroupCoordinator.send_tensor_dict = send_tensor_dict
|
||||
vllm.distributed.GroupCoordinator.recv_tensor_dict = recv_tensor_dict
|
||||
Reference in New Issue
Block a user