apply fused moe gate in ds v3/r1 (#5371)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
@@ -25,6 +26,8 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import moe_fused_gate
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
@@ -209,6 +212,10 @@ def biased_grouped_topk_impl(
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
def is_power_of_two(n):
|
||||
return n > 0 and math.log2(n).is_integer()
|
||||
|
||||
|
||||
def biased_grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@@ -220,23 +227,37 @@ def biased_grouped_topk(
|
||||
compiled: bool = True,
|
||||
n_share_experts_fusion: int = 0,
|
||||
):
|
||||
biased_grouped_topk_fn = (
|
||||
torch.compile(
|
||||
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
|
||||
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
|
||||
if (
|
||||
_is_cuda
|
||||
and n_share_experts_fusion == 0
|
||||
and is_power_of_two(correction_bias.shape[0])
|
||||
):
|
||||
return moe_fused_gate(
|
||||
gating_output,
|
||||
correction_bias,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
)
|
||||
else:
|
||||
biased_grouped_topk_fn = (
|
||||
torch.compile(
|
||||
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
|
||||
)
|
||||
if compiled
|
||||
else biased_grouped_topk_impl
|
||||
)
|
||||
return biased_grouped_topk_fn(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
)
|
||||
if compiled
|
||||
else biased_grouped_topk_impl
|
||||
)
|
||||
return biased_grouped_topk_fn(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
)
|
||||
|
||||
|
||||
def select_experts(
|
||||
|
||||
Reference in New Issue
Block a user