[ROCm] Logic to decide whether to used manually unrolled kernel. (#3306)
This commit is contained in:
committed by
GitHub
parent
4885b90802
commit
7ab84948d8
@@ -22,7 +22,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.utils import get_device_name, is_hip
|
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
is_hip_ = is_hip()
|
||||||
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||||
@@ -450,9 +450,16 @@ def w8a8_block_fp8_matmul(
|
|||||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use manually unrolledx4 kernel on AMD GPU.
|
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
||||||
|
# Empirical testing shows the sweet spot lies when it's less than the # of
|
||||||
|
# compute units available on the device.
|
||||||
|
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||||
|
N, config["BLOCK_SIZE_N"]
|
||||||
|
)
|
||||||
kernel = (
|
kernel = (
|
||||||
_w8a8_block_fp8_matmul_unrolledx4 if is_hip_ == True else _w8a8_block_fp8_matmul
|
_w8a8_block_fp8_matmul_unrolledx4
|
||||||
|
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
||||||
|
else _w8a8_block_fp8_matmul
|
||||||
)
|
)
|
||||||
|
|
||||||
kernel[grid](
|
kernel[grid](
|
||||||
|
|||||||
@@ -1046,6 +1046,13 @@ def get_device_name(device_id: int = 0) -> str:
|
|||||||
return torch.hpu.get_device_name(device_id)
|
return torch.hpu.get_device_name(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_core_count(device_id: int = 0) -> int:
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||||
major, minor = None, None
|
major, minor = None, None
|
||||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
|||||||
Reference in New Issue
Block a user