import torch import torch.distributed from vllm.platforms import current_platform from vllm.utils import supports_custom_op from collections import namedtuple from typing import (Any, Dict, List, Optional, Tuple, Union) import torch import torch.distributed from torch.distributed import ProcessGroup from vllm.utils import supports_custom_op from vllm.distributed.parallel_state import TensorMetadata # memory recycler MEMORY_RECYCLER_KEY = ['previous_hidden_states'] def _split_tensor_dict_concat( tensor_dict: Dict[str, Union[torch.Tensor, Any]] ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced by its metadata. 2. A list of tensors. """ # all_tensor_list = ['input_tokens','input_positions', 'slot_mapping','seq_lens_tensor', 'context_lens_tensor','block_tables', 'query_start_loc','seq_start_loc', 'selected_token_indices'] metadata_list: List[Tuple[str, Any]] = [] tensor_list: List[torch.Tensor] = [] all_tensor = [] all_tensor_numel = 0 for key, value in tensor_dict.items(): if isinstance(value, torch.Tensor): # Note: we cannot use `value.device` here, # because it contains not only the device type but also the device # index (e.g. "cuda:0"). We only need the device type. # receiving side will set the device index. device = value.device.type if not value.is_cpu and value.numel() > 0: value_bytes_tensor = value.view(torch.int8) all_tensor.append(value_bytes_tensor.view([-1])) all_tensor_numel += value_bytes_tensor.numel() # tensor_list.append(value) metadata_list.append( (key, TensorMetadata(device, value.dtype, value.size()))) else: metadata_list.append((key, value)) if len(all_tensor) != 0: memory_recycler_dynamic_output = None # 计算all_tensor的总大小 from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler if isinstance(memory_recycler, DeepseekMTPMemoryRecycler): memory_recycler_dynamic_output = memory_recycler.DYNAMIC_OUTPUT_BUFFER.view(torch.int8)[:all_tensor_numel] if memory_recycler_dynamic_output is not None: all_tensor = torch.concatenate(all_tensor, 0, out = memory_recycler_dynamic_output) else: all_tensor = torch.concatenate(all_tensor, 0) tensor_list.append(all_tensor) metadata_list.append(("all_tensor", TensorMetadata(all_tensor.device.type, all_tensor.dtype, all_tensor.size()))) return metadata_list, tensor_list def all_gather_to_rank0(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") # For TPUs, use TPU communicator. tpu_comm = self.tpu_communicator if tpu_comm is not None and not tpu_comm.disabled: return tpu_comm.all_gather(input_, dim) # For HPUs, use HPU communicator. hpu_comm = self.hpu_communicator if hpu_comm is not None and not hpu_comm.disabled: return hpu_comm.all_gather(input_, dim) if dim < 0: # Convert negative dim to positive. dim += input_.dim() input_size = input_.size() # NOTE: we have to use concat-style all-gather here, # stack-style all-gather has compatibility issues with # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 output_size = (input_size[0] * world_size, ) + input_size[1:] try: total_bytes = input_.numel() * input_.element_size() * world_size # only support 4M now if total_bytes < 4194304: from torch_vacc.vacc.custom_ops import all_gather output_tensor = all_gather(input_, self.rank_in_group, self.world_size, self.group_id, dev_info = self.rank_device_infos) if self.rank_in_group != 0: output_tensor = None else: output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape(input_size[:dim] + (world_size * input_size[dim], ) + input_size[dim + 1:]) return output_tensor except Exception as e: print("all_gather by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype) # Allocate output tensor. output_tensor = torch.empty(output_size, dtype=input_.dtype, device=input_.device) # All-gather. torch.distributed.all_gather_into_tensor(output_tensor, input_, group=self.device_group) if self.rank_in_group != 0: output_tensor = None else: # Reshape output_tensor = output_tensor.reshape((world_size, ) + input_size) output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape(input_size[:dim] + (world_size * input_size[dim], ) + input_size[dim + 1:]) return output_tensor def generate_group_id(self, group_id): self.group_id = group_id def generate_rank_device_infos(self): import numpy as np import os # encoder rank_dev_list def combine_arrays(a, b): a = np.asarray(a, dtype=np.uint32) b = np.asarray(b, dtype=np.uint32) if len(a) != len(b): raise ValueError("两个数组的长度必须一致。") a_shifted = np.left_shift(a, 16) combined = np.bitwise_or(a_shifted, b) return combined.tolist() # decoder rank_dev_list def uncombine_array(array): array = np.asarray(array, dtype=np.uint32) o_0 = array >> 16 o_1 = array << 16 >> 16 return o_0, o_1 physical_devices = self.ranks visible_devices = os.getenv('VACC_VISIBLE_DEVICES') if visible_devices is not None: device_list = visible_devices.split(',') device_count = len(device_list) assert device_count >= len(self.ranks), f'VACC_VISIBLE_DEVICES:{device_count} is less than ranks:{len(self.ranks)}, please designate more devices' physical_devices = [int(device_list[i]) for i in self.ranks] # print("[vccl] logic_devices:physical_devices ", self.ranks, physical_devices) logic_ranks = [self.ranks.index(rank) for rank in self.ranks] self.rank_device_infos = combine_arrays(logic_ranks, physical_devices) def get_bitwidth(dtype): if dtype.is_floating_point: return torch.finfo(dtype).bits else: return torch.iinfo(dtype).bits class GroupCoordinator: def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: """ User-facing all-reduce function before we actually call the all-reduce operation. We need this because Dynamo does not support passing an arbitrary object (`self` in this case) to a custom op. We need to pass the group name as a string, and then look up the group coordinator from the group name, dispatch the all-reduce operation to the group coordinator. In addition, PyTorch custom ops do not support mutation or returning a new tensor in the same op. So we need to figure out if the op is in-place or out-of-place ahead of time. """ # Bypass the function if we are using only 1 GPU. if self.world_size == 1: return input_ if input_.is_cpu: import intel_extension_for_pytorch as ipex ipex.distributed.all_reduce(input_, group=self.device_group) return input_ # vacc impl # s0, s1 = input_.shape # output_tensor = torch.empty([32, s0, s1], # dtype=input_.dtype, # device=input_.device) # torch.distributed.all_gather_into_tensor(output_tensor, # input_, # group=self.device_group) # input_ = output_tensor.sum(dim=0) torch.distributed.all_reduce(input_, group=self.device_group) return input_ def broadcast_tensor_dict( self, tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, metadata_group: Optional[ProcessGroup] = None ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ # all_tensor_list = ['input_tokens','input_positions', 'slot_mapping','seq_lens_tensor', 'context_lens_tensor','block_tables', 'query_start_loc','seq_start_loc', 'selected_token_indices'] # Bypass the function if we are using only 1 GPU. if (not torch.distributed.is_initialized() or self.world_size == 1): return tensor_dict group = self.device_group metadata_group = self.cpu_group assert src < self.world_size, f"Invalid src rank ({src})" rank_in_group = self.rank_in_group if rank_in_group == src: metadata_list: List[Tuple[Any, Any]] = [] assert isinstance( tensor_dict, dict), (f"Expecting a dictionary, got {type(tensor_dict)}") # metadata_list, tensor_list = _split_tensor_dict(tensor_dict) metadata_list, tensor_list = _split_tensor_dict_concat(tensor_dict) # `metadata_list` lives in CPU memory. # `broadcast_object_list` has serialization & deserialization, # all happening on CPU. Therefore, we can use the CPU group. # metadata_list 包含 (key, value) value是metadata 只有shape,没有数据 self.broadcast_object(metadata_list, src=src) async_handles = [] for tensor in tensor_list: if tensor.numel() == 0: # Skip broadcasting empty tensors. continue if tensor.is_cpu: # use metadata_group for CPU tensors handle = torch.distributed.broadcast(tensor, src=self.ranks[src], group=metadata_group, async_op=True) async_handles.append(handle) else: # use group for GPU tensors total_bytes = tensor.numel() * tensor.element_size() use_dist = True # only support 4M now if total_bytes < 4194304: try: from torch_vacc.vacc.custom_ops import broadcast #print("send tensor is:", tensor.shape, tensor.dtype, self.rank) broadcast(tensor, self.rank_in_group, self.world_size, root_rank=0, group_id=self.group_id, dev_info = self.rank_device_infos) use_dist = False except Exception as e: print("odsp broadcast run fail, now using distributed:", e) if use_dist: handle = torch.distributed.broadcast(tensor, src=self.ranks[src], group=group, async_op=True) async_handles.append(handle) for async_handle in async_handles: async_handle.wait() else: # other rank metadata_list = self.broadcast_object(None, src=src) tensor_dict = {} async_handles = [] tensor_size = [] # list of [key, shape] for split all_tensor dataType_list = [] for key, value in metadata_list: # if rank_in_group == 1: # print('rank1 k v ', key, value) if isinstance(value, TensorMetadata): tensor = None # 固定为int8 from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler if isinstance(memory_recycler, DeepseekMTPMemoryRecycler) and value.dtype == torch.int8: tensor = memory_recycler.DYNAMIC_OUTPUT_BUFFER.view(value.dtype)[:value.size.numel()].view(value.size) if tensor is None: tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) if tensor.numel() == 0: # Skip broadcasting empty tensors. tensor_dict[key] = tensor continue if tensor.is_cpu: # use metadata_group for CPU tensors handle = torch.distributed.broadcast( tensor, src=self.ranks[src], group=metadata_group, async_op=True) async_handles.append(handle) tensor_dict[key] = tensor else: # use group for GPU tensors if key == "all_tensor": total_bytes = tensor.numel() * tensor.element_size() use_dist = True # only support 4M now if total_bytes < 4194304: try: from torch_vacc.vacc.custom_ops import broadcast tensor = broadcast(tensor, self.rank_in_group, self.world_size, root_rank=0, group_id=self.group_id, dev_info = self.rank_device_infos) use_dist = False except Exception as e: print("dsp brocast run fail, now using distributed:", e) if use_dist: handle = torch.distributed.broadcast( tensor, #拼接的tensor src=self.ranks[src], group=group, async_op=False) # 按 key shape对, 拆分 all_tensor, 存入tensor_dict start = 0 idx = 0 for ki_vi in tensor_size: ki, vi = ki_vi length = vi.numel() * int(get_bitwidth(dataType_list[idx]) / 8) if ki in MEMORY_RECYCLER_KEY: tensor_dict[ki] = tensor[start:start+length].view(dataType_list[idx]).view(vi) else: value_tensor = torch.empty(vi, dtype=dataType_list[idx], device=value.device) recv_tensor = tensor[start:start+length].view(dataType_list[idx]).view(vi) tensor_dict[ki] = value_tensor.copy_(recv_tensor) start += length idx += 1 else: dataType_list.append(value.dtype) tensor_size.append([key, value.size]) else: tensor_dict[key] = value for async_handle in async_handles: async_handle.wait() return tensor_dict def all_gather(self, input_: torch.Tensor, dim: int = -1, output_: torch.Tensor = None) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") if self.use_custom_op_call: return torch.ops.vllm.all_gather(input_, dim, world_size, group_name=self.unique_name) else: # 启用输出复用版的 all_gather if output_ is not None: return self.device_communicator.all_gather_into_tensor(input_, dim, output_) return self._all_gather_out_place(input_, dim) def recv_tensor_dict( self, src: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return None all_gather_size = (1 if all_gather_group is None else all_gather_group.world_size) all_gather_rank = (0 if all_gather_group is None else all_gather_group.rank_in_group) group = self.device_group metadata_group = self.cpu_group if src is None: src = (self.rank_in_group - 1) % self.world_size assert src < self.world_size, f"Invalid src rank ({src})" recv_metadata_list = self.recv_object(src=src) tensor_dict: dict[str, Any] = {} from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import alloc_pipeline_parallel_recycler_buffer memory_recycler_list = ["hidden_states", "residual"] for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): # 判断是否需要内存复用 # 1. key在[hiddens, residual]中,说明为PP造成的 # 2. 可以根据key,从 memory_recycling 模块中申请到tensor use_create_recycler_tensor = False tensor = None value_tensor = None # 用于接收all_gather总数据 if key in memory_recycler_list: tensor = alloc_pipeline_parallel_recycler_buffer(value.size, value.dtype, key) if tensor is not None: use_create_recycler_tensor = True if not use_create_recycler_tensor: tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) value_tensor = tensor if tensor.numel() == 0: # Skip broadcasting empty tensors. tensor_dict[key] = tensor continue # send-allgather: send only a slice, then do allgather. use_all_gather = (all_gather_group is not None and tensor.numel() % all_gather_size == 0) if use_all_gather: orig_shape = tensor.shape # 内存复用,无需reshape, view即可 if use_create_recycler_tensor: tensor = tensor.view(all_gather_size, -1)[all_gather_rank].contiguous() else: tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] if tensor.is_cpu: # use metadata_group for CPU tensors torch.distributed.recv(tensor, src=self.ranks[src], group=metadata_group) else: # use group for GPU tensors torch.distributed.recv(tensor, src=self.ranks[src], group=group) if use_all_gather: # do the allgather if use_create_recycler_tensor: tensor = all_gather_group.all_gather( # type: ignore tensor, dim=0, output_ = value_tensor) tensor = tensor.view(orig_shape) else: tensor = all_gather_group.all_gather( # type: ignore tensor, dim=0) tensor = tensor.reshape(orig_shape) tensor_dict[key] = tensor else: tensor_dict[key] = value return tensor_dict