pass a_scale from fp8 quant result instead of hard code to 1.0f (#10241)
Co-authored-by: Yichen Wang <yichen.wang@bytedance.com> Co-authored-by: Jinwu Guo <641876696@qq.com>
This commit is contained in:
@@ -147,8 +147,8 @@ def cutlass_w4a8_moe(
|
||||
k,
|
||||
)
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
|
||||
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
|
||||
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
|
||||
|
||||
cutlass_w4a8_moe_mm(
|
||||
c1,
|
||||
@@ -166,7 +166,7 @@ def cutlass_w4a8_moe(
|
||||
topk,
|
||||
)
|
||||
|
||||
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
|
||||
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
|
||||
silu_and_mul(c1, intermediate)
|
||||
|
||||
intermediate_q = torch.empty(
|
||||
|
||||
@@ -209,7 +209,7 @@ void cutlass_w4a8_group_gemm_caller(
|
||||
|
||||
Args arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
fusion_args.alpha = 1.0f;
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = a_scales.data_ptr<float>();
|
||||
;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import cutlass_w4a8_moe_mm
|
||||
from sgl_kernel import cutlass_w4a8_moe_mm, sgl_per_tensor_quant_fp8
|
||||
from utils import is_hopper
|
||||
|
||||
|
||||
@@ -67,7 +67,6 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
||||
if debug:
|
||||
a = torch.ones(m, k, dtype=torch.bfloat16, device=device)
|
||||
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
|
||||
a_scale = torch.ones(1, dtype=torch.float, device=device)
|
||||
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||
else:
|
||||
a = torch.randn(m, k, dtype=dtype, device=device)
|
||||
@@ -75,7 +74,6 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
||||
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
|
||||
)
|
||||
affine_coeff = 0.005
|
||||
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
|
||||
ref_w_scale = (
|
||||
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||
* affine_coeff
|
||||
@@ -93,7 +91,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
||||
s_strides = c_strides
|
||||
|
||||
# Quantize input
|
||||
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device)
|
||||
a_q, a_scale = _per_tensor_quant_fp8(a)
|
||||
|
||||
# Create output tensor
|
||||
c = torch.empty((m, n), dtype=torch.bfloat16, device=device)
|
||||
@@ -117,7 +115,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
||||
# Reference implementation
|
||||
experts_selection_result = torch.full((m,), 0)
|
||||
c_ref = ref_grouped_gemm(
|
||||
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
||||
c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
||||
)
|
||||
|
||||
# Compare results
|
||||
@@ -138,17 +136,29 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
||||
raise
|
||||
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# not is_hopper(),
|
||||
# reason="cutlass_w4a8_moe_mm is only supported on sm90",
|
||||
# )
|
||||
def _per_tensor_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
dtype: torch.dtype = torch.float8_e4m3fn,
|
||||
):
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
x_s = torch.empty(
|
||||
1,
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
sgl_per_tensor_quant_fp8(x, x_q, x_s, is_static=False)
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
True,
|
||||
reason="TODO(rainj-me): fix cu129 binary issue on hopper cu126",
|
||||
not is_hopper(),
|
||||
reason="cutlass_w4a8_moe_mm is only supported on sm90",
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("k", [256, 512, 1024])
|
||||
@pytest.mark.parametrize("n", [1024, 2048, 7168])
|
||||
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16, 32])
|
||||
@pytest.mark.parametrize("k", [512, 1024, 2048, 4096, 7168])
|
||||
@pytest.mark.parametrize("n", [256, 512, 1024, 2048])
|
||||
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
|
||||
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||
torch.manual_seed(0)
|
||||
@@ -163,7 +173,6 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||
if debug:
|
||||
a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device)
|
||||
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
|
||||
a_scale = torch.ones(1, dtype=torch.float, device=device)
|
||||
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||
else:
|
||||
a = torch.randn(batch_size, k, dtype=dtype, device=device)
|
||||
@@ -171,7 +180,6 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
|
||||
)
|
||||
affine_coeff = 0.005
|
||||
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
|
||||
ref_w_scale = (
|
||||
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||
* affine_coeff
|
||||
@@ -202,12 +210,8 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||
expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device)
|
||||
|
||||
# Permute input and quantize
|
||||
a_perm = a[permutation]
|
||||
a_q_perm = (
|
||||
torch.clamp((a_perm / a_scale), -448.0, 448.0)
|
||||
.to(torch.float8_e4m3fn)
|
||||
.to(device)
|
||||
)
|
||||
a_q, a_scale = _per_tensor_quant_fp8(a)
|
||||
a_q_perm = a_q[permutation]
|
||||
|
||||
# Create stride tensors
|
||||
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
|
||||
@@ -238,7 +242,7 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||
c = c.to(dtype)
|
||||
|
||||
c_ref = ref_grouped_gemm(
|
||||
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
||||
c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
||||
)
|
||||
|
||||
# Compare results
|
||||
@@ -256,10 +260,11 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||
raise
|
||||
|
||||
|
||||
def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_result):
|
||||
def ref_grouped_gemm(
|
||||
c, a_q, a_scale, w, w_scale, num_experts, experts_selection_result
|
||||
):
|
||||
dtype = torch.bfloat16
|
||||
c_ref = torch.zeros_like(c)
|
||||
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn)
|
||||
for i in range(num_experts):
|
||||
token_idx = torch.where(experts_selection_result == i)[0]
|
||||
if len(token_idx) == 0:
|
||||
|
||||
Reference in New Issue
Block a user