diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 6f406e7ce..a84a4a6e2 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -20,7 +20,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ) from sglang.srt.utils import is_hip -_is_hip_ = is_hip() +_is_hip = is_hip() class BenchmarkConfig(TypedDict): @@ -112,8 +112,8 @@ def benchmark_config( ) if use_fp8_w8a8: - w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn) + w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) @@ -204,7 +204,7 @@ def get_configs_compute_bound() -> List[Dict[str, int]]: # TODO(woosuk): Increase the search space and use a performance model to # prune the search space. configs: List[BenchmarkConfig] = [] - if _is_hip_: + if _is_hip: configs = get_rocm_configs_compute_bound() else: for num_stages in [2, 3, 4, 5]: diff --git a/benchmark/kernels/quantization/tuning_block_wise_kernel.py b/benchmark/kernels/quantization/tuning_block_wise_kernel.py index 197939f02..7b0dfb47a 100644 --- a/benchmark/kernels/quantization/tuning_block_wise_kernel.py +++ b/benchmark/kernels/quantization/tuning_block_wise_kernel.py @@ -33,7 +33,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul from sglang.srt.utils import get_device_core_count, get_device_name, is_hip -is_hip_ = is_hip() +_is_hip = is_hip() DTYPE_MAP = { "float32": torch.float32, @@ -99,7 +99,7 @@ def w8a8_block_matmul( if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn: kernel = ( _w8a8_block_fp8_matmul_unrolledx4 - if (is_hip_ == True and num_workgroups <= get_device_core_count()) + if (_is_hip == True and num_workgroups <= get_device_core_count()) else _w8a8_block_fp8_matmul ) else: @@ -157,7 +157,7 @@ def get_rocm_configs_compute_bound(): def get_configs_compute_bound(): configs = [] - if is_hip_: + if _is_hip: configs = get_rocm_configs_compute_bound() else: for num_stages in [2, 3, 4, 5]: @@ -244,7 +244,7 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type): if input_type == "fp8": fp8_info = torch.finfo( - torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn ) fp8_max, fp8_min = fp8_info.max, fp8_info.min @@ -252,14 +252,14 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type): (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max ) A = A_fp32.clamp(min=fp8_min, max=fp8_max).to( - torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn ) B_fp32 = ( (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max ) B = B_fp32.clamp(min=fp8_min, max=fp8_max).to( - torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn ) else: int8_info = torch.iinfo(torch.int8)