Support single batch overlap (#10422)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -38,10 +39,12 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.offloader import get_offloader
|
||||
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||
from sglang.srt.utils import (
|
||||
ceil_div,
|
||||
dispose_tensor,
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
@@ -466,7 +469,11 @@ class DeepEPMoE(EPMoE):
|
||||
),
|
||||
)
|
||||
|
||||
def moe_impl(self, dispatch_output: DispatchOutput):
|
||||
def moe_impl(
|
||||
self,
|
||||
dispatch_output: DispatchOutput,
|
||||
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
|
||||
):
|
||||
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
||||
|
||||
if _use_aiter:
|
||||
@@ -481,7 +488,9 @@ class DeepEPMoE(EPMoE):
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
||||
return self.forward_flashinfer_cutedsl(dispatch_output)
|
||||
return self.forward_flashinfer_cutedsl(
|
||||
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
||||
)
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
return self.forward_deepgemm_masked(dispatch_output)
|
||||
else:
|
||||
@@ -495,12 +504,14 @@ class DeepEPMoE(EPMoE):
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
overlap_args: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
return self.deepep_dispatcher.combine(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
overlap_args=overlap_args,
|
||||
)
|
||||
|
||||
def forward_aiter(
|
||||
@@ -687,6 +698,7 @@ class DeepEPMoE(EPMoE):
|
||||
def forward_flashinfer_cutedsl(
|
||||
self,
|
||||
dispatch_output: DeepEPLLOutput,
|
||||
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
|
||||
):
|
||||
hidden_states, _, _, masked_m, _ = dispatch_output
|
||||
assert self.quant_method is not None
|
||||
@@ -697,6 +709,7 @@ class DeepEPMoE(EPMoE):
|
||||
x=hidden_states,
|
||||
masked_m=masked_m,
|
||||
moe_runner_config=self.moe_runner_config,
|
||||
down_gemm_overlap_args=down_gemm_overlap_args,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@@ -30,6 +30,9 @@ def flashinfer_cutedsl_moe_masked(
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alpha,
|
||||
masked_m: torch.Tensor,
|
||||
down_sm_count: Optional[int] = None,
|
||||
down_signals: Optional[torch.Tensor] = None,
|
||||
down_start_event: Optional[torch.cuda.Event] = None,
|
||||
):
|
||||
"""
|
||||
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
|
||||
@@ -151,6 +154,9 @@ def flashinfer_cutedsl_moe_masked(
|
||||
masked_m,
|
||||
)
|
||||
|
||||
if down_start_event is not None:
|
||||
down_start_event.record()
|
||||
|
||||
# Gemm2
|
||||
out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device)
|
||||
out = out.permute(1, 2, 0) # requirement of kernel
|
||||
@@ -165,5 +171,13 @@ def flashinfer_cutedsl_moe_masked(
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=w2_alpha.view(1, 1, num_experts),
|
||||
alpha_dtype=get_cute_dtype(w2_alpha),
|
||||
**(
|
||||
dict(
|
||||
sm_count=down_sm_count,
|
||||
dst_signals=down_signals,
|
||||
)
|
||||
if down_sm_count is not None or down_signals is not None
|
||||
else {}
|
||||
),
|
||||
) # in logical [m, k, l]
|
||||
return out.permute(2, 0, 1)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
@@ -25,6 +26,9 @@ from sglang.srt.utils import (
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.single_batch_overlap import CombineOverlapArgs
|
||||
|
||||
try:
|
||||
from deep_ep import Buffer, Config
|
||||
|
||||
@@ -310,6 +314,7 @@ class _DeepEPDispatcherImplBase:
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -428,6 +433,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
deepep_post_reorder_triton_kernel,
|
||||
@@ -503,6 +509,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||
"""
|
||||
self.return_recv_hook = return_recv_hook
|
||||
self.device_module = torch.get_device_module()
|
||||
|
||||
def dispatch_a(
|
||||
self,
|
||||
@@ -570,7 +577,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
use_fp8 = True
|
||||
|
||||
buffer = self._get_buffer()
|
||||
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
||||
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
|
||||
buffer.low_latency_dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
@@ -591,23 +598,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
|
||||
)
|
||||
)
|
||||
return packed_recv_hidden, packed_recv_count, event, hook
|
||||
return packed_recv_hidden, self.packed_recv_count, event, hook
|
||||
|
||||
def combine_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
hidden_states, event, hook = self._combine_core(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
overlap_args=overlap_args,
|
||||
)
|
||||
return hidden_states, event, hook
|
||||
return hidden_states, event, hook, overlap_args
|
||||
|
||||
def combine_b(self, hidden_states, event, hook):
|
||||
def combine_b(self, hidden_states, event, hook, overlap_args):
|
||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||
|
||||
if overlap_args is not None:
|
||||
self.device_module.current_stream().wait_stream(overlap_args.stream)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _combine_core(
|
||||
@@ -615,17 +628,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
buffer = self._get_buffer()
|
||||
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.handle,
|
||||
async_finish=not self.return_recv_hook,
|
||||
return_recv_hook=self.return_recv_hook,
|
||||
)
|
||||
self.handle = None
|
||||
|
||||
ctx = nullcontext()
|
||||
if overlap_args is not None:
|
||||
overlap_args.stream.wait_event(overlap_args.wait_event)
|
||||
ctx = torch.cuda.stream(overlap_args.stream)
|
||||
|
||||
with ctx:
|
||||
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
||||
x=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
handle=self.handle,
|
||||
async_finish=not self.return_recv_hook,
|
||||
return_recv_hook=self.return_recv_hook,
|
||||
**(
|
||||
dict(
|
||||
overlap=overlap_args.overlap,
|
||||
src_signals=overlap_args.signal,
|
||||
src_signal_expect_value=overlap_args.threshold,
|
||||
)
|
||||
if overlap_args is not None
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
self.packed_recv_count = self.handle = None
|
||||
return combined_hidden_states, event, hook
|
||||
|
||||
def _get_buffer(self):
|
||||
@@ -727,12 +758,14 @@ class DeepEPDispatcher(BaseDispatcher):
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
overlap_args: Optional["CombineOverlapArgs"] = None,
|
||||
):
|
||||
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
||||
inner_state = self._get_impl(forward_batch).combine_a(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
overlap_args=overlap_args,
|
||||
)
|
||||
self._combine_intermediate_state = forward_batch, inner_state
|
||||
|
||||
|
||||
@@ -108,6 +108,7 @@ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
|
||||
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
|
||||
DEEPEP_MODE: Optional[DeepEPMode] = None
|
||||
IS_TBO_ENABLED: Optional[bool] = None
|
||||
IS_SBO_ENABLED: Optional[bool] = None
|
||||
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
|
||||
DEEPEP_CONFIG: Optional[str] = None
|
||||
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
|
||||
@@ -119,6 +120,7 @@ def initialize_moe_config(server_args: ServerArgs):
|
||||
global DEEPEP_MODE
|
||||
global DEEPEP_CONFIG
|
||||
global IS_TBO_ENABLED
|
||||
global IS_SBO_ENABLED
|
||||
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
|
||||
|
||||
@@ -127,6 +129,7 @@ def initialize_moe_config(server_args: ServerArgs):
|
||||
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
|
||||
DEEPEP_CONFIG = server_args.deepep_config or ""
|
||||
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
|
||||
IS_SBO_ENABLED = server_args.enable_single_batch_overlap
|
||||
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
|
||||
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
|
||||
server_args.disable_flashinfer_cutlass_moe_fp4_allgather
|
||||
@@ -172,6 +175,13 @@ def is_tbo_enabled() -> bool:
|
||||
return IS_TBO_ENABLED
|
||||
|
||||
|
||||
def is_sbo_enabled() -> bool:
|
||||
global IS_SBO_ENABLED
|
||||
if IS_SBO_ENABLED is None:
|
||||
IS_SBO_ENABLED = False
|
||||
return IS_SBO_ENABLED
|
||||
|
||||
|
||||
def get_tbo_token_distribution_threshold() -> float:
|
||||
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
|
||||
|
||||
@@ -47,6 +47,7 @@ if TYPE_CHECKING:
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import scaled_fp4_quant
|
||||
@@ -1468,6 +1469,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
@@ -1495,5 +1497,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alpha=layer.g2_alphas,
|
||||
masked_m=masked_m,
|
||||
**(
|
||||
dict(
|
||||
down_sm_count=down_gemm_overlap_args.num_sms,
|
||||
down_signals=down_gemm_overlap_args.signal,
|
||||
down_start_event=down_gemm_overlap_args.start_event,
|
||||
)
|
||||
if down_gemm_overlap_args is not None
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user