[1/N] MoE Refactor: refactor select_experts (#7966)
This commit is contained in:
@@ -6,6 +6,7 @@ from tqdm import tqdm
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.utils import is_hip
|
||||
@@ -132,13 +133,17 @@ class TestFusedMOE(CustomTestCase):
|
||||
input_scale=a2_scale,
|
||||
)
|
||||
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
)
|
||||
|
||||
sglang_output = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
topk_output,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
@@ -166,7 +171,13 @@ class TestFusedMOE(CustomTestCase):
|
||||
w2 = self.create_random_cuda_tensor((e, k, n), dtype)
|
||||
score = self.create_random_cuda_tensor((m, e), dtype)
|
||||
|
||||
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
)
|
||||
|
||||
triton_output = fused_moe(a, w1, w2, topk_output)
|
||||
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
|
||||
torch.testing.assert_close(
|
||||
triton_output, torch_output, rtol=rtol, atol=atol
|
||||
|
||||
Reference in New Issue
Block a user