pass a_scale from fp8 quant result instead of hard code to 1.0f (#10241)
Co-authored-by: Yichen Wang <yichen.wang@bytedance.com> Co-authored-by: Jinwu Guo <641876696@qq.com>
This commit is contained in:
@@ -147,8 +147,8 @@ def cutlass_w4a8_moe(
|
|||||||
k,
|
k,
|
||||||
)
|
)
|
||||||
|
|
||||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
|
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
|
||||||
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
|
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
cutlass_w4a8_moe_mm(
|
cutlass_w4a8_moe_mm(
|
||||||
c1,
|
c1,
|
||||||
@@ -166,7 +166,7 @@ def cutlass_w4a8_moe(
|
|||||||
topk,
|
topk,
|
||||||
)
|
)
|
||||||
|
|
||||||
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
|
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
|
||||||
silu_and_mul(c1, intermediate)
|
silu_and_mul(c1, intermediate)
|
||||||
|
|
||||||
intermediate_q = torch.empty(
|
intermediate_q = torch.empty(
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ void cutlass_w4a8_group_gemm_caller(
|
|||||||
|
|
||||||
Args arguments;
|
Args arguments;
|
||||||
decltype(arguments.epilogue.thread) fusion_args;
|
decltype(arguments.epilogue.thread) fusion_args;
|
||||||
fusion_args.alpha = 1.0f;
|
fusion_args.alpha = 0;
|
||||||
fusion_args.beta = 0;
|
fusion_args.beta = 0;
|
||||||
fusion_args.alpha_ptr = a_scales.data_ptr<float>();
|
fusion_args.alpha_ptr = a_scales.data_ptr<float>();
|
||||||
;
|
;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel import cutlass_w4a8_moe_mm
|
from sgl_kernel import cutlass_w4a8_moe_mm, sgl_per_tensor_quant_fp8
|
||||||
from utils import is_hopper
|
from utils import is_hopper
|
||||||
|
|
||||||
|
|
||||||
@@ -67,7 +67,6 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
|||||||
if debug:
|
if debug:
|
||||||
a = torch.ones(m, k, dtype=torch.bfloat16, device=device)
|
a = torch.ones(m, k, dtype=torch.bfloat16, device=device)
|
||||||
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
|
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
|
||||||
a_scale = torch.ones(1, dtype=torch.float, device=device)
|
|
||||||
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
|
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
a = torch.randn(m, k, dtype=dtype, device=device)
|
a = torch.randn(m, k, dtype=dtype, device=device)
|
||||||
@@ -75,7 +74,6 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
|||||||
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
|
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
|
||||||
)
|
)
|
||||||
affine_coeff = 0.005
|
affine_coeff = 0.005
|
||||||
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
|
|
||||||
ref_w_scale = (
|
ref_w_scale = (
|
||||||
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
|
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||||
* affine_coeff
|
* affine_coeff
|
||||||
@@ -93,7 +91,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
|||||||
s_strides = c_strides
|
s_strides = c_strides
|
||||||
|
|
||||||
# Quantize input
|
# Quantize input
|
||||||
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device)
|
a_q, a_scale = _per_tensor_quant_fp8(a)
|
||||||
|
|
||||||
# Create output tensor
|
# Create output tensor
|
||||||
c = torch.empty((m, n), dtype=torch.bfloat16, device=device)
|
c = torch.empty((m, n), dtype=torch.bfloat16, device=device)
|
||||||
@@ -117,7 +115,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
|||||||
# Reference implementation
|
# Reference implementation
|
||||||
experts_selection_result = torch.full((m,), 0)
|
experts_selection_result = torch.full((m,), 0)
|
||||||
c_ref = ref_grouped_gemm(
|
c_ref = ref_grouped_gemm(
|
||||||
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compare results
|
# Compare results
|
||||||
@@ -138,17 +136,29 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.skipif(
|
def _per_tensor_quant_fp8(
|
||||||
# not is_hopper(),
|
x: torch.Tensor,
|
||||||
# reason="cutlass_w4a8_moe_mm is only supported on sm90",
|
dtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
# )
|
):
|
||||||
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||||
|
x_s = torch.empty(
|
||||||
|
1,
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
sgl_per_tensor_quant_fp8(x, x_q, x_s, is_static=False)
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
True,
|
not is_hopper(),
|
||||||
reason="TODO(rainj-me): fix cu129 binary issue on hopper cu126",
|
reason="cutlass_w4a8_moe_mm is only supported on sm90",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
|
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16, 32])
|
||||||
@pytest.mark.parametrize("k", [256, 512, 1024])
|
@pytest.mark.parametrize("k", [512, 1024, 2048, 4096, 7168])
|
||||||
@pytest.mark.parametrize("n", [1024, 2048, 7168])
|
@pytest.mark.parametrize("n", [256, 512, 1024, 2048])
|
||||||
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
|
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
|
||||||
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
@@ -163,7 +173,6 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
|||||||
if debug:
|
if debug:
|
||||||
a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device)
|
a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device)
|
||||||
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
|
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
|
||||||
a_scale = torch.ones(1, dtype=torch.float, device=device)
|
|
||||||
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
|
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
a = torch.randn(batch_size, k, dtype=dtype, device=device)
|
a = torch.randn(batch_size, k, dtype=dtype, device=device)
|
||||||
@@ -171,7 +180,6 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
|||||||
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
|
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
|
||||||
)
|
)
|
||||||
affine_coeff = 0.005
|
affine_coeff = 0.005
|
||||||
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
|
|
||||||
ref_w_scale = (
|
ref_w_scale = (
|
||||||
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
|
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||||
* affine_coeff
|
* affine_coeff
|
||||||
@@ -202,12 +210,8 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
|||||||
expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device)
|
expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
# Permute input and quantize
|
# Permute input and quantize
|
||||||
a_perm = a[permutation]
|
a_q, a_scale = _per_tensor_quant_fp8(a)
|
||||||
a_q_perm = (
|
a_q_perm = a_q[permutation]
|
||||||
torch.clamp((a_perm / a_scale), -448.0, 448.0)
|
|
||||||
.to(torch.float8_e4m3fn)
|
|
||||||
.to(device)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create stride tensors
|
# Create stride tensors
|
||||||
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
|
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
|
||||||
@@ -238,7 +242,7 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
|||||||
c = c.to(dtype)
|
c = c.to(dtype)
|
||||||
|
|
||||||
c_ref = ref_grouped_gemm(
|
c_ref = ref_grouped_gemm(
|
||||||
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compare results
|
# Compare results
|
||||||
@@ -256,10 +260,11 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_result):
|
def ref_grouped_gemm(
|
||||||
|
c, a_q, a_scale, w, w_scale, num_experts, experts_selection_result
|
||||||
|
):
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
c_ref = torch.zeros_like(c)
|
c_ref = torch.zeros_like(c)
|
||||||
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn)
|
|
||||||
for i in range(num_experts):
|
for i in range(num_experts):
|
||||||
token_idx = torch.where(experts_selection_result == i)[0]
|
token_idx = torch.where(experts_selection_result == i)[0]
|
||||||
if len(token_idx) == 0:
|
if len(token_idx) == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user