Support dispatching logical to physical experts (#6385)
This commit is contained in:
@@ -22,6 +22,10 @@ from sglang.srt.managers.expert_distribution import (
|
||||
ExpertDistributionRecorder,
|
||||
get_global_expert_distribution_recorder,
|
||||
)
|
||||
from sglang.srt.managers.expert_location_dispatch import (
|
||||
ExpertLocationDispatchInfo,
|
||||
topk_ids_logical_to_physical,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||
|
||||
@@ -100,6 +104,7 @@ def grouped_topk(
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -140,6 +145,7 @@ def grouped_topk(
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -155,6 +161,7 @@ def biased_grouped_topk_impl(
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -202,6 +209,7 @@ def biased_grouped_topk_impl(
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -232,6 +240,7 @@ def biased_grouped_topk(
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
assert (
|
||||
routed_scaling_factor is not None
|
||||
@@ -252,6 +261,8 @@ def biased_grouped_topk(
|
||||
n_share_experts_fusion,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
# TODO merge into kernel for this branch
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
# TODO will fuse this into kernel, thus use slow manual operation now
|
||||
torch.compile(
|
||||
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
|
||||
@@ -276,6 +287,7 @@ def biased_grouped_topk(
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -292,6 +304,7 @@ def select_experts(
|
||||
torch_native: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
# DeepSeek V2/V3/R1 series models use grouped_top_k
|
||||
@@ -309,6 +322,7 @@ def select_experts(
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = biased_grouped_topk(
|
||||
@@ -322,11 +336,13 @@ def select_experts(
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
elif torch_native and custom_routing_function is None:
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in fused_topk_native"
|
||||
assert expert_location_dispatch_info is None
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
@@ -337,6 +353,7 @@ def select_experts(
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in fused_topk"
|
||||
assert expert_location_dispatch_info is None
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
@@ -347,6 +364,7 @@ def select_experts(
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in custom_routing_function"
|
||||
assert expert_location_dispatch_info is None
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
|
||||
Reference in New Issue
Block a user