[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)
This commit is contained in:
@@ -11,6 +11,7 @@ import triton
|
||||
from ray.experimental.tqdm_ray import tqdm
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe,
|
||||
get_config_dtype_str,
|
||||
@@ -18,7 +19,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
get_default_config,
|
||||
get_moe_configs,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -117,17 +119,23 @@ def benchmark_config(
|
||||
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
topk_output = select_experts(x, input_gating, topk, renormalize=True)
|
||||
topk_config = TopKConfig(
|
||||
top_k=topk,
|
||||
renormalize=True,
|
||||
)
|
||||
topk_output = select_experts(x, input_gating, topk_config)
|
||||
|
||||
def prepare(i: int):
|
||||
input_gating = gating_output[i]
|
||||
new_topk_output = select_experts(x, input_gating, topk, renormalize=True)
|
||||
new_topk_output = select_experts(x, input_gating, topk_config)
|
||||
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
|
||||
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
|
||||
topk_output.router_logits.copy_(new_topk_output.router_logits)
|
||||
|
||||
def run():
|
||||
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
||||
moe_runner_config = MoeRunnerConfig(
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
with override_config(config):
|
||||
fused_moe(
|
||||
@@ -135,7 +143,7 @@ def benchmark_config(
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
inplace=True,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
|
||||
Reference in New Issue
Block a user