[1/N] MoE Refactor: refactor select_experts (#7966)
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_tensor_quant_mla_fp8,
|
||||
per_token_group_quant_fp8,
|
||||
@@ -497,13 +498,17 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
with torch.inference_mode():
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
)
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
topk_output,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
|
||||
@@ -40,7 +40,7 @@ def ep_moe(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
):
|
||||
use_blockwise_fp8 = block_shape is not None
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
|
||||
@@ -100,12 +100,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
||||
s_strides2 = c_strides2
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype, device=device)
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
)
|
||||
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
||||
expert_map[local_e:] = E
|
||||
|
||||
@@ -159,12 +159,10 @@ def test_cutlass_fp4_moe_no_graph(
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user