Add DeepSeek V3/R1 shared experts fusion (#4918)

This commit is contained in:
Xiaoyu Zhang
2025-04-04 16:59:29 +08:00
committed by GitHub
parent 6ff9c6a5e7
commit 924ca7c92c
14 changed files with 536 additions and 36 deletions

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -13,11 +13,6 @@ import triton
import triton.language as tl
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8,
per_token_quant_int8,
)
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
@@ -42,9 +37,6 @@ if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
else:
from vllm import _custom_ops as vllm_ops
@@ -764,6 +756,16 @@ def invoke_fused_moe_kernel(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> None:
from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8,
per_token_quant_int8,
)
if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

View File

@@ -12,12 +12,14 @@
# limitations under the License.
# ==============================================================================
import os
from typing import Callable, Optional
import torch
import torch.nn.functional as F
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
_is_cuda = is_cuda()
@@ -102,11 +104,13 @@ def grouped_topk(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
n_share_experts_fusion: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
num_experts = scores.shape[1]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
@@ -122,9 +126,25 @@ def grouped_topk(
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if n_share_experts_fusion:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + n_share_experts_fusion,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
)
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / 2.5
) # 2.5 is the routed_scaling_factor.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if n_share_experts_fusion == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
@@ -137,11 +157,13 @@ def biased_grouped_topk_impl(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
n_share_experts_fusion: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = gating_output.sigmoid()
num_token = scores.shape[0]
num_experts = scores.shape[1]
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(num_token, num_expert_group, -1)
@@ -164,8 +186,25 @@ def biased_grouped_topk_impl(
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_ids)
if n_share_experts_fusion:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + n_share_experts_fusion,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
)
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / 2.5
) # 2.5 is the routed_scaling_factor.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if n_share_experts_fusion == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
@@ -179,6 +218,7 @@ def biased_grouped_topk(
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = True,
n_share_experts_fusion: int = 0,
):
biased_grouped_topk_fn = (
torch.compile(
@@ -195,6 +235,7 @@ def biased_grouped_topk(
renormalize,
num_expert_group,
topk_group,
n_share_experts_fusion=n_share_experts_fusion,
)
@@ -210,7 +251,10 @@ def select_experts(
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
):
# DeekSeekv2 uses grouped_top_k
n_share_experts_fusion = 0
if global_server_args_dict["n_share_experts_fusion"] is not None:
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# DeekSeek V2/V3/R1 serices models uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
@@ -222,6 +266,7 @@ def select_experts(
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
)
else:
topk_weights, topk_ids = biased_grouped_topk(
@@ -232,6 +277,7 @@ def select_experts(
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
)
elif torch_native and custom_routing_function is None:
topk_weights, topk_ids = fused_topk_native(