[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)
This commit is contained in:
@@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_tensor_quant_mla_fp8,
|
||||
per_token_group_quant_fp8,
|
||||
@@ -498,11 +498,13 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
||||
)
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
out = fused_moe(
|
||||
a,
|
||||
@@ -514,9 +516,6 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
||||
|
||||
@@ -12,7 +12,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
run_moe_ep_preproess,
|
||||
silu_and_mul_triton_kernel,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -22,35 +22,26 @@ def ep_moe(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
topk_config: TopKConfig,
|
||||
# ep config
|
||||
num_experts: int = 256,
|
||||
fp8_dtype: torch.types = torch.float8_e4m3fn,
|
||||
num_experts_per_partition: int = 128,
|
||||
start_expert_id: int = 0,
|
||||
end_expert_id: int = 127,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
w1_scale_inv: Optional[torch.Tensor] = None,
|
||||
w2_scale_inv: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
):
|
||||
use_blockwise_fp8 = block_shape is not None
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
top_k = topk_config.top_k
|
||||
topk_output = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
# correction_bias=correction_bias, #skip this in test
|
||||
custom_routing_function=custom_routing_function,
|
||||
topk_config=topk_config,
|
||||
)
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
|
||||
|
||||
@@ -294,14 +285,18 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
|
||||
start_id = cur_rank * num_experts_per_partition
|
||||
end_id = start_id + num_experts_per_partition - 1
|
||||
|
||||
topk_config = TopKConfig(
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
out = ep_moe(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
topk_config=topk_config,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale_inv=w1_s,
|
||||
w2_scale_inv=w2_s,
|
||||
@@ -316,8 +311,7 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
|
||||
w1=w1_ref,
|
||||
w2=w2_ref,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
topk_config=topk_config,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale_inv=None,
|
||||
w2_scale_inv=None,
|
||||
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
|
||||
|
||||
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
|
||||
@@ -100,11 +100,12 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
||||
s_strides2 = c_strides2
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype, device=device)
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
||||
expert_map[local_e:] = E
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from sgl_kernel import scaled_fp4_quant
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
|
||||
if torch.cuda.get_device_capability() < (10, 0):
|
||||
pytest.skip(
|
||||
@@ -163,11 +163,12 @@ def check_moe(
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user