Support dispatch low latency (#10263)
Co-authored-by: Kaixi Hou <4001424+kaixih@users.noreply.github.com>
This commit is contained in:
@@ -31,6 +31,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
is_fp8_fnuz,
|
is_fp8_fnuz,
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.modelopt_quant import (
|
||||||
|
CUTEDSL_MOE_NVFP4_DISPATCH,
|
||||||
|
ModelOptNvFp4FusedMoEMethod,
|
||||||
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.offloader import get_offloader
|
from sglang.srt.offloader import get_offloader
|
||||||
@@ -453,6 +457,13 @@ class DeepEPMoE(EPMoE):
|
|||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
|
input_global_scale=(
|
||||||
|
self.w13_input_scale_quant
|
||||||
|
if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
||||||
|
and self.quant_method.enable_flashinfer_cutedsl_moe
|
||||||
|
and CUTEDSL_MOE_NVFP4_DISPATCH
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def moe_impl(self, dispatch_output: DispatchOutput):
|
def moe_impl(self, dispatch_output: DispatchOutput):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
|
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
|
||||||
@@ -20,7 +20,7 @@ def get_cute_dtype(input: torch.Tensor) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def flashinfer_cutedsl_moe_masked(
|
def flashinfer_cutedsl_moe_masked(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||||
input_global_scale: torch.Tensor,
|
input_global_scale: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w1_blockscale: torch.Tensor,
|
w1_blockscale: torch.Tensor,
|
||||||
@@ -36,7 +36,9 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
kernels.
|
kernels.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hidden_states (torch.Tensor): [num_experts, m, k], bf16
|
hidden_states: Either of the following case
|
||||||
|
* torch.Tensor: [num_experts, m, k], bf16
|
||||||
|
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn
|
||||||
input_global_scale (torch.Tensor): (l,)
|
input_global_scale (torch.Tensor): (l,)
|
||||||
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
|
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
|
||||||
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
|
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
|
||||||
@@ -48,13 +50,10 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
masked_m (torch.Tensor): Masked dimension indices
|
masked_m (torch.Tensor): Masked dimension indices
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- Assumes max(masked_m) <= m.
|
- Assumes max(masked_m) == m.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# === Assertions on dtypes ===
|
# === Assertions on dtypes ===
|
||||||
assert (
|
|
||||||
input_global_scale.dtype == torch.float32
|
|
||||||
), f"input_global_scale must be float32, got {input_global_scale.dtype}"
|
|
||||||
assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
|
assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
|
||||||
assert (
|
assert (
|
||||||
w1_blockscale.dtype == torch.float8_e4m3fn
|
w1_blockscale.dtype == torch.float8_e4m3fn
|
||||||
@@ -75,7 +74,31 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
|
|
||||||
# === Assertions on shapes ===
|
# === Assertions on shapes ===
|
||||||
n = w2.shape[-1] * 2 # intermediate dimension
|
n = w2.shape[-1] * 2 # intermediate dimension
|
||||||
num_experts, m, k = hidden_states.shape
|
|
||||||
|
if isinstance(hidden_states, tuple):
|
||||||
|
assert (
|
||||||
|
input_global_scale is None
|
||||||
|
), "input_global_scale is needed when input needs quant"
|
||||||
|
|
||||||
|
a_q = hidden_states[0].view(torch.uint8)
|
||||||
|
a_q_sf = hidden_states[1].view(torch.float8_e4m3fn)
|
||||||
|
m, k_by_2, num_experts = a_q.shape
|
||||||
|
k = k_by_2 * 2
|
||||||
|
else:
|
||||||
|
num_experts, m, k = hidden_states.shape
|
||||||
|
|
||||||
|
assert (
|
||||||
|
input_global_scale.dtype == torch.float32
|
||||||
|
), f"input_global_scale must be float32, got {input_global_scale.dtype}"
|
||||||
|
assert input_global_scale.shape == (
|
||||||
|
num_experts,
|
||||||
|
), f"input_global_scale must be (l,), got {input_global_scale.shape}"
|
||||||
|
|
||||||
|
a_q, a_q_sf = scaled_fp4_grouped_quant(
|
||||||
|
hidden_states,
|
||||||
|
input_global_scale,
|
||||||
|
masked_m,
|
||||||
|
)
|
||||||
|
|
||||||
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
|
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
|
||||||
assert (
|
assert (
|
||||||
@@ -85,10 +108,6 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
k,
|
k,
|
||||||
n // 2,
|
n // 2,
|
||||||
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}"
|
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}"
|
||||||
|
|
||||||
assert input_global_scale.shape == (
|
|
||||||
num_experts,
|
|
||||||
), f"input_global_scale must be (l,), got {input_global_scale.shape}"
|
|
||||||
assert w1_alpha.shape == (
|
assert w1_alpha.shape == (
|
||||||
num_experts,
|
num_experts,
|
||||||
), f"w1_alpha must be (l,), got {w1_alpha.shape}"
|
), f"w1_alpha must be (l,), got {w1_alpha.shape}"
|
||||||
@@ -99,27 +118,21 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
num_experts,
|
num_experts,
|
||||||
), f"w2_alpha must be (l,), got {w2_alpha.shape}"
|
), f"w2_alpha must be (l,), got {w2_alpha.shape}"
|
||||||
|
|
||||||
aq, aq_sf = scaled_fp4_grouped_quant(
|
# TODO(kaixih@nvidia): dtype should be based on inputs.
|
||||||
hidden_states,
|
|
||||||
input_global_scale,
|
|
||||||
masked_m,
|
|
||||||
)
|
|
||||||
gateup_output = torch.empty(
|
gateup_output = torch.empty(
|
||||||
(num_experts, m, n * 2), dtype=hidden_states.dtype, device=aq.device
|
(num_experts, m, n * 2), dtype=torch.bfloat16, device=a_q.device
|
||||||
)
|
)
|
||||||
gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel
|
gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel
|
||||||
sf_vec_size = 16
|
sf_vec_size = 16
|
||||||
assert aq_sf.dtype == torch.float8_e4m3fn
|
assert a_q_sf.dtype == torch.float8_e4m3fn
|
||||||
assert aq.dtype == torch.uint8
|
assert a_q.dtype == torch.uint8
|
||||||
ab_dtype = "float4_e2m1fn"
|
ab_dtype = "float4_e2m1fn"
|
||||||
sf_dtype = "float8_e4m3fn"
|
sf_dtype = "float8_e4m3fn"
|
||||||
|
c_dtype = "bfloat16"
|
||||||
c_dtype = get_cute_dtype(hidden_states)
|
|
||||||
|
|
||||||
# Gemm1
|
# Gemm1
|
||||||
|
|
||||||
grouped_gemm_nt_masked(
|
grouped_gemm_nt_masked(
|
||||||
(aq, aq_sf),
|
(a_q, a_q_sf),
|
||||||
(w1.permute(1, 2, 0), w1_blockscale),
|
(w1.permute(1, 2, 0), w1_blockscale),
|
||||||
gateup_output,
|
gateup_output,
|
||||||
masked_m,
|
masked_m,
|
||||||
@@ -139,7 +152,7 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Gemm2
|
# Gemm2
|
||||||
out = torch.empty_like(hidden_states)
|
out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device)
|
||||||
out = out.permute(1, 2, 0) # requirement of kernel
|
out = out.permute(1, 2, 0) # requirement of kernel
|
||||||
grouped_gemm_nt_masked(
|
grouped_gemm_nt_masked(
|
||||||
(diq, diq_sf),
|
(diq, diq_sf),
|
||||||
|
|||||||
@@ -296,6 +296,7 @@ class _DeepEPDispatcherImplBase:
|
|||||||
def dispatch_a(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
input_global_scale: Optional[torch.Tensor],
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
):
|
):
|
||||||
@@ -329,6 +330,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
def dispatch_a(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
input_global_scale: Optional[torch.Tensor],
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
):
|
):
|
||||||
@@ -505,6 +507,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
def dispatch_a(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
input_global_scale: Optional[torch.Tensor],
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
):
|
):
|
||||||
@@ -516,9 +519,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
) // self.num_experts
|
) // self.num_experts
|
||||||
hidden_states, masked_m, event, hook = self._dispatch_core(
|
hidden_states, masked_m, event, hook = self._dispatch_core(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
input_global_scale,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
# TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341
|
|
||||||
use_fp8=not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"),
|
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -558,9 +560,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
def _dispatch_core(
|
def _dispatch_core(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
input_global_scale: Optional[torch.Tensor],
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
use_fp8: bool = False,
|
|
||||||
):
|
):
|
||||||
|
use_nvfp4 = use_fp8 = False
|
||||||
|
if input_global_scale is not None:
|
||||||
|
use_nvfp4 = True
|
||||||
|
elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
|
||||||
|
use_fp8 = True
|
||||||
|
|
||||||
buffer = self._get_buffer()
|
buffer = self._get_buffer()
|
||||||
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
||||||
buffer.low_latency_dispatch(
|
buffer.low_latency_dispatch(
|
||||||
@@ -569,6 +577,12 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
self.num_max_dispatch_tokens_per_rank,
|
self.num_max_dispatch_tokens_per_rank,
|
||||||
self.num_experts,
|
self.num_experts,
|
||||||
use_fp8=use_fp8,
|
use_fp8=use_fp8,
|
||||||
|
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
|
||||||
|
**(
|
||||||
|
dict(x_global_scale=input_global_scale)
|
||||||
|
if input_global_scale is not None
|
||||||
|
else dict()
|
||||||
|
),
|
||||||
async_finish=not self.return_recv_hook,
|
async_finish=not self.return_recv_hook,
|
||||||
return_recv_hook=self.return_recv_hook,
|
return_recv_hook=self.return_recv_hook,
|
||||||
round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||||
@@ -682,6 +696,7 @@ class DeepEPDispatcher(BaseDispatcher):
|
|||||||
def dispatch_a(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
input_global_scale: Optional[torch.Tensor],
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
@@ -689,6 +704,7 @@ class DeepEPDispatcher(BaseDispatcher):
|
|||||||
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
||||||
inner_state = self._get_impl(forward_batch).dispatch_a(
|
inner_state = self._get_impl(forward_batch).dispatch_a(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
input_global_scale=input_global_scale,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -80,6 +80,10 @@ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
|
|||||||
USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
|
USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
|
||||||
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
|
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
|
||||||
)
|
)
|
||||||
|
# TODO make it true by default when the DeepEP PR is merged
|
||||||
|
CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
|
||||||
|
"SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH", "false"
|
||||||
|
)
|
||||||
|
|
||||||
# Supported activation schemes for the current configuration
|
# Supported activation schemes for the current configuration
|
||||||
ACTIVATION_SCHEMES = ["static"]
|
ACTIVATION_SCHEMES = ["static"]
|
||||||
@@ -1234,6 +1238,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
w13_input_scale = _slice_scale(w13_input_scale)
|
w13_input_scale = _slice_scale(w13_input_scale)
|
||||||
w2_input_scale = _slice_scale(w2_input_scale)
|
w2_input_scale = _slice_scale(w2_input_scale)
|
||||||
|
|
||||||
|
if CUTEDSL_MOE_NVFP4_DISPATCH:
|
||||||
|
assert torch.all(w13_input_scale == w13_input_scale[0])
|
||||||
|
w13_input_scale = w13_input_scale[0]
|
||||||
else:
|
else:
|
||||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||||
w2_input_scale = layer.w2_input_scale
|
w2_input_scale = layer.w2_input_scale
|
||||||
@@ -1476,7 +1484,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
out = flashinfer_cutedsl_moe_masked(
|
out = flashinfer_cutedsl_moe_masked(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
input_global_scale=layer.w13_input_scale_quant,
|
input_global_scale=(
|
||||||
|
None if CUTEDSL_MOE_NVFP4_DISPATCH else layer.w13_input_scale_quant
|
||||||
|
),
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||||
w1_alpha=layer.g1_alphas,
|
w1_alpha=layer.g1_alphas,
|
||||||
|
|||||||
@@ -896,6 +896,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
self.experts.deepep_dispatcher.dispatch_a(
|
self.experts.deepep_dispatcher.dispatch_a(
|
||||||
hidden_states=state.hidden_states_mlp_input,
|
hidden_states=state.hidden_states_mlp_input,
|
||||||
|
input_global_scale=None,
|
||||||
topk_idx=state.pop("topk_idx_local"),
|
topk_idx=state.pop("topk_idx_local"),
|
||||||
topk_weights=state.pop("topk_weights_local"),
|
topk_weights=state.pop("topk_weights_local"),
|
||||||
forward_batch=state.forward_batch,
|
forward_batch=state.forward_batch,
|
||||||
|
|||||||
Reference in New Issue
Block a user