[feat]Support fusion kernel for constructing quant input and scale factor for fp8_blockwise_scaled_grouped_mm (#8023)
This commit is contained in:
@@ -1166,3 +1166,88 @@ def scaled_fp8_quant(
|
|||||||
) # True for static
|
) # True for static
|
||||||
|
|
||||||
return output, scale
|
return output, scale
|
||||||
|
|
||||||
|
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config({"BLOCK_M": block_m}, num_warps=num_warps)
|
||||||
|
for block_m in [16, 32, 64, 128]
|
||||||
|
for num_warps in [2, 4, 8]
|
||||||
|
],
|
||||||
|
key=["K", "BLOCK_K", "M_ALIGNMENT"],
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def _per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||||
|
a, # (M, K):(K, 1)
|
||||||
|
expert_offsets, # (num_experts,)
|
||||||
|
problem_sizes, # (num_experts, 3)
|
||||||
|
a_fp8, # (M, K):(K, 1)
|
||||||
|
sfa, # (M, k)
|
||||||
|
K: tl.constexpr,
|
||||||
|
BLOCK_K: tl.constexpr,
|
||||||
|
M_ALIGNMENT: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr, # tune
|
||||||
|
):
|
||||||
|
k_offset = tl.program_id(0)
|
||||||
|
expert_id = tl.program_id(1)
|
||||||
|
|
||||||
|
m = tl.load(problem_sizes + expert_id * 3)
|
||||||
|
current_expert_offset = tl.load(expert_offsets + expert_id).to(tl.int64)
|
||||||
|
tl.multiple_of(m, M_ALIGNMENT)
|
||||||
|
tl.multiple_of(current_expert_offset, M_ALIGNMENT)
|
||||||
|
|
||||||
|
coord_k = k_offset * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||||
|
for i in tl.range(tl.cdiv(m, BLOCK_M)):
|
||||||
|
coord_m = i * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
a_ptrs = a + current_expert_offset * K + coord_m[:, None] * K + coord_k[None, :]
|
||||||
|
a_mask = (coord_m < m)[:, None] & (coord_k < K)[None, :]
|
||||||
|
|
||||||
|
inp = tl.load(a_ptrs, mask=a_mask).to(tl.float32) # [BLOCK_M, BLOCK_K]
|
||||||
|
inp_amax = tl.max(tl.abs(inp), axis=1) # [BLOCK_M,]
|
||||||
|
inp_amax = tl.clamp(inp_amax, min=1e-4, max=float("inf"))
|
||||||
|
inp_fp8 = (inp * (448.0 / inp_amax[:, None])).to(tl.float8e4nv)
|
||||||
|
|
||||||
|
# Store fp8
|
||||||
|
a_fp8_ptrs = (
|
||||||
|
a_fp8 + current_expert_offset * K + coord_m[:, None] * K + coord_k[None, :]
|
||||||
|
)
|
||||||
|
tl.store(a_fp8_ptrs, inp_fp8, mask=a_mask)
|
||||||
|
|
||||||
|
# Store sfa
|
||||||
|
k = tl.cdiv(K, BLOCK_K)
|
||||||
|
sfa_ptrs = (
|
||||||
|
sfa + current_expert_offset * k + k_offset * m + coord_m
|
||||||
|
) # MN-Major with sfa
|
||||||
|
tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m)
|
||||||
|
|
||||||
|
|
||||||
|
def per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||||
|
A: torch.Tensor,
|
||||||
|
expert_offsets: torch.Tensor,
|
||||||
|
problem_sizes: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
expert_tokens_alignment: int = 1,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert A.dim() == 2
|
||||||
|
assert A.is_contiguous(), "`A` is not contiguous"
|
||||||
|
assert (
|
||||||
|
A.shape[-1] % group_size == 0
|
||||||
|
), "the last dimension of `A` cannot be divisible by `group_size`"
|
||||||
|
|
||||||
|
a_q = torch.empty_like(A, device=A.device, dtype=fp8_dtype)
|
||||||
|
M, K = A.shape[0], A.shape[1]
|
||||||
|
k = K // group_size
|
||||||
|
sfa = torch.empty((M, k), device=A.device, dtype=torch.float32)
|
||||||
|
num_experts = problem_sizes.shape[0]
|
||||||
|
grid = (k, num_experts)
|
||||||
|
_per_token_group_quant_fp8_hopper_moe_mn_major[grid](
|
||||||
|
A,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes,
|
||||||
|
a_q,
|
||||||
|
sfa,
|
||||||
|
K,
|
||||||
|
group_size,
|
||||||
|
expert_tokens_alignment,
|
||||||
|
)
|
||||||
|
return a_q, sfa
|
||||||
|
|||||||
@@ -5,6 +5,10 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from sgl_kernel import fp8_blockwise_scaled_grouped_mm
|
from sgl_kernel import fp8_blockwise_scaled_grouped_mm
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
per_token_group_quant_fp8_hopper_moe_mn_major,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cdiv(a: int, b: int) -> int:
|
def cdiv(a: int, b: int) -> int:
|
||||||
return -(a // -b)
|
return -(a // -b)
|
||||||
@@ -104,8 +108,11 @@ def is_sm90_supported(device=None) -> bool:
|
|||||||
)
|
)
|
||||||
@pytest.mark.parametrize("num_experts", [8, 16])
|
@pytest.mark.parametrize("num_experts", [8, 16])
|
||||||
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
||||||
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
@pytest.mark.parametrize("use_custom_kernel", [True, False])
|
||||||
|
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kernel):
|
||||||
cc = torch.cuda.get_device_capability(None)[0]
|
cc = torch.cuda.get_device_capability(None)[0]
|
||||||
|
if cc == 10 and use_custom_kernel:
|
||||||
|
return
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
alignment = 16
|
alignment = 16
|
||||||
n_g = alignment * random.randint(1, 5) * 128
|
n_g = alignment * random.randint(1, 5) * 128
|
||||||
@@ -116,6 +123,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
|||||||
layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
|
layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
|
||||||
layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
|
layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
a_original_tensors = []
|
||||||
a_tensors = []
|
a_tensors = []
|
||||||
b_tensors = []
|
b_tensors = []
|
||||||
a_scales_tensors = []
|
a_scales_tensors = []
|
||||||
@@ -136,6 +144,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
|||||||
b_g, b_scale = per_block_cast_to_fp8(
|
b_g, b_scale = per_block_cast_to_fp8(
|
||||||
b
|
b
|
||||||
) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
|
) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
|
||||||
|
a_original_tensors.append(a)
|
||||||
a_tensors.append(a_g)
|
a_tensors.append(a_g)
|
||||||
b_tensors.append(b_g)
|
b_tensors.append(b_g)
|
||||||
a_scales_tensors.append(a_scale)
|
a_scales_tensors.append(a_scale)
|
||||||
@@ -143,22 +152,15 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
|||||||
|
|
||||||
baseline = torch.mm(a, b)
|
baseline = torch.mm(a, b)
|
||||||
baseline_tensors.append(baseline)
|
baseline_tensors.append(baseline)
|
||||||
|
a_original_stack = torch.empty(
|
||||||
|
(expert_offsets[-1], k_g), device=device, dtype=out_dtype
|
||||||
|
)
|
||||||
a_stack = torch.empty(
|
a_stack = torch.empty(
|
||||||
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
|
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
b_stack = torch.empty(
|
b_stack = torch.empty(
|
||||||
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
|
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|
||||||
for g in range(num_experts):
|
|
||||||
# Matrix A is Row-Major.
|
|
||||||
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[
|
|
||||||
g
|
|
||||||
] # a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1)
|
|
||||||
b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1)
|
|
||||||
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
|
|
||||||
|
|
||||||
a_scale_stack = torch.empty(
|
a_scale_stack = torch.empty(
|
||||||
(expert_offsets[-1] * (k_g // 128)), device=device, dtype=torch.float32
|
(expert_offsets[-1] * (k_g // 128)), device=device, dtype=torch.float32
|
||||||
)
|
)
|
||||||
@@ -167,6 +169,14 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for g in range(num_experts):
|
for g in range(num_experts):
|
||||||
|
# Matrix A is Row-Major.
|
||||||
|
a_original_stack[expert_offsets[g] : expert_offsets[g + 1]] = (
|
||||||
|
a_original_tensors[g]
|
||||||
|
)
|
||||||
|
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[
|
||||||
|
g
|
||||||
|
] # a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1)
|
||||||
|
b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1)
|
||||||
if cc == 9:
|
if cc == 9:
|
||||||
# For SM90, we need MN-Major scale factor
|
# For SM90, we need MN-Major scale factor
|
||||||
# a_scales_tensors[g] -- (M, k):(k, 1)
|
# a_scales_tensors[g] -- (M, k):(k, 1)
|
||||||
@@ -185,9 +195,20 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
|||||||
g
|
g
|
||||||
] # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
|
] # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
|
||||||
a_scale_stack = a_scale_stack.view(expert_offsets[-1], k_g // 128)
|
a_scale_stack = a_scale_stack.view(expert_offsets[-1], k_g // 128)
|
||||||
|
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
|
||||||
if cc == 10:
|
if cc == 10:
|
||||||
b_scale_stack = b_scale_stack.transpose(1, 2).contiguous()
|
b_scale_stack = b_scale_stack.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
if use_custom_kernel:
|
||||||
|
# Replace a_stack, a_scale_stack with custom kernel output
|
||||||
|
a_stack, a_scale_stack = per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||||
|
a_original_stack,
|
||||||
|
expert_offsets[:-1],
|
||||||
|
problem_sizes,
|
||||||
|
128,
|
||||||
|
expert_tokens_alignment=alignment,
|
||||||
|
)
|
||||||
|
|
||||||
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
||||||
a_strides = torch.full(
|
a_strides = torch.full(
|
||||||
(num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
|
(num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
|
||||||
|
|||||||
Reference in New Issue
Block a user