[ROCm] Add ROCm tuning config to block gemm and Re-tune for AMD Radeon Graphics (#3418)

Co-authored-by: Bruce Xue <yigex@xilinx.com>
Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
yigex
2025-02-11 15:55:04 +08:00
committed by GitHub
parent 5f0e7de339
commit fdf04a1426
11 changed files with 432 additions and 385 deletions

View File

@@ -23,8 +23,13 @@ import torch
import triton
from tqdm import tqdm
from sglang.srt.layers.quantization.fp8_kernel import _w8a8_block_fp8_matmul
from sglang.srt.utils import get_device_name
from sglang.srt.layers.quantization.fp8_kernel import (
_w8a8_block_fp8_matmul,
_w8a8_block_fp8_matmul_unrolledx4,
)
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
is_hip_ = is_hip()
DTYPE_MAP = {
"float32": torch.float32,
@@ -80,7 +85,19 @@ def w8a8_block_fp8_matmul(
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
_w8a8_block_fp8_matmul[grid](
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (is_hip_ == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)
kernel[grid](
A,
B,
C,
@@ -107,14 +124,15 @@ def w8a8_block_fp8_matmul(
return C
def get_configs_compute_bound():
def get_rocm_configs_compute_bound():
configs = []
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128]:
for block_n in [32, 64, 128, 256]:
waves_per_eu_range = 0
for num_stages in [2]:
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
@@ -123,11 +141,36 @@ def get_configs_compute_bound():
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound():
configs = []
if is_hip_:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
def get_weight_shapes(tp_size):
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# cannot TP
@@ -190,14 +233,18 @@ def benchmark_config(
def tune(M, N, K, block_size, out_dtype, search_space):
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_info = torch.finfo(torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
)
B_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
)
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n