Support single batch overlap (#10422)
This commit is contained in:
@@ -294,6 +294,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False |
|
| `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False |
|
||||||
| `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False |
|
| `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False |
|
||||||
| `--tbo-token-distribution-threshold` | The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap. | 0.48 |
|
| `--tbo-token-distribution-threshold` | The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap. | 0.48 |
|
||||||
|
| `--enable-single-batch-overlap` | Enabling single batch overlap. | False |
|
||||||
| `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False |
|
| `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False |
|
||||||
| `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 |
|
| `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 |
|
||||||
| `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | |
|
| `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | |
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
from contextlib import nullcontext
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@@ -38,10 +39,12 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.offloader import get_offloader
|
from sglang.srt.offloader import get_offloader
|
||||||
|
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
ceil_div,
|
ceil_div,
|
||||||
dispose_tensor,
|
dispose_tensor,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
|
get_int_env_var,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
is_npu,
|
is_npu,
|
||||||
@@ -466,7 +469,11 @@ class DeepEPMoE(EPMoE):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def moe_impl(self, dispatch_output: DispatchOutput):
|
def moe_impl(
|
||||||
|
self,
|
||||||
|
dispatch_output: DispatchOutput,
|
||||||
|
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
|
||||||
|
):
|
||||||
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
||||||
|
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
@@ -481,7 +488,9 @@ class DeepEPMoE(EPMoE):
|
|||||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||||
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
||||||
return self.forward_flashinfer_cutedsl(dispatch_output)
|
return self.forward_flashinfer_cutedsl(
|
||||||
|
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
||||||
|
)
|
||||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||||
return self.forward_deepgemm_masked(dispatch_output)
|
return self.forward_deepgemm_masked(dispatch_output)
|
||||||
else:
|
else:
|
||||||
@@ -495,12 +504,14 @@ class DeepEPMoE(EPMoE):
|
|||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
|
overlap_args: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
return self.deepep_dispatcher.combine(
|
return self.deepep_dispatcher.combine(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
|
overlap_args=overlap_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_aiter(
|
def forward_aiter(
|
||||||
@@ -687,6 +698,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
def forward_flashinfer_cutedsl(
|
def forward_flashinfer_cutedsl(
|
||||||
self,
|
self,
|
||||||
dispatch_output: DeepEPLLOutput,
|
dispatch_output: DeepEPLLOutput,
|
||||||
|
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
|
||||||
):
|
):
|
||||||
hidden_states, _, _, masked_m, _ = dispatch_output
|
hidden_states, _, _, masked_m, _ = dispatch_output
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
@@ -697,6 +709,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
masked_m=masked_m,
|
masked_m=masked_m,
|
||||||
moe_runner_config=self.moe_runner_config,
|
moe_runner_config=self.moe_runner_config,
|
||||||
|
down_gemm_overlap_args=down_gemm_overlap_args,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,9 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
w2_blockscale: torch.Tensor,
|
w2_blockscale: torch.Tensor,
|
||||||
w2_alpha,
|
w2_alpha,
|
||||||
masked_m: torch.Tensor,
|
masked_m: torch.Tensor,
|
||||||
|
down_sm_count: Optional[int] = None,
|
||||||
|
down_signals: Optional[torch.Tensor] = None,
|
||||||
|
down_start_event: Optional[torch.cuda.Event] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
|
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
|
||||||
@@ -151,6 +154,9 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
masked_m,
|
masked_m,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if down_start_event is not None:
|
||||||
|
down_start_event.record()
|
||||||
|
|
||||||
# Gemm2
|
# Gemm2
|
||||||
out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device)
|
out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device)
|
||||||
out = out.permute(1, 2, 0) # requirement of kernel
|
out = out.permute(1, 2, 0) # requirement of kernel
|
||||||
@@ -165,5 +171,13 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
sf_vec_size=sf_vec_size,
|
sf_vec_size=sf_vec_size,
|
||||||
alpha=w2_alpha.view(1, 1, num_experts),
|
alpha=w2_alpha.view(1, 1, num_experts),
|
||||||
alpha_dtype=get_cute_dtype(w2_alpha),
|
alpha_dtype=get_cute_dtype(w2_alpha),
|
||||||
|
**(
|
||||||
|
dict(
|
||||||
|
sm_count=down_sm_count,
|
||||||
|
dst_signals=down_signals,
|
||||||
|
)
|
||||||
|
if down_sm_count is not None or down_signals is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
) # in logical [m, k, l]
|
) # in logical [m, k, l]
|
||||||
return out.permute(2, 0, 1)
|
return out.permute(2, 0, 1)
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from contextlib import nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||||
@@ -25,6 +26,9 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
_is_npu = is_npu()
|
_is_npu = is_npu()
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.single_batch_overlap import CombineOverlapArgs
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deep_ep import Buffer, Config
|
from deep_ep import Buffer, Config
|
||||||
|
|
||||||
@@ -310,6 +314,7 @@ class _DeepEPDispatcherImplBase:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
|
overlap_args: Optional["CombineOverlapArgs"],
|
||||||
):
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -428,6 +433,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
|
overlap_args: Optional["CombineOverlapArgs"],
|
||||||
):
|
):
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
deepep_post_reorder_triton_kernel,
|
deepep_post_reorder_triton_kernel,
|
||||||
@@ -503,6 +509,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||||
"""
|
"""
|
||||||
self.return_recv_hook = return_recv_hook
|
self.return_recv_hook = return_recv_hook
|
||||||
|
self.device_module = torch.get_device_module()
|
||||||
|
|
||||||
def dispatch_a(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
@@ -570,7 +577,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
use_fp8 = True
|
use_fp8 = True
|
||||||
|
|
||||||
buffer = self._get_buffer()
|
buffer = self._get_buffer()
|
||||||
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
|
||||||
buffer.low_latency_dispatch(
|
buffer.low_latency_dispatch(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
@@ -591,23 +598,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
|
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return packed_recv_hidden, packed_recv_count, event, hook
|
return packed_recv_hidden, self.packed_recv_count, event, hook
|
||||||
|
|
||||||
def combine_a(
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
|
overlap_args: Optional["CombineOverlapArgs"],
|
||||||
):
|
):
|
||||||
hidden_states, event, hook = self._combine_core(
|
hidden_states, event, hook = self._combine_core(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
|
overlap_args=overlap_args,
|
||||||
)
|
)
|
||||||
return hidden_states, event, hook
|
return hidden_states, event, hook, overlap_args
|
||||||
|
|
||||||
def combine_b(self, hidden_states, event, hook):
|
def combine_b(self, hidden_states, event, hook, overlap_args):
|
||||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||||
|
|
||||||
|
if overlap_args is not None:
|
||||||
|
self.device_module.current_stream().wait_stream(overlap_args.stream)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def _combine_core(
|
def _combine_core(
|
||||||
@@ -615,17 +628,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
|
overlap_args: Optional["CombineOverlapArgs"],
|
||||||
):
|
):
|
||||||
buffer = self._get_buffer()
|
buffer = self._get_buffer()
|
||||||
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
|
||||||
hidden_states,
|
ctx = nullcontext()
|
||||||
topk_idx,
|
if overlap_args is not None:
|
||||||
topk_weights,
|
overlap_args.stream.wait_event(overlap_args.wait_event)
|
||||||
self.handle,
|
ctx = torch.cuda.stream(overlap_args.stream)
|
||||||
async_finish=not self.return_recv_hook,
|
|
||||||
return_recv_hook=self.return_recv_hook,
|
with ctx:
|
||||||
)
|
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
||||||
self.handle = None
|
x=hidden_states,
|
||||||
|
topk_idx=topk_idx,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
handle=self.handle,
|
||||||
|
async_finish=not self.return_recv_hook,
|
||||||
|
return_recv_hook=self.return_recv_hook,
|
||||||
|
**(
|
||||||
|
dict(
|
||||||
|
overlap=overlap_args.overlap,
|
||||||
|
src_signals=overlap_args.signal,
|
||||||
|
src_signal_expect_value=overlap_args.threshold,
|
||||||
|
)
|
||||||
|
if overlap_args is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.packed_recv_count = self.handle = None
|
||||||
return combined_hidden_states, event, hook
|
return combined_hidden_states, event, hook
|
||||||
|
|
||||||
def _get_buffer(self):
|
def _get_buffer(self):
|
||||||
@@ -727,12 +758,14 @@ class DeepEPDispatcher(BaseDispatcher):
|
|||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
|
overlap_args: Optional["CombineOverlapArgs"] = None,
|
||||||
):
|
):
|
||||||
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
||||||
inner_state = self._get_impl(forward_batch).combine_a(
|
inner_state = self._get_impl(forward_batch).combine_a(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
|
overlap_args=overlap_args,
|
||||||
)
|
)
|
||||||
self._combine_intermediate_state = forward_batch, inner_state
|
self._combine_intermediate_state = forward_batch, inner_state
|
||||||
|
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
|
|||||||
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
|
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
|
||||||
DEEPEP_MODE: Optional[DeepEPMode] = None
|
DEEPEP_MODE: Optional[DeepEPMode] = None
|
||||||
IS_TBO_ENABLED: Optional[bool] = None
|
IS_TBO_ENABLED: Optional[bool] = None
|
||||||
|
IS_SBO_ENABLED: Optional[bool] = None
|
||||||
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
|
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
|
||||||
DEEPEP_CONFIG: Optional[str] = None
|
DEEPEP_CONFIG: Optional[str] = None
|
||||||
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
|
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
|
||||||
@@ -119,6 +120,7 @@ def initialize_moe_config(server_args: ServerArgs):
|
|||||||
global DEEPEP_MODE
|
global DEEPEP_MODE
|
||||||
global DEEPEP_CONFIG
|
global DEEPEP_CONFIG
|
||||||
global IS_TBO_ENABLED
|
global IS_TBO_ENABLED
|
||||||
|
global IS_SBO_ENABLED
|
||||||
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||||
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
|
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
|
||||||
|
|
||||||
@@ -127,6 +129,7 @@ def initialize_moe_config(server_args: ServerArgs):
|
|||||||
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
|
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
|
||||||
DEEPEP_CONFIG = server_args.deepep_config or ""
|
DEEPEP_CONFIG = server_args.deepep_config or ""
|
||||||
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
|
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
|
||||||
|
IS_SBO_ENABLED = server_args.enable_single_batch_overlap
|
||||||
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
|
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
|
||||||
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
|
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
|
||||||
server_args.disable_flashinfer_cutlass_moe_fp4_allgather
|
server_args.disable_flashinfer_cutlass_moe_fp4_allgather
|
||||||
@@ -172,6 +175,13 @@ def is_tbo_enabled() -> bool:
|
|||||||
return IS_TBO_ENABLED
|
return IS_TBO_ENABLED
|
||||||
|
|
||||||
|
|
||||||
|
def is_sbo_enabled() -> bool:
|
||||||
|
global IS_SBO_ENABLED
|
||||||
|
if IS_SBO_ENABLED is None:
|
||||||
|
IS_SBO_ENABLED = False
|
||||||
|
return IS_SBO_ENABLED
|
||||||
|
|
||||||
|
|
||||||
def get_tbo_token_distribution_threshold() -> float:
|
def get_tbo_token_distribution_threshold() -> float:
|
||||||
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||||
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
|
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ if TYPE_CHECKING:
|
|||||||
CombineInput,
|
CombineInput,
|
||||||
StandardDispatchOutput,
|
StandardDispatchOutput,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||||
|
|
||||||
if is_cuda():
|
if is_cuda():
|
||||||
from sgl_kernel import scaled_fp4_quant
|
from sgl_kernel import scaled_fp4_quant
|
||||||
@@ -1468,6 +1469,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
masked_m: torch.Tensor,
|
masked_m: torch.Tensor,
|
||||||
moe_runner_config: MoeRunnerConfig,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
|
down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert (
|
assert (
|
||||||
moe_runner_config.activation == "silu"
|
moe_runner_config.activation == "silu"
|
||||||
@@ -1495,5 +1497,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||||
w2_alpha=layer.g2_alphas,
|
w2_alpha=layer.g2_alphas,
|
||||||
masked_m=masked_m,
|
masked_m=masked_m,
|
||||||
|
**(
|
||||||
|
dict(
|
||||||
|
down_sm_count=down_gemm_overlap_args.num_sms,
|
||||||
|
down_signals=down_gemm_overlap_args.signal,
|
||||||
|
down_start_event=down_gemm_overlap_args.start_event,
|
||||||
|
)
|
||||||
|
if down_gemm_overlap_args is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from torch import nn
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from sglang.srt import single_batch_overlap
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_moe_expert_parallel_world_size,
|
get_moe_expert_parallel_world_size,
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
@@ -101,6 +102,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
from sglang.srt.single_batch_overlap import SboFlags
|
||||||
from sglang.srt.two_batch_overlap import (
|
from sglang.srt.two_batch_overlap import (
|
||||||
MaybeTboDeepEPDispatcher,
|
MaybeTboDeepEPDispatcher,
|
||||||
model_forward_maybe_tbo,
|
model_forward_maybe_tbo,
|
||||||
@@ -806,7 +808,8 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
if not SboFlags.fuse_shared_experts_inside_sbo():
|
||||||
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
topk_weights, topk_idx, _ = self.topk(
|
topk_weights, topk_idx, _ = self.topk(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
router_logits,
|
router_logits,
|
||||||
@@ -820,12 +823,18 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
hidden_states.device
|
hidden_states.device
|
||||||
)
|
)
|
||||||
|
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
|
# SBO args
|
||||||
|
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
|
||||||
|
experts=self.experts,
|
||||||
|
alt_stream=self.alt_stream,
|
||||||
)
|
)
|
||||||
|
if sbo_shared_output is not None:
|
||||||
|
shared_output = sbo_shared_output
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
x = shared_output
|
x = shared_output
|
||||||
@@ -843,7 +852,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
def _forward_shared_experts(
|
def _forward_shared_experts(
|
||||||
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
|
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
|
||||||
):
|
):
|
||||||
if self.num_fused_shared_experts == 0:
|
if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
|
||||||
return self.shared_experts(
|
return self.shared_experts(
|
||||||
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
|
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -377,6 +377,7 @@ class ServerArgs:
|
|||||||
enable_dp_attention: bool = False
|
enable_dp_attention: bool = False
|
||||||
enable_dp_lm_head: bool = False
|
enable_dp_lm_head: bool = False
|
||||||
enable_two_batch_overlap: bool = False
|
enable_two_batch_overlap: bool = False
|
||||||
|
enable_single_batch_overlap: bool = False
|
||||||
tbo_token_distribution_threshold: float = 0.48
|
tbo_token_distribution_threshold: float = 0.48
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
@@ -2457,6 +2458,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enabling two micro batches to overlap.",
|
help="Enabling two micro batches to overlap.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-single-batch-overlap",
|
||||||
|
action="store_true",
|
||||||
|
help="Let computation and communication overlap within one micro batch.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tbo-token-distribution-threshold",
|
"--tbo-token-distribution-threshold",
|
||||||
type=float,
|
type=float,
|
||||||
|
|||||||
151
python/sglang/srt/single_batch_overlap.py
Normal file
151
python/sglang/srt/single_batch_overlap.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe import get_moe_runner_backend
|
||||||
|
from sglang.srt.layers.moe.utils import is_sbo_enabled
|
||||||
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
from sglang.srt.utils import get_int_env_var
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
|
||||||
|
|
||||||
|
|
||||||
|
class SboFlags:
|
||||||
|
# TODO may have: "enable_dispatch_shared_one_stream_overlap", "enable_dispatch_gateup_gemm_two_stream_overlap", ...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enable_combine_down_gemm_two_stream_overlap(cls):
|
||||||
|
return (
|
||||||
|
is_sbo_enabled()
|
||||||
|
# currently only cutedsl backend supports it
|
||||||
|
and get_moe_runner_backend().is_flashinfer_cutedsl()
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enable_combine_shared_two_stream_overlap(cls):
|
||||||
|
return is_sbo_enabled()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fuse_shared_experts_inside_sbo(cls):
|
||||||
|
# TODO after antgroup's PR, should be `... or cls.enable_dispatch_shared_one_stream_overlap()`
|
||||||
|
return cls.enable_combine_shared_two_stream_overlap()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CombineOverlapArgs:
|
||||||
|
# this "overlap" flag means overlapping with down gemm, not the general two-stream overlap
|
||||||
|
overlap: bool
|
||||||
|
stream: torch.cuda.Stream
|
||||||
|
wait_event: torch.cuda.Event
|
||||||
|
num_sms: int
|
||||||
|
signal: Optional[torch.Tensor] = None
|
||||||
|
threshold: int = -1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DownGemmOverlapArgs:
|
||||||
|
num_sms: int
|
||||||
|
signal: torch.Tensor
|
||||||
|
start_event: torch.cuda.Event
|
||||||
|
|
||||||
|
|
||||||
|
def execute_sbo(
|
||||||
|
forward_shared_experts: Callable[[], Any],
|
||||||
|
experts: "DeepEPMoE",
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
alt_stream: Optional = None,
|
||||||
|
):
|
||||||
|
shared_output = None
|
||||||
|
|
||||||
|
dispatch_output = experts.dispatch(
|
||||||
|
hidden_states, topk_idx, topk_weights, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
|
||||||
|
_compute_overlap_args(dispatch_output, alt_stream)
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = experts.moe_impl(
|
||||||
|
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
||||||
|
)
|
||||||
|
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
|
||||||
|
e.record()
|
||||||
|
|
||||||
|
if SboFlags.enable_combine_shared_two_stream_overlap():
|
||||||
|
# TODO reduce sm for non-deepgemm
|
||||||
|
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
||||||
|
meta_overlap_args["compute_num_sms"]
|
||||||
|
):
|
||||||
|
shared_output = forward_shared_experts()
|
||||||
|
|
||||||
|
hidden_states = experts.combine(
|
||||||
|
hidden_states,
|
||||||
|
dispatch_output.topk_idx,
|
||||||
|
dispatch_output.topk_weights,
|
||||||
|
forward_batch,
|
||||||
|
overlap_args=combine_overlap_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states, shared_output
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_overlap_args(dispatch_output, alt_stream):
|
||||||
|
if not (
|
||||||
|
SboFlags.enable_combine_down_gemm_two_stream_overlap()
|
||||||
|
or SboFlags.enable_combine_shared_two_stream_overlap()
|
||||||
|
):
|
||||||
|
return None, None, {}
|
||||||
|
|
||||||
|
hidden_states = dispatch_output.hidden_states_fp8
|
||||||
|
if isinstance(hidden_states, tuple):
|
||||||
|
hidden_states = hidden_states[0]
|
||||||
|
|
||||||
|
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
|
||||||
|
|
||||||
|
total_num_sms = torch.cuda.get_device_properties(
|
||||||
|
device="cuda"
|
||||||
|
).multi_processor_count
|
||||||
|
communicate_num_sms = get_int_env_var("SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS", 32)
|
||||||
|
compute_num_sms = total_num_sms - communicate_num_sms
|
||||||
|
|
||||||
|
assert alt_stream is not None
|
||||||
|
combine_wait_event = torch.cuda.Event()
|
||||||
|
combine_overlap_args = CombineOverlapArgs(
|
||||||
|
overlap=False,
|
||||||
|
num_sms=communicate_num_sms,
|
||||||
|
stream=alt_stream,
|
||||||
|
wait_event=combine_wait_event,
|
||||||
|
)
|
||||||
|
meta_overlap_args = dict(
|
||||||
|
compute_num_sms=compute_num_sms,
|
||||||
|
)
|
||||||
|
down_gemm_overlap_args = None
|
||||||
|
|
||||||
|
if SboFlags.enable_combine_down_gemm_two_stream_overlap():
|
||||||
|
# TODO use zero_allocator to remove this `torch.zeros` call
|
||||||
|
# NOTE ours v2 use uint32 not int32 currently
|
||||||
|
combine_signal = torch.zeros(
|
||||||
|
num_local_experts, dtype=torch.uint32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
down_gemm_overlap_args = DownGemmOverlapArgs(
|
||||||
|
signal=combine_signal,
|
||||||
|
start_event=combine_wait_event,
|
||||||
|
num_sms=compute_num_sms,
|
||||||
|
)
|
||||||
|
combine_overlap_args.overlap = True
|
||||||
|
combine_overlap_args.signal = combine_signal
|
||||||
|
combine_overlap_args.threshold = compute_num_sms
|
||||||
|
else:
|
||||||
|
meta_overlap_args |= dict(
|
||||||
|
record_event_after_down=combine_wait_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args
|
||||||
Reference in New Issue
Block a user