801 lines
26 KiB
Python
801 lines
26 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from contextlib import nullcontext
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
|
|
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
|
from sglang.srt.layers.moe.token_dispatcher.base import (
|
|
BaseDispatcher,
|
|
BaseDispatcherConfig,
|
|
CombineInput,
|
|
CombineInputFormat,
|
|
DispatchOutput,
|
|
DispatchOutputFormat,
|
|
)
|
|
from sglang.srt.layers.moe.utils import (
|
|
DeepEPMode,
|
|
get_deepep_config,
|
|
get_moe_runner_backend,
|
|
is_tbo_enabled,
|
|
)
|
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
|
from sglang.srt.utils import (
|
|
get_bool_env_var,
|
|
get_int_env_var,
|
|
is_hip,
|
|
is_npu,
|
|
load_json_config,
|
|
)
|
|
|
|
_is_npu = is_npu()
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.single_batch_overlap import CombineOverlapArgs
|
|
|
|
try:
|
|
from deep_ep import Buffer, Config
|
|
|
|
if not _is_npu:
|
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
|
sglang_per_token_group_quant_fp8,
|
|
)
|
|
|
|
use_deepep = True
|
|
except ImportError:
|
|
use_deepep = False
|
|
|
|
from enum import Enum, IntEnum, auto
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
|
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DeepEPNormalOutput(NamedTuple):
|
|
"""DeepEP normal dispatch output."""
|
|
|
|
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
|
|
# hidden_states_scale
|
|
topk_idx: torch.Tensor
|
|
topk_weights: torch.Tensor
|
|
num_recv_tokens_per_expert: List[int]
|
|
|
|
@property
|
|
def format(self) -> DispatchOutputFormat:
|
|
return DispatchOutputFormat.DEEPEP_NORMAL
|
|
|
|
|
|
class DeepEPLLOutput(NamedTuple):
|
|
"""DeepEP low latency dispatch output."""
|
|
|
|
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
|
topk_idx: torch.Tensor
|
|
topk_weights: torch.Tensor
|
|
masked_m: torch.Tensor
|
|
expected_m: int
|
|
|
|
@property
|
|
def format(self) -> DispatchOutputFormat:
|
|
return DispatchOutputFormat.DEEPEP_LL
|
|
|
|
|
|
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
|
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
|
|
|
|
|
class DeepEPNormalCombineInput(NamedTuple):
|
|
"""DeepEP normal combine input."""
|
|
|
|
pass
|
|
|
|
@property
|
|
def format(self) -> CombineInputFormat:
|
|
return CombineInputFormat.DEEPEP_NORMAL
|
|
|
|
|
|
class DeepEPLLCombineInput(NamedTuple):
|
|
"""DeepEP low latency combine input."""
|
|
|
|
pass
|
|
|
|
@property
|
|
def format(self) -> CombineInputFormat:
|
|
return CombineInputFormat.DEEPEP_LL
|
|
|
|
|
|
assert isinstance(DeepEPNormalCombineInput, CombineInput)
|
|
assert isinstance(DeepEPLLCombineInput, CombineInput)
|
|
|
|
|
|
class DeepEPDispatchMode(IntEnum):
|
|
NORMAL = auto()
|
|
LOW_LATENCY = auto()
|
|
|
|
|
|
class DeepEPBuffer:
|
|
_buffer = None
|
|
_dispatch_mode: Optional[DeepEPDispatchMode] = None
|
|
_hidden_size: Optional[int] = None
|
|
_num_max_dispatch_tokens_per_rank: Optional[int] = None
|
|
_num_experts: Optional[int] = None
|
|
|
|
@classmethod
|
|
def get_deepep_buffer(
|
|
cls,
|
|
group: dist.ProcessGroup,
|
|
hidden_size: int,
|
|
param_bytes: int,
|
|
deepep_mode: DeepEPMode,
|
|
num_max_dispatch_tokens_per_rank: int = -1,
|
|
num_experts: int = -1,
|
|
):
|
|
if cls._buffer is not None:
|
|
return cls._buffer
|
|
|
|
cls._hidden_size = hidden_size
|
|
cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
|
cls._num_experts = num_experts
|
|
|
|
num_nvl_bytes, num_rdma_bytes = 0, 0
|
|
if deepep_mode.enable_normal():
|
|
hidden_bytes = hidden_size * param_bytes
|
|
for config in (
|
|
DeepEPConfig.get_instance().normal_dispatch_config
|
|
or Buffer.get_dispatch_config(group.size()),
|
|
DeepEPConfig.get_instance().normal_combine_config
|
|
or Buffer.get_combine_config(group.size()),
|
|
):
|
|
num_nvl_bytes = max(
|
|
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
|
|
num_nvl_bytes,
|
|
)
|
|
num_rdma_bytes = max(
|
|
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()),
|
|
num_rdma_bytes,
|
|
)
|
|
if deepep_mode.enable_low_latency():
|
|
assert num_max_dispatch_tokens_per_rank != -1
|
|
assert num_experts != -1 and num_experts % group.size() == 0
|
|
num_rdma_bytes = max(
|
|
Buffer.get_low_latency_rdma_size_hint(
|
|
num_max_dispatch_tokens_per_rank,
|
|
hidden_size,
|
|
group.size(),
|
|
num_experts,
|
|
),
|
|
num_rdma_bytes,
|
|
)
|
|
|
|
# We should calculate num_qps_per_rank consistently with DeepEP's test script logic:
|
|
if deepep_mode == DeepEPMode.NORMAL:
|
|
# refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py#L235
|
|
num_qps_per_rank = DeepEPConfig.get_instance().num_sms
|
|
elif deepep_mode == DeepEPMode.LOW_LATENCY:
|
|
# refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_low_latency.py#L176
|
|
num_qps_per_rank = num_experts // group.size()
|
|
elif deepep_mode == DeepEPMode.AUTO:
|
|
# low-latency and normal mode all need run
|
|
# refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py#L235
|
|
num_qps_per_rank = max(
|
|
DeepEPConfig.get_instance().num_sms, num_experts // group.size()
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
if not _is_npu:
|
|
total_num_sms = torch.cuda.get_device_properties(
|
|
device="cuda"
|
|
).multi_processor_count
|
|
if (
|
|
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
|
and not is_tbo_enabled()
|
|
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
|
):
|
|
logger.warning(
|
|
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
|
|
f"This may result in highly suboptimal performance. "
|
|
f"Consider using --deepep-config to change the behavior."
|
|
)
|
|
|
|
cls._buffer = Buffer(
|
|
group,
|
|
num_nvl_bytes,
|
|
num_rdma_bytes,
|
|
low_latency_mode=deepep_mode.enable_low_latency(),
|
|
num_qps_per_rank=num_qps_per_rank,
|
|
# TODO can be false when unneeded
|
|
allow_mnnvl=True,
|
|
)
|
|
return cls._buffer
|
|
|
|
@classmethod
|
|
def clean_buffer(cls):
|
|
if not cls._buffer.low_latency_mode:
|
|
return
|
|
cls._buffer.clean_low_latency_buffer(
|
|
cls._num_max_dispatch_tokens_per_rank,
|
|
cls._hidden_size,
|
|
cls._num_experts,
|
|
)
|
|
|
|
@classmethod
|
|
def set_dispatch_mode_as_normal(cls):
|
|
cls._dispatch_mode = DeepEPDispatchMode.NORMAL
|
|
|
|
@classmethod
|
|
def set_dispatch_mode_as_low_latency(cls):
|
|
if cls._dispatch_mode == DeepEPDispatchMode.NORMAL:
|
|
cls.clean_buffer()
|
|
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
|
|
|
|
|
class DeepEPConfig(BaseDispatcherConfig):
|
|
_instance = None
|
|
|
|
def __init__(self):
|
|
config_str = get_deepep_config()
|
|
if config_str:
|
|
config_parsed = load_json_config(config_str)
|
|
if torch.distributed.get_rank() == 0:
|
|
logger.info(f"Use DeepEP Config: {config_parsed}")
|
|
config_dispatch = config_parsed["normal_dispatch"]
|
|
config_combine = config_parsed["normal_combine"]
|
|
|
|
self.normal_dispatch_config = Config(**config_dispatch)
|
|
self.normal_combine_config = Config(**config_combine)
|
|
|
|
assert config_dispatch["num_sms"] == config_combine["num_sms"]
|
|
self.num_sms = config_dispatch["num_sms"]
|
|
else:
|
|
self.normal_dispatch_config = None
|
|
self.normal_combine_config = None
|
|
self.num_sms = Buffer.num_sms
|
|
|
|
@classmethod
|
|
def get_instance(cls):
|
|
if cls._instance is None:
|
|
cls._instance = DeepEPConfig()
|
|
return cls._instance
|
|
|
|
|
|
class _DeepEPDispatcherImplBase:
|
|
def __init__(
|
|
self,
|
|
group: torch.distributed.ProcessGroup,
|
|
router_topk: int,
|
|
permute_fusion: bool,
|
|
num_experts: int,
|
|
num_local_experts: int,
|
|
hidden_size: int,
|
|
params_dtype: torch.dtype,
|
|
deepep_mode: DeepEPMode,
|
|
):
|
|
if not use_deepep:
|
|
raise ImportError(
|
|
"DeepEP is not installed. Please install DeepEP package from "
|
|
"https://github.com/deepseek-ai/deepep."
|
|
)
|
|
|
|
self.group = group
|
|
self.router_topk = router_topk
|
|
self.permute_fusion = permute_fusion
|
|
self.num_experts = num_experts
|
|
self.num_local_experts = num_local_experts
|
|
self.hidden_size = hidden_size
|
|
self.params_dtype = params_dtype
|
|
self.deepep_mode = deepep_mode
|
|
|
|
self.params_bytes = 2
|
|
self.num_max_dispatch_tokens_per_rank = get_int_env_var(
|
|
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
|
|
)
|
|
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
|
|
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
|
|
assert self.num_max_dispatch_tokens_per_rank <= 1024
|
|
|
|
self.handle = None
|
|
|
|
def dispatch_a(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
input_global_scale: Optional[torch.Tensor],
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
):
|
|
raise NotImplementedError
|
|
|
|
def dispatch_b(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def combine_a(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
overlap_args: Optional["CombineOverlapArgs"],
|
|
):
|
|
raise NotImplementedError
|
|
|
|
def combine_b(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def _get_buffer(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
def __init__(self, async_finish: bool, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
self.async_finish = async_finish
|
|
self.src2dst = None
|
|
|
|
def dispatch_a(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
input_global_scale: Optional[torch.Tensor],
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
):
|
|
topk_idx = topk_idx.to(torch.int64)
|
|
if (
|
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
|
and not get_moe_runner_backend().is_cutlass()
|
|
):
|
|
# TODO hard code 128 block quant,use fp8 communication
|
|
hidden_states = sglang_per_token_group_quant_fp8(
|
|
hidden_states,
|
|
128,
|
|
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
)
|
|
previous_event = Buffer.capture() if self.async_finish else None
|
|
return hidden_states, topk_idx, topk_weights, previous_event
|
|
|
|
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
|
(
|
|
hidden_states,
|
|
topk_idx,
|
|
topk_weights,
|
|
num_recv_tokens_per_expert,
|
|
event,
|
|
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
|
|
event.current_stream_wait() if self.async_finish else ()
|
|
return DeepEPNormalOutput(
|
|
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
|
)
|
|
|
|
def _dispatch_core(
|
|
self,
|
|
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
previous_event,
|
|
):
|
|
buffer = self._get_buffer()
|
|
(
|
|
num_tokens_per_rank,
|
|
num_tokens_per_rdma_rank,
|
|
num_tokens_per_expert,
|
|
is_token_in_rank,
|
|
previous_event,
|
|
) = buffer.get_dispatch_layout(
|
|
topk_idx,
|
|
self.num_experts,
|
|
previous_event=previous_event,
|
|
async_finish=self.async_finish,
|
|
allocate_on_comm_stream=previous_event is not None,
|
|
)
|
|
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
|
|
# However, doing this would incur an unknown synchronization error, but keeping
|
|
# `handle` as a member variable works.
|
|
|
|
(
|
|
recv_x,
|
|
recv_topk_idx,
|
|
recv_topk_weights,
|
|
num_recv_tokens_per_expert,
|
|
self.handle,
|
|
event,
|
|
) = buffer.dispatch(
|
|
x,
|
|
topk_idx=topk_idx,
|
|
topk_weights=topk_weights,
|
|
num_tokens_per_rank=num_tokens_per_rank,
|
|
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
|
is_token_in_rank=is_token_in_rank,
|
|
num_tokens_per_expert=num_tokens_per_expert,
|
|
previous_event=previous_event,
|
|
async_finish=self.async_finish,
|
|
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
|
expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
|
|
config=DeepEPConfig.get_instance().normal_dispatch_config,
|
|
)
|
|
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
|
num_recv_tokens_per_expert,
|
|
num_tokens_per_rank=num_tokens_per_rank,
|
|
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
|
num_tokens_per_expert=num_tokens_per_expert,
|
|
)
|
|
|
|
return (
|
|
recv_x,
|
|
recv_topk_idx,
|
|
recv_topk_weights,
|
|
num_recv_tokens_per_expert,
|
|
event,
|
|
)
|
|
|
|
def combine_a(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
overlap_args: Optional["CombineOverlapArgs"],
|
|
):
|
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
deepep_post_reorder_triton_kernel,
|
|
)
|
|
|
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
|
output = hidden_states
|
|
else:
|
|
if hidden_states.shape[0] > 0:
|
|
num_tokens = self.src2dst.shape[0] // self.router_topk
|
|
output = torch.empty(
|
|
(num_tokens, hidden_states.shape[1]),
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
|
hidden_states,
|
|
output,
|
|
self.src2dst,
|
|
topk_idx,
|
|
topk_weights,
|
|
self.router_topk,
|
|
hidden_states.shape[1],
|
|
BLOCK_SIZE=512,
|
|
)
|
|
else:
|
|
output = torch.zeros(
|
|
(0, hidden_states.shape[1]),
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
previous_event = Buffer.capture() if self.async_finish else None
|
|
return output, previous_event
|
|
|
|
def combine_b(self, output, previous_event):
|
|
hidden_states, event = self._combine_core(output, previous_event)
|
|
event.current_stream_wait() if self.async_finish else ()
|
|
self.handle = None
|
|
self.src2dst = None
|
|
return hidden_states
|
|
|
|
def _combine_core(self, x: torch.Tensor, previous_event):
|
|
buffer = self._get_buffer()
|
|
combined_x, _, event = buffer.combine(
|
|
x,
|
|
self.handle,
|
|
async_finish=self.async_finish,
|
|
previous_event=previous_event,
|
|
allocate_on_comm_stream=previous_event is not None,
|
|
config=DeepEPConfig.get_instance().normal_combine_config,
|
|
)
|
|
return combined_x, event
|
|
|
|
def _get_buffer(self):
|
|
DeepEPBuffer.set_dispatch_mode_as_normal()
|
|
|
|
return DeepEPBuffer.get_deepep_buffer(
|
|
self.group,
|
|
self.hidden_size,
|
|
self.params_bytes,
|
|
self.deepep_mode,
|
|
self.num_max_dispatch_tokens_per_rank,
|
|
self.num_experts,
|
|
)
|
|
|
|
|
|
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
def __init__(self, return_recv_hook: bool, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
"""
|
|
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
|
"""
|
|
self.return_recv_hook = return_recv_hook
|
|
self.device_module = torch.get_device_module()
|
|
|
|
def dispatch_a(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
input_global_scale: Optional[torch.Tensor],
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
):
|
|
buffer = self._get_buffer()
|
|
topk_idx = topk_idx.to(torch.int64)
|
|
expected_m = (
|
|
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
|
|
+ self.num_experts
|
|
) // self.num_experts
|
|
hidden_states, masked_m, event, hook = self._dispatch_core(
|
|
hidden_states,
|
|
input_global_scale,
|
|
topk_idx,
|
|
)
|
|
return (
|
|
hidden_states,
|
|
topk_idx,
|
|
topk_weights,
|
|
masked_m,
|
|
expected_m,
|
|
event,
|
|
hook,
|
|
)
|
|
|
|
def dispatch_b(
|
|
self,
|
|
hidden_states,
|
|
topk_idx,
|
|
topk_weights,
|
|
masked_m,
|
|
expected_m,
|
|
event,
|
|
hook,
|
|
):
|
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
|
|
|
get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
|
|
masked_m
|
|
)
|
|
|
|
deepep_output = DeepEPLLOutput(
|
|
hidden_states,
|
|
topk_idx,
|
|
topk_weights,
|
|
masked_m,
|
|
expected_m,
|
|
)
|
|
return deepep_output
|
|
|
|
def _dispatch_core(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
input_global_scale: Optional[torch.Tensor],
|
|
topk_idx: torch.Tensor,
|
|
):
|
|
use_nvfp4 = use_fp8 = False
|
|
if input_global_scale is not None:
|
|
use_nvfp4 = True
|
|
elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
|
|
use_fp8 = True
|
|
|
|
buffer = self._get_buffer()
|
|
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
|
|
buffer.low_latency_dispatch(
|
|
hidden_states,
|
|
topk_idx,
|
|
self.num_max_dispatch_tokens_per_rank,
|
|
self.num_experts,
|
|
use_fp8=use_fp8,
|
|
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
|
|
**(
|
|
dict(x_global_scale=input_global_scale)
|
|
if input_global_scale is not None
|
|
else dict()
|
|
),
|
|
async_finish=not self.return_recv_hook,
|
|
return_recv_hook=self.return_recv_hook,
|
|
round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
|
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
|
|
use_ue8m0=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
|
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
|
|
)
|
|
)
|
|
return packed_recv_hidden, self.packed_recv_count, event, hook
|
|
|
|
def combine_a(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
overlap_args: Optional["CombineOverlapArgs"],
|
|
):
|
|
hidden_states, event, hook = self._combine_core(
|
|
hidden_states,
|
|
topk_idx,
|
|
topk_weights,
|
|
overlap_args=overlap_args,
|
|
)
|
|
return hidden_states, event, hook, overlap_args
|
|
|
|
def combine_b(self, hidden_states, event, hook, overlap_args):
|
|
if overlap_args is not None:
|
|
overlap_args.stream.wait_stream(self.device_module.current_stream())
|
|
|
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
|
|
|
if overlap_args is not None:
|
|
self.device_module.current_stream().wait_stream(overlap_args.stream)
|
|
|
|
return hidden_states
|
|
|
|
def _combine_core(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
overlap_args: Optional["CombineOverlapArgs"],
|
|
):
|
|
buffer = self._get_buffer()
|
|
|
|
ctx = nullcontext()
|
|
if overlap_args is not None:
|
|
overlap_args.stream.wait_event(overlap_args.wait_event)
|
|
ctx = torch.cuda.stream(overlap_args.stream)
|
|
|
|
with ctx:
|
|
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
|
x=hidden_states,
|
|
topk_idx=topk_idx,
|
|
topk_weights=topk_weights,
|
|
handle=self.handle,
|
|
async_finish=not self.return_recv_hook,
|
|
return_recv_hook=self.return_recv_hook,
|
|
**(
|
|
dict(
|
|
overlap=overlap_args.overlap,
|
|
src_signals=overlap_args.signal,
|
|
src_signal_expect_value=overlap_args.threshold,
|
|
)
|
|
if overlap_args is not None
|
|
else {}
|
|
),
|
|
)
|
|
|
|
self.packed_recv_count = self.handle = None
|
|
return combined_hidden_states, event, hook
|
|
|
|
def _get_buffer(self):
|
|
DeepEPBuffer.set_dispatch_mode_as_low_latency()
|
|
return DeepEPBuffer.get_deepep_buffer(
|
|
self.group,
|
|
self.hidden_size,
|
|
self.params_bytes,
|
|
self.deepep_mode,
|
|
self.num_max_dispatch_tokens_per_rank,
|
|
self.num_experts,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class _Stage(Enum):
|
|
INITIAL = auto()
|
|
AFTER_DISPATCH_A = auto()
|
|
AFTER_DISPATCH_B = auto()
|
|
AFTER_COMBINE_A = auto()
|
|
|
|
|
|
class DeepEPDispatcher(BaseDispatcher):
|
|
def __init__(
|
|
self,
|
|
group: torch.distributed.ProcessGroup,
|
|
router_topk: int,
|
|
permute_fusion: bool = False,
|
|
num_experts: int = None,
|
|
num_local_experts: int = None,
|
|
hidden_size: int = None,
|
|
params_dtype: torch.dtype = None,
|
|
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
|
async_finish: bool = False,
|
|
return_recv_hook: bool = False,
|
|
):
|
|
self.deepep_mode = deepep_mode
|
|
|
|
common_kwargs = dict(
|
|
group=group,
|
|
router_topk=router_topk,
|
|
permute_fusion=permute_fusion,
|
|
num_experts=num_experts,
|
|
num_local_experts=num_local_experts,
|
|
hidden_size=hidden_size,
|
|
params_dtype=params_dtype,
|
|
deepep_mode=deepep_mode,
|
|
)
|
|
|
|
if self.deepep_mode.enable_low_latency():
|
|
self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
|
|
return_recv_hook=return_recv_hook,
|
|
**common_kwargs,
|
|
)
|
|
if self.deepep_mode.enable_normal():
|
|
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
|
|
async_finish=async_finish,
|
|
**common_kwargs,
|
|
)
|
|
|
|
self._stage = _Stage.INITIAL
|
|
|
|
def dispatch(self, *args, **kwargs) -> DispatchOutput:
|
|
self.dispatch_a(*args, **kwargs)
|
|
ret = self.dispatch_b()
|
|
return ret
|
|
|
|
def dispatch_a(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
input_global_scale: Optional[torch.Tensor],
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
):
|
|
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
|
inner_state = self._get_impl(forward_batch).dispatch_a(
|
|
hidden_states=hidden_states,
|
|
input_global_scale=input_global_scale,
|
|
topk_idx=topk_idx,
|
|
topk_weights=topk_weights,
|
|
)
|
|
self._dispatch_intermediate_state = forward_batch, inner_state
|
|
|
|
def dispatch_b(self):
|
|
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
|
forward_batch, inner_state = self._dispatch_intermediate_state
|
|
del self._dispatch_intermediate_state
|
|
return self._get_impl(forward_batch).dispatch_b(*inner_state)
|
|
|
|
def combine(self, *args, **kwargs) -> Tuple:
|
|
self.combine_a(*args, **kwargs)
|
|
ret = self.combine_b()
|
|
return ret
|
|
|
|
def combine_a(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
overlap_args: Optional["CombineOverlapArgs"] = None,
|
|
):
|
|
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
|
inner_state = self._get_impl(forward_batch).combine_a(
|
|
hidden_states=hidden_states,
|
|
topk_idx=topk_idx,
|
|
topk_weights=topk_weights,
|
|
overlap_args=overlap_args,
|
|
)
|
|
self._combine_intermediate_state = forward_batch, inner_state
|
|
|
|
def combine_b(self):
|
|
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
|
forward_batch, inner_state = self._combine_intermediate_state
|
|
del self._combine_intermediate_state
|
|
return self._get_impl(forward_batch).combine_b(*inner_state)
|
|
|
|
def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
|
|
resolved_deepep_mode = self.deepep_mode.resolve(
|
|
forward_batch.is_extend_in_batch
|
|
)
|
|
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
|
return self._normal_dispatcher
|
|
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
|
return self._low_latency_dispatcher
|
|
else:
|
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
|
|
|
def _update_stage(self, old_stage, new_stage):
|
|
assert self._stage == old_stage
|
|
self._stage = new_stage
|