From 7ab84948d87d2c264cccc4ae8c1db339b9efea6a Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Tue, 4 Feb 2025 21:12:20 -0600 Subject: [PATCH] [ROCm] Logic to decide whether to used manually unrolled kernel. (#3306) --- python/sglang/srt/layers/quantization/fp8_kernel.py | 13 ++++++++++--- python/sglang/srt/utils.py | 7 +++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 8443f8dd6..ddd614fdf 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -22,7 +22,7 @@ import torch import triton import triton.language as tl -from sglang.srt.utils import get_device_name, is_hip +from sglang.srt.utils import get_device_core_count, get_device_name, is_hip is_hip_ = is_hip() fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn @@ -450,9 +450,16 @@ def w8a8_block_fp8_matmul( triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - # Use manually unrolledx4 kernel on AMD GPU. + # Use manually unrolledx4 kernel on AMD GPU when the grid size is small. + # Empirical testing shows the sweet spot lies when it's less than the # of + # compute units available on the device. + num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( + N, config["BLOCK_SIZE_N"] + ) kernel = ( - _w8a8_block_fp8_matmul_unrolledx4 if is_hip_ == True else _w8a8_block_fp8_matmul + _w8a8_block_fp8_matmul_unrolledx4 + if (is_hip_ == True and num_workgroups <= get_device_core_count()) + else _w8a8_block_fp8_matmul ) kernel[grid]( diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ebb346bbc..b1c49f527 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1046,6 +1046,13 @@ def get_device_name(device_id: int = 0) -> str: return torch.hpu.get_device_name(device_id) +def get_device_core_count(device_id: int = 0) -> int: + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return torch.cuda.get_device_properties(device_id).multi_processor_count + + return 0 + + def get_device_capability(device_id: int = 0) -> Tuple[int, int]: major, minor = None, None if hasattr(torch, "cuda") and torch.cuda.is_available():