[1/N] MoE Refactor: refactor select_experts (#7966)

This commit is contained in:
Cheng Wan
2025-07-19 00:51:15 -07:00
committed by GitHub
parent cfab0ff6e2
commit 15ad6c9086
39 changed files with 556 additions and 871 deletions

View File

@@ -5,6 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.test.test_utils import CustomTestCase
@@ -171,14 +172,18 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase):
score = torch.randn((M, E), dtype=dtype)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
with torch.inference_mode():
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
topk_output,
use_int8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,

View File

@@ -6,6 +6,7 @@ from tqdm import tqdm
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.utils import is_hip
@@ -132,13 +133,17 @@ class TestFusedMOE(CustomTestCase):
input_scale=a2_scale,
)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
sglang_output = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
topk_output,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
@@ -166,7 +171,13 @@ class TestFusedMOE(CustomTestCase):
w2 = self.create_random_cuda_tensor((e, k, n), dtype)
score = self.create_random_cuda_tensor((m, e), dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
triton_output = fused_moe(a, w1, w2, topk_output)
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
torch.testing.assert_close(
triton_output, torch_output, rtol=rtol, atol=atol

View File

@@ -5,6 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.test.test_utils import CustomTestCase
@@ -114,13 +115,16 @@ class TestW8A8Int8FusedMoE(CustomTestCase):
with torch.inference_mode():
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
topk_output,
use_fp8_w8a8=False, # Not using fp8
use_int8_w8a16=False, # Not using int8-w8a16
use_int8_w8a8=True, # Using int8-w8a8

View File

@@ -5,6 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.test.test_utils import CustomTestCase
@@ -126,13 +127,16 @@ class TestW8A8FP8FusedMoE(CustomTestCase):
with torch.inference_mode():
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
topk_output,
use_fp8_w8a8=True, # using fp8
use_int8_w8a16=False,
use_int8_w8a8=False,

View File

@@ -5,6 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6]
@@ -219,13 +220,17 @@ def test_fused_moe_wn16(
if has_zp:
w_qzeros[expert_id] = qzeros
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
triton_output = fused_moe(
a,
w1_qweight,
w2_qweight,
score,
topk,
renormalize=False,
topk_output,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
w1_scale=w1_scales,