Files

485 lines
22 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
import torch
import torch.distributed
from vllm.platforms import current_platform
from vllm.utils import supports_custom_op
from collections import namedtuple
from typing import (Any, Dict, List, Optional, Tuple,
Union)
import torch
import torch.distributed
from torch.distributed import ProcessGroup
from vllm.utils import supports_custom_op
from vllm.distributed.parallel_state import TensorMetadata
# memory recycler
MEMORY_RECYCLER_KEY = ['previous_hidden_states']
def _split_tensor_dict_concat(
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
# all_tensor_list = ['input_tokens','input_positions', 'slot_mapping','seq_lens_tensor', 'context_lens_tensor','block_tables', 'query_start_loc','seq_start_loc', 'selected_token_indices']
metadata_list: List[Tuple[str, Any]] = []
tensor_list: List[torch.Tensor] = []
all_tensor = []
all_tensor_numel = 0
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = value.device.type
if not value.is_cpu and value.numel() > 0:
value_bytes_tensor = value.view(torch.int8)
all_tensor.append(value_bytes_tensor.view([-1]))
all_tensor_numel += value_bytes_tensor.numel()
# tensor_list.append(value)
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size())))
else:
metadata_list.append((key, value))
if len(all_tensor) != 0:
memory_recycler_dynamic_output = None
# 计算all_tensor的总大小
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler):
memory_recycler_dynamic_output = memory_recycler.DYNAMIC_OUTPUT_BUFFER.view(torch.int8)[:all_tensor_numel]
if memory_recycler_dynamic_output is not None:
all_tensor = torch.concatenate(all_tensor, 0, out = memory_recycler_dynamic_output)
else:
all_tensor = torch.concatenate(all_tensor, 0)
tensor_list.append(all_tensor)
metadata_list.append(("all_tensor", TensorMetadata(all_tensor.device.type, all_tensor.dtype, all_tensor.size())))
return metadata_list, tensor_list
def all_gather_to_rank0(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)
# For HPUs, use HPU communicator.
hpu_comm = self.hpu_communicator
if hpu_comm is not None and not hpu_comm.disabled:
return hpu_comm.all_gather(input_, dim)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size, ) + input_size[1:]
try:
total_bytes = input_.numel() * input_.element_size() * world_size
# only support 4M now
if total_bytes < 4194304:
from torch_vacc.vacc.custom_ops import all_gather
output_tensor = all_gather(input_, self.rank_in_group, self.world_size, self.group_id,
dev_info = self.rank_device_infos)
if self.rank_in_group != 0:
output_tensor = None
else:
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
except Exception as e:
print("all_gather by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype)
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
if self.rank_in_group != 0:
output_tensor = None
else:
# Reshape
output_tensor = output_tensor.reshape((world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def generate_group_id(self, group_id):
self.group_id = group_id
def generate_rank_device_infos(self):
import numpy as np
import os
# encoder rank_dev_list
def combine_arrays(a, b):
a = np.asarray(a, dtype=np.uint32)
b = np.asarray(b, dtype=np.uint32)
if len(a) != len(b):
raise ValueError("两个数组的长度必须一致。")
a_shifted = np.left_shift(a, 16)
combined = np.bitwise_or(a_shifted, b)
return combined.tolist()
# decoder rank_dev_list
def uncombine_array(array):
array = np.asarray(array, dtype=np.uint32)
o_0 = array >> 16
o_1 = array << 16 >> 16
return o_0, o_1
physical_devices = self.ranks
visible_devices = os.getenv('VACC_VISIBLE_DEVICES')
if visible_devices is not None:
device_list = visible_devices.split(',')
device_count = len(device_list)
assert device_count >= len(self.ranks), f'VACC_VISIBLE_DEVICES:{device_count} is less than ranks:{len(self.ranks)}, please designate more devices'
physical_devices = [int(device_list[i]) for i in self.ranks]
# print("[vccl] logic_devices:physical_devices ", self.ranks, physical_devices)
logic_ranks = [self.ranks.index(rank) for rank in self.ranks]
self.rank_device_infos = combine_arrays(logic_ranks, physical_devices)
def get_bitwidth(dtype):
if dtype.is_floating_point:
return torch.finfo(dtype).bits
else:
return torch.iinfo(dtype).bits
class GroupCoordinator:
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
if input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
return input_
# vacc impl
# s0, s1 = input_.shape
# output_tensor = torch.empty([32, s0, s1],
# dtype=input_.dtype,
# device=input_.device)
# torch.distributed.all_gather_into_tensor(output_tensor,
# input_,
# group=self.device_group)
# input_ = output_tensor.sum(dim=0)
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# all_tensor_list = ['input_tokens','input_positions', 'slot_mapping','seq_lens_tensor', 'context_lens_tensor','block_tables', 'query_start_loc','seq_start_loc', 'selected_token_indices']
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized() or self.world_size == 1):
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
assert src < self.world_size, f"Invalid src rank ({src})"
rank_in_group = self.rank_in_group
if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
# metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
metadata_list, tensor_list = _split_tensor_dict_concat(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
# metadata_list 包含 key, value value是metadata 只有shape,没有数据
self.broadcast_object(metadata_list, src=src)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=metadata_group,
async_op=True)
async_handles.append(handle)
else:
# use group for GPU tensors
total_bytes = tensor.numel() * tensor.element_size()
use_dist = True
# only support 4M now
if total_bytes < 4194304:
try:
from torch_vacc.vacc.custom_ops import broadcast
#print("send tensor is:", tensor.shape, tensor.dtype, self.rank)
broadcast(tensor, self.rank_in_group, self.world_size, root_rank=0, group_id=self.group_id,
dev_info = self.rank_device_infos)
use_dist = False
except Exception as e:
print("odsp broadcast run fail, now using distributed:", e)
if use_dist:
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=group,
async_op=True)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
else:
# other rank
metadata_list = self.broadcast_object(None, src=src)
tensor_dict = {}
async_handles = []
tensor_size = [] # list of [key, shape] for split all_tensor
dataType_list = []
for key, value in metadata_list:
# if rank_in_group == 1:
# print('rank1 k v ', key, value)
if isinstance(value, TensorMetadata):
tensor = None
# 固定为int8
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler) and value.dtype == torch.int8:
tensor = memory_recycler.DYNAMIC_OUTPUT_BUFFER.view(value.dtype)[:value.size.numel()].view(value.size)
if tensor is None:
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(
tensor,
src=self.ranks[src],
group=metadata_group,
async_op=True)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
# use group for GPU tensors
if key == "all_tensor":
total_bytes = tensor.numel() * tensor.element_size()
use_dist = True
# only support 4M now
if total_bytes < 4194304:
try:
from torch_vacc.vacc.custom_ops import broadcast
tensor = broadcast(tensor, self.rank_in_group, self.world_size, root_rank=0, group_id=self.group_id,
dev_info = self.rank_device_infos)
use_dist = False
except Exception as e:
print("dsp brocast run fail, now using distributed:", e)
if use_dist:
handle = torch.distributed.broadcast(
tensor, #拼接的tensor
src=self.ranks[src],
group=group,
async_op=False)
# 按 key shape对, 拆分 all_tensor, 存入tensor_dict
start = 0
idx = 0
for ki_vi in tensor_size:
ki, vi = ki_vi
length = vi.numel() * int(get_bitwidth(dataType_list[idx]) / 8)
if ki in MEMORY_RECYCLER_KEY:
tensor_dict[ki] = tensor[start:start+length].view(dataType_list[idx]).view(vi)
else:
value_tensor = torch.empty(vi,
dtype=dataType_list[idx],
device=value.device)
recv_tensor = tensor[start:start+length].view(dataType_list[idx]).view(vi)
tensor_dict[ki] = value_tensor.copy_(recv_tensor)
start += length
idx += 1
else:
dataType_list.append(value.dtype)
tensor_size.append([key, value.size])
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
def all_gather(self, input_: torch.Tensor, dim: int = -1, output_: torch.Tensor = None) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if self.use_custom_op_call:
return torch.ops.vllm.all_gather(input_,
dim,
world_size,
group_name=self.unique_name)
else:
# 启用输出复用版的 all_gather
if output_ is not None:
return self.device_communicator.all_gather_into_tensor(input_, dim, output_)
return self._all_gather_out_place(input_, dim)
def recv_tensor_dict(
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import alloc_pipeline_parallel_recycler_buffer
memory_recycler_list = ["hidden_states", "residual"]
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
# 判断是否需要内存复用
# 1. key在[hiddens, residual]中,说明为PP造成的
# 2. 可以根据key从 memory_recycling 模块中申请到tensor
use_create_recycler_tensor = False
tensor = None
value_tensor = None # 用于接收all_gather总数据
if key in memory_recycler_list:
tensor = alloc_pipeline_parallel_recycler_buffer(value.size, value.dtype, key)
if tensor is not None:
use_create_recycler_tensor = True
if not use_create_recycler_tensor:
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
value_tensor = tensor
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
if use_all_gather:
orig_shape = tensor.shape
# 内存复用无需reshape, view即可
if use_create_recycler_tensor:
tensor = tensor.view(all_gather_size,
-1)[all_gather_rank].contiguous()
else:
tensor = tensor.reshape(all_gather_size,
-1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
if use_all_gather:
# do the allgather
if use_create_recycler_tensor:
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0, output_ = value_tensor)
tensor = tensor.view(orig_shape)
else:
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0)
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict