Fix tuning_fused_moe_triton.py (#8175)
This commit is contained in:
@@ -18,6 +18,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
get_default_config,
|
||||
get_moe_configs,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -115,10 +116,15 @@ def benchmark_config(
|
||||
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||
|
||||
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
topk_output = select_experts(x, input_gating, topk, renormalize=True)
|
||||
|
||||
def prepare(i: int):
|
||||
input_gating.copy_(gating_output[i])
|
||||
input_gating = gating_output[i]
|
||||
new_topk_output = select_experts(x, input_gating, topk, renormalize=True)
|
||||
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
|
||||
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
|
||||
topk_output.router_logits.copy_(new_topk_output.router_logits)
|
||||
|
||||
def run():
|
||||
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
||||
@@ -128,9 +134,7 @@ def benchmark_config(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
topk_output,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
|
||||
Reference in New Issue
Block a user