247 lines
7.9 KiB
Python
247 lines
7.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
import pytest
|
|
import torch
|
|
from sgl_kernel import scaled_fp4_quant
|
|
|
|
from sglang.srt.layers.activation import SiluAndMul
|
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
|
from sglang.srt.layers.moe.topk import select_experts
|
|
|
|
if torch.cuda.get_device_capability() < (10, 0):
|
|
pytest.skip(
|
|
reason="Nvfp4 Requires compute capability of 10 or above.",
|
|
allow_module_level=True,
|
|
)
|
|
|
|
kE2M1ToFloat = torch.tensor(
|
|
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
|
)
|
|
|
|
FLOAT8_E4M3_MAX = 448.0
|
|
FLOAT4_E2M1_MAX = 6.0
|
|
|
|
|
|
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
|
m_tiles = (m + 128 - 1) // 128
|
|
f = block_size * 4
|
|
k_tiles = (k + f - 1) // f
|
|
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
|
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
|
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
|
return out[0:m, 0:k]
|
|
|
|
|
|
def dequantize_nvfp4_to_dtype(
|
|
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
|
):
|
|
"""Dequantize the fp4 tensor back to high precision."""
|
|
# Two fp4 values are packed into one uint8.
|
|
assert tensor_fp4.dtype == torch.uint8
|
|
m, packed_k = tensor_fp4.shape
|
|
k = packed_k * 2
|
|
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
|
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
|
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
|
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
|
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
|
|
|
# scale the tensor
|
|
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
|
return out.to(dtype=dtype)
|
|
|
|
|
|
def break_fp4_bytes(a, dtype):
|
|
assert a.dtype == torch.uint8
|
|
m, n = a.shape
|
|
|
|
# Vectorized nibble processing
|
|
a_flat = a.flatten()
|
|
high = (a_flat & 0xF0) >> 4 # Upper nibbles
|
|
low = a_flat & 0x0F # Lower nibbles
|
|
|
|
# Combine nibbles for batch processing
|
|
combined = torch.stack((low, high), dim=1).flatten()
|
|
|
|
# Vectorized sign and magnitude extraction
|
|
signs = (combined & 0x08).to(torch.bool) # Sign bits
|
|
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
|
|
|
|
# Device-aware lookup and sign application
|
|
kE2M1 = kE2M1ToFloat.to(device=a.device)
|
|
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
|
|
|
|
# Reshape to final form
|
|
return values.reshape(m, n * 2).to(dtype=dtype)
|
|
|
|
|
|
MNK_FACTORS = [
|
|
(2, 1024, 1024),
|
|
(2, 1024, 1536),
|
|
(2, 3072, 1024),
|
|
(2, 3072, 1536),
|
|
(64, 1024, 1024),
|
|
(64, 1024, 1536),
|
|
(64, 3072, 1024),
|
|
(64, 2048, 1024),
|
|
(224, 1024, 1024),
|
|
(224, 1024, 1536),
|
|
]
|
|
|
|
|
|
# Reference implementation of torch_moe
|
|
def torch_moe(a, w1, w2, score, topk, expert_map):
|
|
B, D = a.shape
|
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
topk_weight, topk_ids = torch.topk(score, topk)
|
|
topk_weight = topk_weight.view(-1)
|
|
topk_ids = topk_ids.view(-1)
|
|
if expert_map is not None:
|
|
topk_ids = expert_map[topk_ids]
|
|
for i in range(w1.shape[0]):
|
|
mask = topk_ids == i
|
|
if mask.sum():
|
|
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
|
|
0, 1
|
|
)
|
|
return (
|
|
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
|
).sum(dim=1)
|
|
|
|
|
|
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
|
@pytest.mark.parametrize("e", [40, 64, 256])
|
|
@pytest.mark.parametrize("topk", [1, 6, 8])
|
|
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
|
@torch.inference_mode()
|
|
def test_cutlass_fp4_moe_no_graph(
|
|
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
|
):
|
|
|
|
torch.manual_seed(7)
|
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
|
quant_blocksize = 16
|
|
round_up = lambda x, y: (x + y - 1) // y * y
|
|
sf_w1_2n = round_up(2 * n, 128)
|
|
sf_w1_k = round_up(k // quant_blocksize, 4)
|
|
w1_blockscale = torch.empty(
|
|
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
|
|
)
|
|
|
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
|
sf_w2_k = round_up(k, 128)
|
|
sf_w2_n = round_up(n // quant_blocksize, 4)
|
|
w2_blockscale = torch.empty(
|
|
(e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn
|
|
)
|
|
|
|
w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
|
|
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
|
|
w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
|
|
w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
|
|
|
|
for expert in range(e):
|
|
w1_amax = torch.abs(w1).max().to(torch.float32)
|
|
w2_amax = torch.abs(w2).max().to(torch.float32)
|
|
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
|
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
|
|
|
w1_q[expert], w1_blockscale[expert] = scaled_fp4_quant(
|
|
w1[expert], w1_gs[expert]
|
|
)
|
|
|
|
w2_q[expert], w2_blockscale[expert] = scaled_fp4_quant(
|
|
w2[expert], w2_gs[expert]
|
|
)
|
|
|
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
|
|
|
topk_weights, topk_ids, _ = select_experts(
|
|
hidden_states=a,
|
|
router_logits=score,
|
|
top_k=topk,
|
|
)
|
|
|
|
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
|
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
|
# strides for the cutlass moe_fp4 kernel
|
|
ab_strides_13 = torch.full(
|
|
(e,), w1_q.shape[2] * 2, dtype=torch.int64, device=w1_q.device
|
|
)
|
|
c_strides_13 = torch.full(
|
|
(e,), w1_q.shape[1], dtype=torch.int64, device=w1_q.device
|
|
)
|
|
ab_strides_2 = torch.full(
|
|
(e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device
|
|
)
|
|
c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device)
|
|
params = CutlassMoEParams(
|
|
CutlassMoEType.BlockscaledFP4,
|
|
device=a.device,
|
|
num_experts=e,
|
|
intermediate_size_per_partition=n, # n
|
|
hidden_size=k,
|
|
) # k
|
|
cutlass_output = cutlass_moe_fp4(
|
|
a=a,
|
|
a1_gscale=a1_gs,
|
|
w1_fp4=w1_q,
|
|
w1_blockscale=w1_blockscale,
|
|
w1_alphas=(1 / w1_gs),
|
|
a2_gscale=a2_gs,
|
|
w2_fp4=w2_q,
|
|
w2_blockscale=w2_blockscale,
|
|
w2_alphas=(1 / w2_gs),
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
params=params,
|
|
apply_router_weight_on_input=False,
|
|
)
|
|
|
|
# Reference check:
|
|
a_global_scale = (
|
|
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
|
|
).to(torch.float32)
|
|
a_fp4, a_scale_interleaved = scaled_fp4_quant(a, a_global_scale)
|
|
_, m_k = a_fp4.shape
|
|
a_in_dtype = dequantize_nvfp4_to_dtype(
|
|
a_fp4,
|
|
a_scale_interleaved,
|
|
a_global_scale,
|
|
dtype=a.dtype,
|
|
device=a.device,
|
|
block_size=quant_blocksize,
|
|
)
|
|
|
|
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
|
|
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
|
|
|
for idx in range(0, e):
|
|
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
|
w1_q[idx],
|
|
w1_blockscale[idx],
|
|
w1_gs[idx],
|
|
dtype=w1.dtype,
|
|
device=w1.device,
|
|
block_size=quant_blocksize,
|
|
)
|
|
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
|
w2_q[idx],
|
|
w2_blockscale[idx],
|
|
w2_gs[idx],
|
|
dtype=w2.dtype,
|
|
device=w2.device,
|
|
block_size=quant_blocksize,
|
|
)
|
|
|
|
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
|
|
|
|
torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
|