[Model] Support DeepSeek-V4
This commit is contained in:
286
vllm_mlu/distributed/parallel_state.py
Normal file
286
vllm_mlu/distributed/parallel_state.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import (
|
||||
GroupCoordinator,
|
||||
GraphCaptureContext,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.distributed.mlu_parallel_state import(
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_moe_expert_parallel_rank,
|
||||
get_moe_expert_parallel_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class MLUGraphCaptureContext:
|
||||
stream: torch.mlu.Stream
|
||||
|
||||
|
||||
@contextmanager
|
||||
def mlu_graph_capture(device: torch.device):
|
||||
"""
|
||||
`graph_capture` is a context manager which should surround the code that
|
||||
is capturing the CUDA graph. Its main purpose is to ensure that the
|
||||
some operations will be run after the graph is captured, before the graph
|
||||
is replayed. It returns a `GraphCaptureContext` object which contains the
|
||||
necessary data for the graph capture. Currently, it only contains the
|
||||
stream that the graph capture is running on. This stream is set to the
|
||||
current CUDA stream when the context manager is entered and reset to the
|
||||
default stream when the context manager is exited. This is to ensure that
|
||||
the graph capture is running on a separate stream from the default stream,
|
||||
in order to explicitly distinguish the kernels to capture
|
||||
from other kernels possibly launched on background in the default stream.
|
||||
"""
|
||||
context = MLUGraphCaptureContext(torch.mlu.Stream(device=device))
|
||||
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
|
||||
yield context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def vllm__distributed__parallel_state__GroupCoordinator__graph_capture(
|
||||
self,
|
||||
graph_capture_context: GraphCaptureContext | None = None,
|
||||
):
|
||||
if graph_capture_context is None:
|
||||
stream = torch.mlu.Stream()
|
||||
graph_capture_context = GraphCaptureContext(stream)
|
||||
else:
|
||||
stream = graph_capture_context.stream
|
||||
|
||||
# only cuda uses this function,
|
||||
# so we don't abstract it into the base class
|
||||
maybe_ca_context = nullcontext()
|
||||
from vllm_mlu.distributed.device_communicators.mlu_communicator import (
|
||||
MLUCommunicator,
|
||||
)
|
||||
|
||||
if self.device_communicator is not None:
|
||||
assert isinstance(self.device_communicator, MLUCommunicator)
|
||||
ca_comm = self.device_communicator.ca_comm
|
||||
if ca_comm is not None:
|
||||
maybe_ca_context = ca_comm.capture() # type: ignore
|
||||
|
||||
# ensure all initialization operations complete before attempting to
|
||||
# capture the graph on another stream
|
||||
curr_stream = torch.mlu.current_stream()
|
||||
if curr_stream != stream:
|
||||
stream.wait_stream(curr_stream)
|
||||
|
||||
with torch.mlu.stream(stream), maybe_ca_context:
|
||||
yield graph_capture_context
|
||||
|
||||
@dataclass
|
||||
class CnclEPBuffer:
|
||||
dispatch_send_token_tensor: torch.Tensor
|
||||
dispatch_recv_token_tensor: torch.Tensor
|
||||
combine_send_token_tensor: torch.Tensor
|
||||
combine_recv_token_tensor: torch.Tensor
|
||||
|
||||
class CnclEP:
|
||||
|
||||
def __init__(self,
|
||||
dispatch_token_size: int,
|
||||
combine_token_size: int,
|
||||
max_num_tokens_per_rank: int,
|
||||
num_global_experts: int,
|
||||
use_quant_dispatch: bool = True) -> None:
|
||||
nranks = get_moe_expert_parallel_world_size()
|
||||
rank = get_moe_expert_parallel_rank()
|
||||
moe_ep_group = get_moe_expert_parallel_group()
|
||||
self.max_num_tokens_per_rank = max_num_tokens_per_rank
|
||||
self.use_quant_dispatch = use_quant_dispatch
|
||||
|
||||
(
|
||||
handle,
|
||||
exchange_info_size,
|
||||
exchange_info,
|
||||
dispatch_send_token_tensor,
|
||||
dispatch_recv_token_tensor,
|
||||
combine_send_token_tensor,
|
||||
combine_recv_token_tensor
|
||||
) = mlu_ops.moe_all2all_create(dispatch_token_size,
|
||||
combine_token_size,
|
||||
num_global_experts,
|
||||
max_num_tokens_per_rank,
|
||||
rank,
|
||||
nranks)
|
||||
self.handle = handle
|
||||
self.buffer = CnclEPBuffer(
|
||||
dispatch_send_token_tensor,
|
||||
dispatch_recv_token_tensor,
|
||||
combine_send_token_tensor,
|
||||
combine_recv_token_tensor)
|
||||
|
||||
assert exchange_info.ndim == 1, "exchange_info should be 1D"
|
||||
all_exchange_info = torch.empty((nranks, exchange_info.size(0)),
|
||||
dtype=exchange_info.dtype,
|
||||
device=exchange_info.device)
|
||||
exchange_info = exchange_info.unsqueeze(0)
|
||||
torch.distributed.all_gather_into_tensor(all_exchange_info,
|
||||
exchange_info,
|
||||
group=moe_ep_group.cpu_group,
|
||||
async_op=False)
|
||||
mlu_ops.moe_all2all_init(self.handle, all_exchange_info)
|
||||
torch.distributed.barrier(group=moe_ep_group.cpu_group)
|
||||
|
||||
def dispatch(self,
|
||||
token_byte: int,
|
||||
token_num: int,
|
||||
send_layout: torch.Tensor,
|
||||
send_token_num: torch.Tensor,
|
||||
recv_layout: torch.Tensor,
|
||||
recv_token_num: torch.Tensor,
|
||||
send_token: Optional[torch.Tensor] = None,
|
||||
recv_token: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
'''
|
||||
The returned tensors are in-placed modified, we could directly use them
|
||||
after dispatch finishes.
|
||||
'''
|
||||
mlu_ops.moe_all2all_dispatch(self.handle,
|
||||
token_byte,
|
||||
token_num,
|
||||
send_layout,
|
||||
send_token_num,
|
||||
recv_layout,
|
||||
recv_token_num,
|
||||
send_token,
|
||||
recv_token)
|
||||
|
||||
def combine(self,
|
||||
token_byte: int,
|
||||
token_num: int,
|
||||
send_src_layout: torch.Tensor,
|
||||
send_dst_layout: torch.Tensor,
|
||||
send_token: Optional[torch.Tensor] = None,
|
||||
recv_token: Optional[torch.Tensor] = None,
|
||||
) ->None:
|
||||
mlu_ops.moe_all2all_combine(self.handle,
|
||||
token_byte,
|
||||
token_num,
|
||||
send_src_layout,
|
||||
send_dst_layout,
|
||||
send_token,
|
||||
recv_token)
|
||||
|
||||
def destroy(self) -> None:
|
||||
mlu_ops.moe_all2all_destroy(self.handle)
|
||||
|
||||
_CNCLEP: CnclEP | None = None
|
||||
_CNCLEP_BF16: CnclEP | None = None
|
||||
|
||||
def get_cnclep(use_quant_dispatch: bool = True) -> CnclEP:
|
||||
if use_quant_dispatch:
|
||||
assert _CNCLEP is not None, "cnclep is not initialized"
|
||||
return _CNCLEP
|
||||
else:
|
||||
assert _CNCLEP_BF16 is not None, "cnclep_bf16 is not initialized"
|
||||
return _CNCLEP_BF16
|
||||
|
||||
def init_cnclep(dispatch_token_size: int,
|
||||
combine_token_size: int,
|
||||
max_num_tokens_per_rank: int,
|
||||
num_global_experts: int,
|
||||
use_quant_dispatch: bool = True):
|
||||
if use_quant_dispatch:
|
||||
global _CNCLEP
|
||||
assert _CNCLEP is None, "cnclep has been initialized"
|
||||
_CNCLEP = CnclEP(dispatch_token_size,
|
||||
combine_token_size,
|
||||
max_num_tokens_per_rank,
|
||||
num_global_experts,
|
||||
use_quant_dispatch)
|
||||
else:
|
||||
global _CNCLEP_BF16
|
||||
assert _CNCLEP_BF16 is None, "cnclep_bf16 has been initialized"
|
||||
_CNCLEP_BF16 = CnclEP(dispatch_token_size,
|
||||
combine_token_size,
|
||||
max_num_tokens_per_rank,
|
||||
num_global_experts,
|
||||
use_quant_dispatch)
|
||||
|
||||
def cnclep_dispatch(token_byte: int,
|
||||
token_num: int,
|
||||
send_layout: torch.Tensor,
|
||||
send_token_num: torch.Tensor,
|
||||
recv_layout: torch.Tensor,
|
||||
recv_token_num: torch.Tensor,
|
||||
send_token: Optional[torch.Tensor] = None,
|
||||
recv_token: Optional[torch.Tensor] = None,
|
||||
use_quant_dispatch: bool = True,
|
||||
):
|
||||
if use_quant_dispatch:
|
||||
_CNCLEP.dispatch(token_byte,
|
||||
token_num,
|
||||
send_layout,
|
||||
send_token_num,
|
||||
recv_layout,
|
||||
recv_token_num,
|
||||
send_token,
|
||||
recv_token)
|
||||
else:
|
||||
_CNCLEP_BF16.dispatch(token_byte,
|
||||
token_num,
|
||||
send_layout,
|
||||
send_token_num,
|
||||
recv_layout,
|
||||
recv_token_num,
|
||||
send_token,
|
||||
recv_token)
|
||||
|
||||
def cnclep_combine(token_byte: int,
|
||||
token_num: int,
|
||||
send_src_layout: torch.Tensor,
|
||||
send_dst_layout: torch.Tensor,
|
||||
send_token: Optional[torch.Tensor] = None,
|
||||
recv_token: Optional[torch.Tensor] = None,
|
||||
use_quant_dispatch: bool = True,
|
||||
):
|
||||
if use_quant_dispatch:
|
||||
_CNCLEP.combine(token_byte,
|
||||
token_num,
|
||||
send_src_layout,
|
||||
send_dst_layout,
|
||||
send_token,
|
||||
recv_token)
|
||||
else:
|
||||
_CNCLEP_BF16.combine(token_byte,
|
||||
token_num,
|
||||
send_src_layout,
|
||||
send_dst_layout,
|
||||
send_token,
|
||||
recv_token)
|
||||
|
||||
|
||||
def destroy_cnclep():
|
||||
global _CNCLEP
|
||||
|
||||
if _CNCLEP:
|
||||
_CNCLEP.destroy()
|
||||
_CNCLEP = None
|
||||
|
||||
global _CNCLEP_BF16
|
||||
|
||||
if _CNCLEP_BF16:
|
||||
_CNCLEP_BF16.destroy()
|
||||
_CNCLEP_BF16 = None
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(GroupCoordinator,
|
||||
GroupCoordinator.graph_capture,
|
||||
vllm__distributed__parallel_state__GroupCoordinator__graph_capture)
|
||||
|
||||
Reference in New Issue
Block a user