|
|
|
|
@@ -7,6 +7,7 @@ try:
|
|
|
|
|
except ImportError:
|
|
|
|
|
use_deepep = False
|
|
|
|
|
|
|
|
|
|
from enum import IntEnum, auto
|
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -19,70 +20,95 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
|
|
|
)
|
|
|
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
|
|
|
|
|
|
|
|
_buffer_normal = None
|
|
|
|
|
_buffer_low_latency = None
|
|
|
|
|
|
|
|
|
|
class DeepEPDispatchMode(IntEnum):
|
|
|
|
|
NORMAL = auto()
|
|
|
|
|
LOW_LATENCY = auto()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
|
|
|
|
"""
|
|
|
|
|
Copy from DeepEP example usage in model inference prefilling.
|
|
|
|
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
|
|
|
|
|
"""
|
|
|
|
|
class DeepEPBuffer:
|
|
|
|
|
|
|
|
|
|
global _buffer_normal
|
|
|
|
|
_buffer: Optional[Buffer] = None
|
|
|
|
|
_dispatch_mode: Optional[DeepEPDispatchMode] = None
|
|
|
|
|
_hidden_size: Optional[int] = None
|
|
|
|
|
_num_max_dispatch_tokens_per_rank: Optional[int] = None
|
|
|
|
|
_num_experts: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
num_nvl_bytes, num_rdma_bytes = 0, 0
|
|
|
|
|
for config in (
|
|
|
|
|
Buffer.get_dispatch_config(group.size()),
|
|
|
|
|
Buffer.get_combine_config(group.size()),
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_deepep_buffer(
|
|
|
|
|
cls,
|
|
|
|
|
group: dist.ProcessGroup,
|
|
|
|
|
hidden_size: int,
|
|
|
|
|
param_bytes: int,
|
|
|
|
|
deepep_mode: DeepEPMode,
|
|
|
|
|
num_max_dispatch_tokens_per_rank: int = None,
|
|
|
|
|
num_experts: int = None,
|
|
|
|
|
):
|
|
|
|
|
num_nvl_bytes = max(
|
|
|
|
|
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
|
|
|
|
|
)
|
|
|
|
|
num_rdma_bytes = max(
|
|
|
|
|
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
|
|
|
|
|
)
|
|
|
|
|
if cls._buffer is not None:
|
|
|
|
|
return cls._buffer
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
_buffer_normal is None
|
|
|
|
|
or _buffer_normal.group != group
|
|
|
|
|
or _buffer_normal.num_nvl_bytes < num_nvl_bytes
|
|
|
|
|
or _buffer_normal.num_rdma_bytes < num_rdma_bytes
|
|
|
|
|
):
|
|
|
|
|
_buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes)
|
|
|
|
|
return _buffer_normal
|
|
|
|
|
cls._hidden_size = hidden_size
|
|
|
|
|
cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
|
|
|
|
cls._num_experts = num_experts
|
|
|
|
|
|
|
|
|
|
num_nvl_bytes, num_rdma_bytes = 0, 0
|
|
|
|
|
if deepep_mode.enable_normal():
|
|
|
|
|
hidden_bytes = hidden_size * param_bytes
|
|
|
|
|
for config in (
|
|
|
|
|
Buffer.get_dispatch_config(group.size()),
|
|
|
|
|
Buffer.get_combine_config(group.size()),
|
|
|
|
|
):
|
|
|
|
|
num_nvl_bytes = max(
|
|
|
|
|
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
|
|
|
|
|
num_nvl_bytes,
|
|
|
|
|
)
|
|
|
|
|
num_rdma_bytes = max(
|
|
|
|
|
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()),
|
|
|
|
|
num_rdma_bytes,
|
|
|
|
|
)
|
|
|
|
|
if deepep_mode.enable_low_latency():
|
|
|
|
|
assert num_max_dispatch_tokens_per_rank is not None
|
|
|
|
|
assert num_experts is not None and num_experts % group.size() == 0
|
|
|
|
|
num_rdma_bytes = max(
|
|
|
|
|
Buffer.get_low_latency_rdma_size_hint(
|
|
|
|
|
num_max_dispatch_tokens_per_rank,
|
|
|
|
|
hidden_size,
|
|
|
|
|
group.size(),
|
|
|
|
|
num_experts,
|
|
|
|
|
),
|
|
|
|
|
num_rdma_bytes,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _get_buffer_low_latency(
|
|
|
|
|
group: dist.ProcessGroup,
|
|
|
|
|
num_max_dispatch_tokens_per_rank: int,
|
|
|
|
|
hidden: int,
|
|
|
|
|
num_experts: int,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Copy from DeepEP example usage in model inference decoding.
|
|
|
|
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
global _buffer_low_latency
|
|
|
|
|
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
|
|
|
|
|
num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
_buffer_low_latency is None
|
|
|
|
|
or _buffer_low_latency.group != group
|
|
|
|
|
or not _buffer_low_latency.low_latency_mode
|
|
|
|
|
or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes
|
|
|
|
|
):
|
|
|
|
|
assert num_experts % group.size() == 0
|
|
|
|
|
_buffer_low_latency = Buffer(
|
|
|
|
|
cls._buffer = Buffer(
|
|
|
|
|
group,
|
|
|
|
|
num_rdma_bytes=num_rdma_bytes,
|
|
|
|
|
low_latency_mode=True,
|
|
|
|
|
num_qps_per_rank=num_experts // group.size(),
|
|
|
|
|
num_nvl_bytes,
|
|
|
|
|
num_rdma_bytes,
|
|
|
|
|
low_latency_mode=deepep_mode.enable_low_latency(),
|
|
|
|
|
num_qps_per_rank=(
|
|
|
|
|
num_experts // group.size() if deepep_mode.enable_low_latency() else 1
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
return _buffer_low_latency
|
|
|
|
|
return cls._buffer
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def clean_buffer(cls):
|
|
|
|
|
if not cls._buffer.low_latency_mode:
|
|
|
|
|
return
|
|
|
|
|
cls._buffer.clean_low_latency_buffer(
|
|
|
|
|
cls._num_max_dispatch_tokens_per_rank,
|
|
|
|
|
cls._hidden_size,
|
|
|
|
|
cls._num_experts,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def set_dispatch_mode_as_normal(cls):
|
|
|
|
|
cls._dispatch_mode = DeepEPDispatchMode.NORMAL
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def set_dispatch_mode_as_low_latency(cls):
|
|
|
|
|
if cls._dispatch_mode == DeepEPDispatchMode.NORMAL:
|
|
|
|
|
cls.clean_buffer()
|
|
|
|
|
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DeepEPDispatcherImplBase:
|
|
|
|
|
@@ -95,6 +121,7 @@ class _DeepEPDispatcherImplBase:
|
|
|
|
|
num_local_experts: int,
|
|
|
|
|
hidden_size: int,
|
|
|
|
|
params_dtype: torch.dtype,
|
|
|
|
|
deepep_mode: DeepEPMode,
|
|
|
|
|
):
|
|
|
|
|
if not use_deepep:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
@@ -109,7 +136,10 @@ class _DeepEPDispatcherImplBase:
|
|
|
|
|
self.num_local_experts = num_local_experts
|
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
|
self.params_dtype = params_dtype
|
|
|
|
|
self.deepep_mode = deepep_mode
|
|
|
|
|
|
|
|
|
|
self.params_bytes = 2
|
|
|
|
|
self.num_max_dispatch_tokens_per_rank = 128
|
|
|
|
|
|
|
|
|
|
self.handle = None
|
|
|
|
|
|
|
|
|
|
@@ -118,8 +148,6 @@ class _DeepEPDispatcherImplBase:
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
topk_idx: torch.Tensor,
|
|
|
|
|
topk_weights: torch.Tensor,
|
|
|
|
|
num_experts: int,
|
|
|
|
|
num_max_dispatch_tokens_per_rank: int,
|
|
|
|
|
):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
@@ -137,14 +165,14 @@ class _DeepEPDispatcherImplBase:
|
|
|
|
|
def combine_b(self, *args, **kwargs):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def _get_buffer(self) -> Buffer:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
|
|
|
def __init__(self, async_finish: bool, **kwargs):
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
|
self.buffer_normal = _get_buffer_normal(
|
|
|
|
|
self.group, self.hidden_size * self.params_bytes
|
|
|
|
|
)
|
|
|
|
|
self.async_finish = async_finish
|
|
|
|
|
self.src2dst = None
|
|
|
|
|
|
|
|
|
|
@@ -153,24 +181,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
topk_idx: torch.Tensor,
|
|
|
|
|
topk_weights: torch.Tensor,
|
|
|
|
|
num_experts: int,
|
|
|
|
|
num_max_dispatch_tokens_per_rank: int,
|
|
|
|
|
):
|
|
|
|
|
topk_idx = topk_idx.to(torch.int64)
|
|
|
|
|
previous_event = Buffer.capture() if self.async_finish else None
|
|
|
|
|
return hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
|
|
|
|
return hidden_states, topk_idx, topk_weights, previous_event
|
|
|
|
|
|
|
|
|
|
def dispatch_b(
|
|
|
|
|
self, hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
|
|
|
|
):
|
|
|
|
|
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
|
|
|
|
(
|
|
|
|
|
hidden_states,
|
|
|
|
|
topk_idx,
|
|
|
|
|
topk_weights,
|
|
|
|
|
event,
|
|
|
|
|
) = self._dispatch_core(
|
|
|
|
|
hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
|
|
|
|
)
|
|
|
|
|
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
|
|
|
|
|
event.current_stream_wait() if self.async_finish else ()
|
|
|
|
|
if hidden_states.shape[0] > 0:
|
|
|
|
|
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
|
|
|
|
@@ -181,7 +203,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
|
|
|
(0,), device=hidden_states.device, dtype=torch.int64
|
|
|
|
|
)
|
|
|
|
|
seg_indptr = torch.zeros(
|
|
|
|
|
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
|
|
|
|
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
masked_m = expected_m = None
|
|
|
|
|
@@ -201,18 +223,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
topk_idx: torch.Tensor,
|
|
|
|
|
topk_weights: torch.Tensor,
|
|
|
|
|
num_experts: int,
|
|
|
|
|
previous_event,
|
|
|
|
|
):
|
|
|
|
|
buffer = self._get_buffer()
|
|
|
|
|
(
|
|
|
|
|
num_tokens_per_rank,
|
|
|
|
|
num_tokens_per_rdma_rank,
|
|
|
|
|
num_tokens_per_expert,
|
|
|
|
|
is_token_in_rank,
|
|
|
|
|
previous_event,
|
|
|
|
|
) = self.buffer_normal.get_dispatch_layout(
|
|
|
|
|
) = buffer.get_dispatch_layout(
|
|
|
|
|
topk_idx,
|
|
|
|
|
num_experts,
|
|
|
|
|
self.num_experts,
|
|
|
|
|
previous_event=previous_event,
|
|
|
|
|
async_finish=self.async_finish,
|
|
|
|
|
allocate_on_comm_stream=previous_event is not None,
|
|
|
|
|
@@ -221,6 +243,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
|
|
|
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
|
|
|
|
|
# However, doing this would incur an unknown synchronization error, but keeping
|
|
|
|
|
# `handle` as a member variable works.
|
|
|
|
|
|
|
|
|
|
(
|
|
|
|
|
recv_x,
|
|
|
|
|
recv_topk_idx,
|
|
|
|
|
@@ -228,7 +251,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
|
|
|
_, # num_recv_tokens_per_expert_list
|
|
|
|
|
self.handle,
|
|
|
|
|
event,
|
|
|
|
|
) = self.buffer_normal.dispatch(
|
|
|
|
|
) = buffer.dispatch(
|
|
|
|
|
x,
|
|
|
|
|
topk_idx=topk_idx,
|
|
|
|
|
topk_weights=topk_weights,
|
|
|
|
|
@@ -327,7 +350,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
def _combine_core(self, x: torch.Tensor, previous_event):
|
|
|
|
|
combined_x, _, event = self.buffer_normal.combine(
|
|
|
|
|
buffer = self._get_buffer()
|
|
|
|
|
combined_x, _, event = buffer.combine(
|
|
|
|
|
x,
|
|
|
|
|
self.handle,
|
|
|
|
|
async_finish=self.async_finish,
|
|
|
|
|
@@ -336,6 +360,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
|
|
|
)
|
|
|
|
|
return combined_x, event
|
|
|
|
|
|
|
|
|
|
def _get_buffer(self):
|
|
|
|
|
DeepEPBuffer.set_dispatch_mode_as_normal()
|
|
|
|
|
return DeepEPBuffer.get_deepep_buffer(
|
|
|
|
|
self.group,
|
|
|
|
|
self.hidden_size,
|
|
|
|
|
self.params_bytes,
|
|
|
|
|
self.deepep_mode,
|
|
|
|
|
self.num_max_dispatch_tokens_per_rank,
|
|
|
|
|
self.num_experts,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
|
|
|
def __init__(self, return_recv_hook: bool, **kwargs):
|
|
|
|
|
@@ -345,14 +380,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
|
|
|
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
|
|
|
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
|
|
|
|
"""
|
|
|
|
|
# TODO(ch-wan): allow users to set this value
|
|
|
|
|
self.num_max_dispatch_tokens_per_rank = 128
|
|
|
|
|
self.buffer_low_latency = _get_buffer_low_latency(
|
|
|
|
|
self.group,
|
|
|
|
|
self.num_max_dispatch_tokens_per_rank,
|
|
|
|
|
self.hidden_size,
|
|
|
|
|
self.num_experts,
|
|
|
|
|
)
|
|
|
|
|
self.return_recv_hook = return_recv_hook
|
|
|
|
|
|
|
|
|
|
def dispatch_a(
|
|
|
|
|
@@ -360,21 +387,16 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
topk_idx: torch.Tensor,
|
|
|
|
|
topk_weights: torch.Tensor,
|
|
|
|
|
num_experts: int,
|
|
|
|
|
num_max_dispatch_tokens_per_rank: int,
|
|
|
|
|
):
|
|
|
|
|
buffer = self._get_buffer()
|
|
|
|
|
topk_idx = topk_idx.to(torch.int64)
|
|
|
|
|
expected_m = (
|
|
|
|
|
hidden_states.shape[0]
|
|
|
|
|
* self.buffer_low_latency.group_size
|
|
|
|
|
* topk_idx.shape[1]
|
|
|
|
|
+ num_experts
|
|
|
|
|
) // num_experts
|
|
|
|
|
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
|
|
|
|
|
+ self.num_experts
|
|
|
|
|
) // self.num_experts
|
|
|
|
|
hidden_states, masked_m, event, hook = self._dispatch_core(
|
|
|
|
|
hidden_states,
|
|
|
|
|
topk_idx,
|
|
|
|
|
num_max_dispatch_tokens_per_rank,
|
|
|
|
|
num_experts,
|
|
|
|
|
use_fp8=True,
|
|
|
|
|
)
|
|
|
|
|
return (
|
|
|
|
|
@@ -415,8 +437,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
|
|
|
self,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
topk_idx: torch.Tensor,
|
|
|
|
|
num_max_dispatch_tokens_per_rank: int,
|
|
|
|
|
num_experts: int,
|
|
|
|
|
use_fp8: bool = False,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
@@ -451,13 +471,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
|
|
|
|
|
|
|
|
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
buffer = self._get_buffer()
|
|
|
|
|
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
|
|
|
|
self.buffer_low_latency.low_latency_dispatch(
|
|
|
|
|
buffer.low_latency_dispatch(
|
|
|
|
|
hidden_states,
|
|
|
|
|
topk_idx,
|
|
|
|
|
num_max_dispatch_tokens_per_rank,
|
|
|
|
|
num_experts,
|
|
|
|
|
self.num_max_dispatch_tokens_per_rank,
|
|
|
|
|
self.num_experts,
|
|
|
|
|
use_fp8=use_fp8,
|
|
|
|
|
async_finish=not self.return_recv_hook,
|
|
|
|
|
return_recv_hook=self.return_recv_hook,
|
|
|
|
|
@@ -488,19 +508,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
|
|
|
topk_idx: torch.Tensor,
|
|
|
|
|
topk_weights: torch.Tensor,
|
|
|
|
|
):
|
|
|
|
|
combined_hidden_states, event, hook = (
|
|
|
|
|
self.buffer_low_latency.low_latency_combine(
|
|
|
|
|
hidden_states,
|
|
|
|
|
topk_idx,
|
|
|
|
|
topk_weights,
|
|
|
|
|
self.handle,
|
|
|
|
|
async_finish=not self.return_recv_hook,
|
|
|
|
|
return_recv_hook=self.return_recv_hook,
|
|
|
|
|
)
|
|
|
|
|
buffer = self._get_buffer()
|
|
|
|
|
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
|
|
|
|
hidden_states,
|
|
|
|
|
topk_idx,
|
|
|
|
|
topk_weights,
|
|
|
|
|
self.handle,
|
|
|
|
|
async_finish=not self.return_recv_hook,
|
|
|
|
|
return_recv_hook=self.return_recv_hook,
|
|
|
|
|
)
|
|
|
|
|
self.handle = None
|
|
|
|
|
return combined_hidden_states, event, hook
|
|
|
|
|
|
|
|
|
|
def _get_buffer(self):
|
|
|
|
|
DeepEPBuffer.set_dispatch_mode_as_low_latency()
|
|
|
|
|
return DeepEPBuffer.get_deepep_buffer(
|
|
|
|
|
self.group,
|
|
|
|
|
self.hidden_size,
|
|
|
|
|
self.params_bytes,
|
|
|
|
|
self.deepep_mode,
|
|
|
|
|
self.num_max_dispatch_tokens_per_rank,
|
|
|
|
|
self.num_experts,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepEPDispatcher:
|
|
|
|
|
def __init__(
|
|
|
|
|
@@ -526,18 +556,19 @@ class DeepEPDispatcher:
|
|
|
|
|
num_local_experts=num_local_experts,
|
|
|
|
|
hidden_size=hidden_size,
|
|
|
|
|
params_dtype=params_dtype,
|
|
|
|
|
deepep_mode=deepep_mode,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.deepep_mode.enable_normal():
|
|
|
|
|
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
|
|
|
|
|
async_finish=async_finish,
|
|
|
|
|
**common_kwargs,
|
|
|
|
|
)
|
|
|
|
|
if self.deepep_mode.enable_low_latency():
|
|
|
|
|
self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
|
|
|
|
|
return_recv_hook=return_recv_hook,
|
|
|
|
|
**common_kwargs,
|
|
|
|
|
)
|
|
|
|
|
if self.deepep_mode.enable_normal():
|
|
|
|
|
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
|
|
|
|
|
async_finish=async_finish,
|
|
|
|
|
**common_kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def dispatch(self, *args, **kwargs) -> Tuple:
|
|
|
|
|
self.dispatch_a(*args, **kwargs)
|
|
|
|
|
@@ -548,16 +579,12 @@ class DeepEPDispatcher:
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
topk_idx: torch.Tensor,
|
|
|
|
|
topk_weights: torch.Tensor,
|
|
|
|
|
num_experts: int,
|
|
|
|
|
num_max_dispatch_tokens_per_rank: int = 128,
|
|
|
|
|
forward_mode: ForwardMode = None,
|
|
|
|
|
):
|
|
|
|
|
inner_state = self._get_impl(forward_mode).dispatch_a(
|
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
topk_idx=topk_idx,
|
|
|
|
|
topk_weights=topk_weights,
|
|
|
|
|
num_experts=num_experts,
|
|
|
|
|
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
|
|
|
|
)
|
|
|
|
|
self._dispatch_intermediate_state = forward_mode, inner_state
|
|
|
|
|
|
|
|
|
|
@@ -589,7 +616,7 @@ class DeepEPDispatcher:
|
|
|
|
|
del self._combine_intermediate_state
|
|
|
|
|
return self._get_impl(forward_mode).combine_b(*inner_state)
|
|
|
|
|
|
|
|
|
|
def _get_impl(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase":
|
|
|
|
|
def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
|
|
|
|
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
|
|
|
|
if resolved_deepep_mode == DeepEPMode.normal:
|
|
|
|
|
return self._normal_dispatcher
|
|
|
|
|
|