[fix]Update unitest for fp8_blockwise_scaled_grouped_mm kernel (#7932)
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@@ -20,6 +21,44 @@ def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
|
||||||
|
def calc_diff(x, y):
|
||||||
|
x, y = x.double(), y.double()
|
||||||
|
denominator = (x * x + y * y).sum()
|
||||||
|
sim = 2 * (x * y).sum() / denominator
|
||||||
|
return 1 - sim
|
||||||
|
|
||||||
|
|
||||||
|
def ceil_div(x: int, y: int) -> int:
|
||||||
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert x.dim() == 2
|
||||||
|
m, n = x.shape
|
||||||
|
pad_size = (128 - (n % 128)) % 128
|
||||||
|
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
|
||||||
|
x_view = x.view(m, -1, 128)
|
||||||
|
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||||
|
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||||
|
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert x.dim() == 2
|
||||||
|
m, n = x.shape
|
||||||
|
x_padded = torch.zeros(
|
||||||
|
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
x_padded[:m, :n] = x
|
||||||
|
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||||
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||||
|
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||||
|
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
|
||||||
|
x_view.size(0), x_view.size(2)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def baseline_scaled_mm(
|
def baseline_scaled_mm(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
b: torch.Tensor,
|
b: torch.Tensor,
|
||||||
@@ -55,7 +94,7 @@ def is_sm100_supported(device=None) -> bool:
|
|||||||
|
|
||||||
def is_sm90_supported(device=None) -> bool:
|
def is_sm90_supported(device=None) -> bool:
|
||||||
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
||||||
torch.version.cuda >= "12.8"
|
torch.version.cuda >= "12.3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -66,14 +105,12 @@ def is_sm90_supported(device=None) -> bool:
|
|||||||
@pytest.mark.parametrize("num_experts", [8, 16])
|
@pytest.mark.parametrize("num_experts", [8, 16])
|
||||||
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
||||||
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||||
|
cc = torch.cuda.get_device_capability(None)[0]
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
alignment = 16
|
alignment = 16
|
||||||
n_g = alignment * random.randint(1, 5) * 128
|
n_g = alignment * random.randint(1, 5) * 128
|
||||||
k_g = alignment * random.randint(1, 5) * 128
|
k_g = alignment * random.randint(1, 5) * 128
|
||||||
|
|
||||||
scale_a_group_shape = (1, 128)
|
|
||||||
scale_b_group_shape = (128, 128)
|
|
||||||
|
|
||||||
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32)
|
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32)
|
||||||
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
|
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
|
||||||
layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
|
layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
|
||||||
@@ -90,20 +127,21 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
|||||||
expert_offsets[g + 1] = expert_offsets[g] + m_g
|
expert_offsets[g + 1] = expert_offsets[g] + m_g
|
||||||
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
|
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
|
||||||
|
|
||||||
a_g = to_fp8(torch.randn((m_g, k_g), device=device))
|
a = torch.randn((m_g, k_g), device=device, dtype=out_dtype) # (M, K):(K, 1)
|
||||||
b_g = to_fp8(torch.randn((n_g, k_g), device=device).t())
|
b = torch.randn((n_g, k_g), device=device, dtype=out_dtype).t() # (K, N):(1, K)
|
||||||
|
|
||||||
|
a_g, a_scale = per_token_cast_to_fp8(
|
||||||
|
a
|
||||||
|
) # ag -- (M, K):(K, 1), a_scale() -- (M, k):(k, 1)
|
||||||
|
b_g, b_scale = per_block_cast_to_fp8(
|
||||||
|
b
|
||||||
|
) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
|
||||||
a_tensors.append(a_g)
|
a_tensors.append(a_g)
|
||||||
b_tensors.append(b_g)
|
b_tensors.append(b_g)
|
||||||
|
a_scales_tensors.append(a_scale)
|
||||||
|
b_scales_tensors.append(b_scale)
|
||||||
|
|
||||||
scale_a_shape = scale_shape(a_g.shape, scale_a_group_shape)
|
baseline = torch.mm(a, b)
|
||||||
scale_b_shape = scale_shape(b_g.shape, scale_b_group_shape)
|
|
||||||
|
|
||||||
a_scales_tensors.append(torch.randn(scale_a_shape, device=device) * 0.001)
|
|
||||||
b_scales_tensors.append(torch.randn(scale_b_shape, device=device) * 0.001)
|
|
||||||
|
|
||||||
baseline = baseline_scaled_mm(
|
|
||||||
a_g, b_g, a_scales_tensors[-1], b_scales_tensors[-1], out_dtype
|
|
||||||
)
|
|
||||||
baseline_tensors.append(baseline)
|
baseline_tensors.append(baseline)
|
||||||
|
|
||||||
a_stack = torch.empty(
|
a_stack = torch.empty(
|
||||||
@@ -114,21 +152,41 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for g in range(num_experts):
|
for g in range(num_experts):
|
||||||
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
|
# Matrix A is Row-Major.
|
||||||
b_stack[g] = b_tensors[g].t()
|
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[
|
||||||
b_stack = b_stack.transpose(1, 2)
|
g
|
||||||
|
] # a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1)
|
||||||
|
b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1)
|
||||||
|
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
|
||||||
|
|
||||||
a_scale_stack = torch.empty(
|
a_scale_stack = torch.empty(
|
||||||
(expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32
|
(expert_offsets[-1] * (k_g // 128)), device=device, dtype=torch.float32
|
||||||
)
|
)
|
||||||
b_scale_stack = torch.empty(
|
b_scale_stack = torch.empty(
|
||||||
(num_experts, n_g // 128, k_g // 128), device=device, dtype=torch.float32
|
(num_experts, k_g // 128, n_g // 128), device=device, dtype=torch.float32
|
||||||
)
|
)
|
||||||
|
|
||||||
for g in range(num_experts):
|
for g in range(num_experts):
|
||||||
a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g]
|
if cc == 9:
|
||||||
b_scale_stack[g] = b_scales_tensors[g].t()
|
# For SM90, we need MN-Major scale factor
|
||||||
b_scale_stack = b_scale_stack.transpose(1, 2)
|
# a_scales_tensors[g] -- (M, k):(k, 1)
|
||||||
|
# a_scales_tensors[g].t().contiguous() -- (k, M):(M, 1)
|
||||||
|
a_scale_stack[
|
||||||
|
expert_offsets[g] * (k_g // 128) : expert_offsets[g + 1] * (k_g // 128)
|
||||||
|
] = (a_scales_tensors[g].t().contiguous().view(-1))
|
||||||
|
b_scale_stack[g] = b_scales_tensors[g] # b_scale_stack[g] -- (k, n):(n, 1)
|
||||||
|
elif cc == 10:
|
||||||
|
# For SM100, we need K-Major scale factor
|
||||||
|
# a_scales_tensors[g] -- (M, k):(k, 1)
|
||||||
|
a_scale_stack[
|
||||||
|
expert_offsets[g] * (k_g // 128) : expert_offsets[g + 1] * (k_g // 128)
|
||||||
|
] = a_scales_tensors[g].view(-1)
|
||||||
|
b_scale_stack[g] = b_scales_tensors[
|
||||||
|
g
|
||||||
|
] # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
|
||||||
|
a_scale_stack = a_scale_stack.view(expert_offsets[-1], k_g // 128)
|
||||||
|
if cc == 10:
|
||||||
|
b_scale_stack = b_scale_stack.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
||||||
a_strides = torch.full(
|
a_strides = torch.full(
|
||||||
@@ -168,8 +226,11 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
|||||||
for g in range(num_experts):
|
for g in range(num_experts):
|
||||||
baseline = baseline_tensors[g]
|
baseline = baseline_tensors[g]
|
||||||
actual = c_out[expert_offsets[g] : expert_offsets[g + 1]]
|
actual = c_out[expert_offsets[g] : expert_offsets[g + 1]]
|
||||||
torch.testing.assert_close(actual, baseline, rtol=1e-2, atol=1e-3)
|
diff = calc_diff(actual, baseline)
|
||||||
print(f"num_experts={num_experts}, out_dtype={out_dtype}: OK")
|
assert diff < 0.001
|
||||||
|
print(
|
||||||
|
f"cc={cc}0 num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user