fix: resolve tuning fused moe issue (#9587)
This commit is contained in:
@@ -22,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.srt.utils import is_hip, is_rocm
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
@@ -287,7 +287,7 @@ class BenchmarkWorker:
|
||||
)
|
||||
else:
|
||||
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
|
||||
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
@@ -319,7 +319,7 @@ class BenchmarkWorker:
|
||||
) -> Dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
|
||||
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
|
||||
Reference in New Issue
Block a user