Add DeepSeek V3/R1 shared experts fusion (#4918)
This commit is contained in:
@@ -12,12 +12,14 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -102,11 +104,13 @@ def grouped_topk(
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
n_share_experts_fusion: int = 0,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
num_token = scores.shape[0]
|
||||
num_experts = scores.shape[1]
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
@@ -122,9 +126,25 @@ def grouped_topk(
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
if n_share_experts_fusion:
|
||||
topk_ids[:, -1] = torch.randint(
|
||||
low=num_experts,
|
||||
high=num_experts + n_share_experts_fusion,
|
||||
size=(topk_ids.size(0),),
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
topk_weights[:, -1] = (
|
||||
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
||||
) # 2.5 is the routed_scaling_factor.
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights_sum = (
|
||||
topk_weights.sum(dim=-1, keepdim=True)
|
||||
if n_share_experts_fusion == 0
|
||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
@@ -137,11 +157,13 @@ def biased_grouped_topk_impl(
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
n_share_experts_fusion: int = 0,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
scores = gating_output.sigmoid()
|
||||
num_token = scores.shape[0]
|
||||
num_experts = scores.shape[1]
|
||||
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores_for_choice.view(num_token, num_expert_group, -1)
|
||||
@@ -164,8 +186,25 @@ def biased_grouped_topk_impl(
|
||||
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
topk_weights = scores.gather(1, topk_ids)
|
||||
|
||||
if n_share_experts_fusion:
|
||||
topk_ids[:, -1] = torch.randint(
|
||||
low=num_experts,
|
||||
high=num_experts + n_share_experts_fusion,
|
||||
size=(topk_ids.size(0),),
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
topk_weights[:, -1] = (
|
||||
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
||||
) # 2.5 is the routed_scaling_factor.
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights_sum = (
|
||||
topk_weights.sum(dim=-1, keepdim=True)
|
||||
if n_share_experts_fusion == 0
|
||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
@@ -179,6 +218,7 @@ def biased_grouped_topk(
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
compiled: bool = True,
|
||||
n_share_experts_fusion: int = 0,
|
||||
):
|
||||
biased_grouped_topk_fn = (
|
||||
torch.compile(
|
||||
@@ -195,6 +235,7 @@ def biased_grouped_topk(
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
)
|
||||
|
||||
|
||||
@@ -210,7 +251,10 @@ def select_experts(
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
torch_native: bool = False,
|
||||
):
|
||||
# DeekSeekv2 uses grouped_top_k
|
||||
n_share_experts_fusion = 0
|
||||
if global_server_args_dict["n_share_experts_fusion"] is not None:
|
||||
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
# DeekSeek V2/V3/R1 serices models uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
@@ -222,6 +266,7 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = biased_grouped_topk(
|
||||
@@ -232,6 +277,7 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
)
|
||||
elif torch_native and custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
|
||||
Reference in New Issue
Block a user