Files
enginex-vastai-va16-vllm/vllm_vacc/vllm/distributed/parallel_state.py
2026-04-02 04:55:00 +00:00

485 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.distributed
from vllm.platforms import current_platform
from vllm.utils import supports_custom_op
from collections import namedtuple
from typing import (Any, Dict, List, Optional, Tuple,
Union)
import torch
import torch.distributed
from torch.distributed import ProcessGroup
from vllm.utils import supports_custom_op
from vllm.distributed.parallel_state import TensorMetadata
# memory recycler
MEMORY_RECYCLER_KEY = ['previous_hidden_states']
def _split_tensor_dict_concat(
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
# all_tensor_list = ['input_tokens','input_positions', 'slot_mapping','seq_lens_tensor', 'context_lens_tensor','block_tables', 'query_start_loc','seq_start_loc', 'selected_token_indices']
metadata_list: List[Tuple[str, Any]] = []
tensor_list: List[torch.Tensor] = []
all_tensor = []
all_tensor_numel = 0
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = value.device.type
if not value.is_cpu and value.numel() > 0:
value_bytes_tensor = value.view(torch.int8)
all_tensor.append(value_bytes_tensor.view([-1]))
all_tensor_numel += value_bytes_tensor.numel()
# tensor_list.append(value)
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size())))
else:
metadata_list.append((key, value))
if len(all_tensor) != 0:
memory_recycler_dynamic_output = None
# 计算all_tensor的总大小
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler):
memory_recycler_dynamic_output = memory_recycler.DYNAMIC_OUTPUT_BUFFER.view(torch.int8)[:all_tensor_numel]
if memory_recycler_dynamic_output is not None:
all_tensor = torch.concatenate(all_tensor, 0, out = memory_recycler_dynamic_output)
else:
all_tensor = torch.concatenate(all_tensor, 0)
tensor_list.append(all_tensor)
metadata_list.append(("all_tensor", TensorMetadata(all_tensor.device.type, all_tensor.dtype, all_tensor.size())))
return metadata_list, tensor_list
def all_gather_to_rank0(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)
# For HPUs, use HPU communicator.
hpu_comm = self.hpu_communicator
if hpu_comm is not None and not hpu_comm.disabled:
return hpu_comm.all_gather(input_, dim)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size, ) + input_size[1:]
try:
total_bytes = input_.numel() * input_.element_size() * world_size
# only support 4M now
if total_bytes < 4194304:
from torch_vacc.vacc.custom_ops import all_gather
output_tensor = all_gather(input_, self.rank_in_group, self.world_size, self.group_id,
dev_info = self.rank_device_infos)
if self.rank_in_group != 0:
output_tensor = None
else:
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
except Exception as e:
print("all_gather by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype)
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
if self.rank_in_group != 0:
output_tensor = None
else:
# Reshape
output_tensor = output_tensor.reshape((world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def generate_group_id(self, group_id):
self.group_id = group_id
def generate_rank_device_infos(self):
import numpy as np
import os
# encoder rank_dev_list
def combine_arrays(a, b):
a = np.asarray(a, dtype=np.uint32)
b = np.asarray(b, dtype=np.uint32)
if len(a) != len(b):
raise ValueError("两个数组的长度必须一致。")
a_shifted = np.left_shift(a, 16)
combined = np.bitwise_or(a_shifted, b)
return combined.tolist()
# decoder rank_dev_list
def uncombine_array(array):
array = np.asarray(array, dtype=np.uint32)
o_0 = array >> 16
o_1 = array << 16 >> 16
return o_0, o_1
physical_devices = self.ranks
visible_devices = os.getenv('VACC_VISIBLE_DEVICES')
if visible_devices is not None:
device_list = visible_devices.split(',')
device_count = len(device_list)
assert device_count >= len(self.ranks), f'VACC_VISIBLE_DEVICES:{device_count} is less than ranks:{len(self.ranks)}, please designate more devices'
physical_devices = [int(device_list[i]) for i in self.ranks]
# print("[vccl] logic_devices:physical_devices ", self.ranks, physical_devices)
logic_ranks = [self.ranks.index(rank) for rank in self.ranks]
self.rank_device_infos = combine_arrays(logic_ranks, physical_devices)
def get_bitwidth(dtype):
if dtype.is_floating_point:
return torch.finfo(dtype).bits
else:
return torch.iinfo(dtype).bits
class GroupCoordinator:
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
if input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
return input_
# vacc impl
# s0, s1 = input_.shape
# output_tensor = torch.empty([32, s0, s1],
# dtype=input_.dtype,
# device=input_.device)
# torch.distributed.all_gather_into_tensor(output_tensor,
# input_,
# group=self.device_group)
# input_ = output_tensor.sum(dim=0)
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# all_tensor_list = ['input_tokens','input_positions', 'slot_mapping','seq_lens_tensor', 'context_lens_tensor','block_tables', 'query_start_loc','seq_start_loc', 'selected_token_indices']
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized() or self.world_size == 1):
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
assert src < self.world_size, f"Invalid src rank ({src})"
rank_in_group = self.rank_in_group
if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
# metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
metadata_list, tensor_list = _split_tensor_dict_concat(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
# metadata_list 包含 key, value value是metadata 只有shape,没有数据
self.broadcast_object(metadata_list, src=src)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=metadata_group,
async_op=True)
async_handles.append(handle)
else:
# use group for GPU tensors
total_bytes = tensor.numel() * tensor.element_size()
use_dist = True
# only support 4M now
if total_bytes < 4194304:
try:
from torch_vacc.vacc.custom_ops import broadcast
#print("send tensor is:", tensor.shape, tensor.dtype, self.rank)
broadcast(tensor, self.rank_in_group, self.world_size, root_rank=0, group_id=self.group_id,
dev_info = self.rank_device_infos)
use_dist = False
except Exception as e:
print("odsp broadcast run fail, now using distributed:", e)
if use_dist:
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=group,
async_op=True)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
else:
# other rank
metadata_list = self.broadcast_object(None, src=src)
tensor_dict = {}
async_handles = []
tensor_size = [] # list of [key, shape] for split all_tensor
dataType_list = []
for key, value in metadata_list:
# if rank_in_group == 1:
# print('rank1 k v ', key, value)
if isinstance(value, TensorMetadata):
tensor = None
# 固定为int8
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler) and value.dtype == torch.int8:
tensor = memory_recycler.DYNAMIC_OUTPUT_BUFFER.view(value.dtype)[:value.size.numel()].view(value.size)
if tensor is None:
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(
tensor,
src=self.ranks[src],
group=metadata_group,
async_op=True)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
# use group for GPU tensors
if key == "all_tensor":
total_bytes = tensor.numel() * tensor.element_size()
use_dist = True
# only support 4M now
if total_bytes < 4194304:
try:
from torch_vacc.vacc.custom_ops import broadcast
tensor = broadcast(tensor, self.rank_in_group, self.world_size, root_rank=0, group_id=self.group_id,
dev_info = self.rank_device_infos)
use_dist = False
except Exception as e:
print("dsp brocast run fail, now using distributed:", e)
if use_dist:
handle = torch.distributed.broadcast(
tensor, #拼接的tensor
src=self.ranks[src],
group=group,
async_op=False)
# 按 key shape对, 拆分 all_tensor, 存入tensor_dict
start = 0
idx = 0
for ki_vi in tensor_size:
ki, vi = ki_vi
length = vi.numel() * int(get_bitwidth(dataType_list[idx]) / 8)
if ki in MEMORY_RECYCLER_KEY:
tensor_dict[ki] = tensor[start:start+length].view(dataType_list[idx]).view(vi)
else:
value_tensor = torch.empty(vi,
dtype=dataType_list[idx],
device=value.device)
recv_tensor = tensor[start:start+length].view(dataType_list[idx]).view(vi)
tensor_dict[ki] = value_tensor.copy_(recv_tensor)
start += length
idx += 1
else:
dataType_list.append(value.dtype)
tensor_size.append([key, value.size])
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
def all_gather(self, input_: torch.Tensor, dim: int = -1, output_: torch.Tensor = None) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if self.use_custom_op_call:
return torch.ops.vllm.all_gather(input_,
dim,
world_size,
group_name=self.unique_name)
else:
# 启用输出复用版的 all_gather
if output_ is not None:
return self.device_communicator.all_gather_into_tensor(input_, dim, output_)
return self._all_gather_out_place(input_, dim)
def recv_tensor_dict(
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import alloc_pipeline_parallel_recycler_buffer
memory_recycler_list = ["hidden_states", "residual"]
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
# 判断是否需要内存复用
# 1. key在[hiddens, residual]中,说明为PP造成的
# 2. 可以根据key从 memory_recycling 模块中申请到tensor
use_create_recycler_tensor = False
tensor = None
value_tensor = None # 用于接收all_gather总数据
if key in memory_recycler_list:
tensor = alloc_pipeline_parallel_recycler_buffer(value.size, value.dtype, key)
if tensor is not None:
use_create_recycler_tensor = True
if not use_create_recycler_tensor:
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
value_tensor = tensor
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
if use_all_gather:
orig_shape = tensor.shape
# 内存复用无需reshape, view即可
if use_create_recycler_tensor:
tensor = tensor.view(all_gather_size,
-1)[all_gather_rank].contiguous()
else:
tensor = tensor.reshape(all_gather_size,
-1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
if use_all_gather:
# do the allgather
if use_create_recycler_tensor:
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0, output_ = value_tensor)
tensor = tensor.view(orig_shape)
else:
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0)
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict