Fix tuning_fused_moe_triton.py (#8175)
This commit is contained in:
@@ -18,6 +18,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
|||||||
get_default_config,
|
get_default_config,
|
||||||
get_moe_configs,
|
get_moe_configs,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
@@ -115,10 +116,15 @@ def benchmark_config(
|
|||||||
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||||
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||||
|
|
||||||
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
topk_output = select_experts(x, input_gating, topk, renormalize=True)
|
||||||
|
|
||||||
def prepare(i: int):
|
def prepare(i: int):
|
||||||
input_gating.copy_(gating_output[i])
|
input_gating = gating_output[i]
|
||||||
|
new_topk_output = select_experts(x, input_gating, topk, renormalize=True)
|
||||||
|
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
|
||||||
|
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
|
||||||
|
topk_output.router_logits.copy_(new_topk_output.router_logits)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
||||||
@@ -128,9 +134,7 @@ def benchmark_config(
|
|||||||
x,
|
x,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
input_gating,
|
topk_output,
|
||||||
topk,
|
|
||||||
renormalize=True,
|
|
||||||
inplace=True,
|
inplace=True,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
|||||||
Reference in New Issue
Block a user