[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)
This commit is contained in:
@@ -5,7 +5,7 @@ import torch
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -175,10 +175,13 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase):
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out = torch_w8a8_block_int8_moe(
|
||||
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
||||
)
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
@@ -189,9 +192,6 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase):
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_int8_moe(
|
||||
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
@@ -118,7 +118,7 @@ class TestW8A8Int8FusedMoE(CustomTestCase):
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
out = fused_moe(
|
||||
a,
|
||||
|
||||
@@ -6,7 +6,7 @@ from tqdm import tqdm
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.utils import is_hip
|
||||
@@ -136,19 +136,7 @@ class TestFusedMOE(CustomTestCase):
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
)
|
||||
|
||||
sglang_output = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
|
||||
torch_output = self.torch_naive_moe(
|
||||
@@ -162,6 +150,18 @@ class TestFusedMOE(CustomTestCase):
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
sglang_output = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
sglang_output, torch_output, rtol=rtol, atol=atol
|
||||
)
|
||||
@@ -174,7 +174,7 @@ class TestFusedMOE(CustomTestCase):
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
|
||||
triton_output = fused_moe(a, w1, w2, topk_output)
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
@@ -130,7 +130,7 @@ class TestW8A8FP8FusedMoE(CustomTestCase):
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
out = fused_moe(
|
||||
a,
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [2, 6]
|
||||
@@ -223,7 +223,7 @@ def test_fused_moe_wn16(
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
topk_config=TopKConfig(top_k=topk),
|
||||
)
|
||||
|
||||
triton_output = fused_moe(
|
||||
|
||||
Reference in New Issue
Block a user