Files
enginex-mlu590-vllm/vllm_mlu/distributed/parallel_state.py
2026-04-24 09:58:03 +08:00

287 lines
11 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from contextlib import contextmanager, nullcontext
from typing import Optional
from dataclasses import dataclass
import torch
from vllm.distributed.parallel_state import (
GroupCoordinator,
GraphCaptureContext,
get_pp_group,
get_tp_group,
)
from vllm.distributed.mlu_parallel_state import(
get_moe_expert_parallel_world_size,
get_moe_expert_parallel_rank,
get_moe_expert_parallel_group,
)
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
@dataclass
class MLUGraphCaptureContext:
stream: torch.mlu.Stream
@contextmanager
def mlu_graph_capture(device: torch.device):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
context = MLUGraphCaptureContext(torch.mlu.Stream(device=device))
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
yield context
@contextmanager
def vllm__distributed__parallel_state__GroupCoordinator__graph_capture(
self,
graph_capture_context: GraphCaptureContext | None = None,
):
if graph_capture_context is None:
stream = torch.mlu.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
# only cuda uses this function,
# so we don't abstract it into the base class
maybe_ca_context = nullcontext()
from vllm_mlu.distributed.device_communicators.mlu_communicator import (
MLUCommunicator,
)
if self.device_communicator is not None:
assert isinstance(self.device_communicator, MLUCommunicator)
ca_comm = self.device_communicator.ca_comm
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.mlu.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.mlu.stream(stream), maybe_ca_context:
yield graph_capture_context
@dataclass
class CnclEPBuffer:
dispatch_send_token_tensor: torch.Tensor
dispatch_recv_token_tensor: torch.Tensor
combine_send_token_tensor: torch.Tensor
combine_recv_token_tensor: torch.Tensor
class CnclEP:
def __init__(self,
dispatch_token_size: int,
combine_token_size: int,
max_num_tokens_per_rank: int,
num_global_experts: int,
use_quant_dispatch: bool = True) -> None:
nranks = get_moe_expert_parallel_world_size()
rank = get_moe_expert_parallel_rank()
moe_ep_group = get_moe_expert_parallel_group()
self.max_num_tokens_per_rank = max_num_tokens_per_rank
self.use_quant_dispatch = use_quant_dispatch
(
handle,
exchange_info_size,
exchange_info,
dispatch_send_token_tensor,
dispatch_recv_token_tensor,
combine_send_token_tensor,
combine_recv_token_tensor
) = mlu_ops.moe_all2all_create(dispatch_token_size,
combine_token_size,
num_global_experts,
max_num_tokens_per_rank,
rank,
nranks)
self.handle = handle
self.buffer = CnclEPBuffer(
dispatch_send_token_tensor,
dispatch_recv_token_tensor,
combine_send_token_tensor,
combine_recv_token_tensor)
assert exchange_info.ndim == 1, "exchange_info should be 1D"
all_exchange_info = torch.empty((nranks, exchange_info.size(0)),
dtype=exchange_info.dtype,
device=exchange_info.device)
exchange_info = exchange_info.unsqueeze(0)
torch.distributed.all_gather_into_tensor(all_exchange_info,
exchange_info,
group=moe_ep_group.cpu_group,
async_op=False)
mlu_ops.moe_all2all_init(self.handle, all_exchange_info)
torch.distributed.barrier(group=moe_ep_group.cpu_group)
def dispatch(self,
token_byte: int,
token_num: int,
send_layout: torch.Tensor,
send_token_num: torch.Tensor,
recv_layout: torch.Tensor,
recv_token_num: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
) -> None:
'''
The returned tensors are in-placed modified, we could directly use them
after dispatch finishes.
'''
mlu_ops.moe_all2all_dispatch(self.handle,
token_byte,
token_num,
send_layout,
send_token_num,
recv_layout,
recv_token_num,
send_token,
recv_token)
def combine(self,
token_byte: int,
token_num: int,
send_src_layout: torch.Tensor,
send_dst_layout: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
) ->None:
mlu_ops.moe_all2all_combine(self.handle,
token_byte,
token_num,
send_src_layout,
send_dst_layout,
send_token,
recv_token)
def destroy(self) -> None:
mlu_ops.moe_all2all_destroy(self.handle)
_CNCLEP: CnclEP | None = None
_CNCLEP_BF16: CnclEP | None = None
def get_cnclep(use_quant_dispatch: bool = True) -> CnclEP:
if use_quant_dispatch:
assert _CNCLEP is not None, "cnclep is not initialized"
return _CNCLEP
else:
assert _CNCLEP_BF16 is not None, "cnclep_bf16 is not initialized"
return _CNCLEP_BF16
def init_cnclep(dispatch_token_size: int,
combine_token_size: int,
max_num_tokens_per_rank: int,
num_global_experts: int,
use_quant_dispatch: bool = True):
if use_quant_dispatch:
global _CNCLEP
assert _CNCLEP is None, "cnclep has been initialized"
_CNCLEP = CnclEP(dispatch_token_size,
combine_token_size,
max_num_tokens_per_rank,
num_global_experts,
use_quant_dispatch)
else:
global _CNCLEP_BF16
assert _CNCLEP_BF16 is None, "cnclep_bf16 has been initialized"
_CNCLEP_BF16 = CnclEP(dispatch_token_size,
combine_token_size,
max_num_tokens_per_rank,
num_global_experts,
use_quant_dispatch)
def cnclep_dispatch(token_byte: int,
token_num: int,
send_layout: torch.Tensor,
send_token_num: torch.Tensor,
recv_layout: torch.Tensor,
recv_token_num: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
use_quant_dispatch: bool = True,
):
if use_quant_dispatch:
_CNCLEP.dispatch(token_byte,
token_num,
send_layout,
send_token_num,
recv_layout,
recv_token_num,
send_token,
recv_token)
else:
_CNCLEP_BF16.dispatch(token_byte,
token_num,
send_layout,
send_token_num,
recv_layout,
recv_token_num,
send_token,
recv_token)
def cnclep_combine(token_byte: int,
token_num: int,
send_src_layout: torch.Tensor,
send_dst_layout: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
use_quant_dispatch: bool = True,
):
if use_quant_dispatch:
_CNCLEP.combine(token_byte,
token_num,
send_src_layout,
send_dst_layout,
send_token,
recv_token)
else:
_CNCLEP_BF16.combine(token_byte,
token_num,
send_src_layout,
send_dst_layout,
send_token,
recv_token)
def destroy_cnclep():
global _CNCLEP
if _CNCLEP:
_CNCLEP.destroy()
_CNCLEP = None
global _CNCLEP_BF16
if _CNCLEP_BF16:
_CNCLEP_BF16.destroy()
_CNCLEP_BF16 = None
MluHijackObject.apply_hijack(GroupCoordinator,
GroupCoordinator.graph_capture,
vllm__distributed__parallel_state__GroupCoordinator__graph_capture)