Minor update for ROCm variable style (#5562)

This commit is contained in:
Zhaoyi Li
2025-04-20 01:45:27 -05:00
committed by GitHub
parent e2574ee986
commit c555d794f7
2 changed files with 10 additions and 10 deletions

View File

@@ -20,7 +20,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
)
from sglang.srt.utils import is_hip
_is_hip_ = is_hip()
_is_hip = is_hip()
class BenchmarkConfig(TypedDict):
@@ -112,8 +112,8 @@ def benchmark_config(
)
if use_fp8_w8a8:
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
@@ -204,7 +204,7 @@ def get_configs_compute_bound() -> List[Dict[str, int]]:
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs: List[BenchmarkConfig] = []
if _is_hip_:
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]: