[feat]Support fusion kernel for constructing quant input and scale factor for fp8_blockwise_scaled_grouped_mm (#8023)
This commit is contained in:
@@ -5,6 +5,10 @@ import pytest
|
||||
import torch
|
||||
from sgl_kernel import fp8_blockwise_scaled_grouped_mm
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8_hopper_moe_mn_major,
|
||||
)
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
return -(a // -b)
|
||||
@@ -104,8 +108,11 @@ def is_sm90_supported(device=None) -> bool:
|
||||
)
|
||||
@pytest.mark.parametrize("num_experts", [8, 16])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
||||
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
@pytest.mark.parametrize("use_custom_kernel", [True, False])
|
||||
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kernel):
|
||||
cc = torch.cuda.get_device_capability(None)[0]
|
||||
if cc == 10 and use_custom_kernel:
|
||||
return
|
||||
device = "cuda"
|
||||
alignment = 16
|
||||
n_g = alignment * random.randint(1, 5) * 128
|
||||
@@ -116,6 +123,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
|
||||
layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
|
||||
|
||||
a_original_tensors = []
|
||||
a_tensors = []
|
||||
b_tensors = []
|
||||
a_scales_tensors = []
|
||||
@@ -136,6 +144,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
b_g, b_scale = per_block_cast_to_fp8(
|
||||
b
|
||||
) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
|
||||
a_original_tensors.append(a)
|
||||
a_tensors.append(a_g)
|
||||
b_tensors.append(b_g)
|
||||
a_scales_tensors.append(a_scale)
|
||||
@@ -143,22 +152,15 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
|
||||
baseline = torch.mm(a, b)
|
||||
baseline_tensors.append(baseline)
|
||||
|
||||
a_original_stack = torch.empty(
|
||||
(expert_offsets[-1], k_g), device=device, dtype=out_dtype
|
||||
)
|
||||
a_stack = torch.empty(
|
||||
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
b_stack = torch.empty(
|
||||
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
# Matrix A is Row-Major.
|
||||
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[
|
||||
g
|
||||
] # a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1)
|
||||
b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1)
|
||||
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
|
||||
|
||||
a_scale_stack = torch.empty(
|
||||
(expert_offsets[-1] * (k_g // 128)), device=device, dtype=torch.float32
|
||||
)
|
||||
@@ -167,6 +169,14 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
# Matrix A is Row-Major.
|
||||
a_original_stack[expert_offsets[g] : expert_offsets[g + 1]] = (
|
||||
a_original_tensors[g]
|
||||
)
|
||||
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[
|
||||
g
|
||||
] # a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1)
|
||||
b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1)
|
||||
if cc == 9:
|
||||
# For SM90, we need MN-Major scale factor
|
||||
# a_scales_tensors[g] -- (M, k):(k, 1)
|
||||
@@ -185,9 +195,20 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
g
|
||||
] # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
|
||||
a_scale_stack = a_scale_stack.view(expert_offsets[-1], k_g // 128)
|
||||
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
|
||||
if cc == 10:
|
||||
b_scale_stack = b_scale_stack.transpose(1, 2).contiguous()
|
||||
|
||||
if use_custom_kernel:
|
||||
# Replace a_stack, a_scale_stack with custom kernel output
|
||||
a_stack, a_scale_stack = per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||
a_original_stack,
|
||||
expert_offsets[:-1],
|
||||
problem_sizes,
|
||||
128,
|
||||
expert_tokens_alignment=alignment,
|
||||
)
|
||||
|
||||
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
||||
a_strides = torch.full(
|
||||
(num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
|
||||
|
||||
Reference in New Issue
Block a user