[1/2] Add Kernel support for Cutlass based Fused FP4 MoE (#6093)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Cutlass MoE kernel."""
|
||||
"""CUTLASS based Fused MoE kernels."""
|
||||
|
||||
import functools
|
||||
import json
|
||||
@@ -14,8 +14,10 @@ _is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
import sgl_kernel
|
||||
from sgl_kernel import (
|
||||
cutlass_fp4_group_mm,
|
||||
fp8_blockwise_scaled_grouped_mm,
|
||||
prepare_moe_input,
|
||||
scaled_fp4_experts_quant,
|
||||
silu_and_mul,
|
||||
)
|
||||
|
||||
@@ -205,3 +207,178 @@ def cutlass_fused_experts(
|
||||
return (
|
||||
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
FLOAT8_E4M3_MAX = 448.0
|
||||
|
||||
|
||||
def cutlass_moe_fp4(
|
||||
a: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w1_alphas: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alphas: torch.Tensor,
|
||||
ab_strides_13: torch.Tensor,
|
||||
ab_strides_2: torch.Tensor,
|
||||
c_strides_13: torch.Tensor,
|
||||
c_strides_2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
):
|
||||
"""
|
||||
MoE implementation for FP4 Inputs
|
||||
|
||||
# Gemm 1
|
||||
a: Input tensor: [m, k] (half/bfloat16)
|
||||
a1_gscale: Activation scale per expert: [e] (float32)
|
||||
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
|
||||
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
|
||||
(Note: `n` is the up projection output dim, `k` is the input dim in
|
||||
full precision)
|
||||
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
|
||||
(Block size = 16 for NVFP4)
|
||||
|
||||
# Gemm 2
|
||||
a2_gscale: Activation scale per expert: [e]
|
||||
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
|
||||
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
|
||||
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
|
||||
|
||||
Strides for activations, weights and output in logical number of elements.
|
||||
The activations & output stride is the number of elements to the next row.
|
||||
The weights stride is the number of elements to the next row per expert.
|
||||
For example, if the weight is [e, n, k], then the b_stride is a tensor of
|
||||
shape [e] with each element being k. Similarly for activations, if the
|
||||
shape is [m, k], then the a_stride has shape [e] with each value k.
|
||||
Similarly for output, if the output is [m, n], then the c_stride is a
|
||||
tensor of shape [e] with each element being k.
|
||||
|
||||
Note: cutlass_fp4_group_mm is designed to accept the strides of
|
||||
activations and weights to be the same, so it is passed in as a single
|
||||
tensor.
|
||||
ab_strides_13: [e] dtype: int64 [Gemm 1: Activation / Weight strides]
|
||||
ab_strides_2: [e] dtype: int64 [Gemm 2: Activation / Weight strides]
|
||||
c_strides_13: [e] dtype: int64 [Gemm 1: Output Strides]
|
||||
c_strides_2: [e] dtype: int64 [Gemm 1: Output Strides]
|
||||
|
||||
topk_weights: [m, topk] dtype: float8
|
||||
topk_ids: [m, topk] dtype: float8
|
||||
|
||||
m, n, k: Unquantized weight shapes, dtype: int
|
||||
e: number of experts for the current rank, dtype: int
|
||||
assumes that topk < k < n to satisfy - up/down projection expectations.
|
||||
"""
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
|
||||
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
|
||||
assert (
|
||||
w1_fp4.ndim == 3
|
||||
and w2_fp4.ndim == 3
|
||||
and w1_blockscale.ndim == 3
|
||||
and w2_blockscale.ndim == 3
|
||||
), "All Weights must be of rank 3 for cutlass_moe_fp4"
|
||||
m_a, k_a = a.shape
|
||||
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
|
||||
e_w2, k_w2, half_n_w2 = w2_fp4.shape
|
||||
|
||||
assert e_w1 == e_w2 and e_w1 == e, (
|
||||
"Number of experts must match",
|
||||
" between weights.",
|
||||
)
|
||||
assert (
|
||||
k_a // 2 == half_k_w1 and k == k_w2
|
||||
), "Hidden size mismatch between a, w1 and w2"
|
||||
assert nx2_w1 == n * 2 and half_n_w2 == n // 2, "mismatch in " "expected `n`"
|
||||
assert m == m_a, "input shape mismatch"
|
||||
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
|
||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
||||
assert (
|
||||
topk_weights.shape[0] == m and topk_ids.shape[0] == m
|
||||
), "topk must be provided for each row of a"
|
||||
|
||||
out_dtype = a.dtype
|
||||
num_topk = topk_ids.shape[1]
|
||||
|
||||
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
|
||||
# Problem size: (num_experts, (m,2n,k))
|
||||
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
|
||||
# Problem size: (num_experts, (m,n,k))
|
||||
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)
|
||||
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
|
||||
# problem shapes should have [m, n, k]
|
||||
# Note that problem sizes are based on logical number of elements.
|
||||
blockscale_offsets = torch.empty(e + 1, dtype=torch.int32, device=device)
|
||||
prepare_moe_input(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
a_map,
|
||||
c_map,
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
blockscale_offsets,
|
||||
)
|
||||
|
||||
rep_a_fp4, rep_a_blockscale = scaled_fp4_experts_quant(
|
||||
a, a1_gscale, expert_offsets, blockscale_offsets, num_topk, expert_map=a_map
|
||||
)
|
||||
|
||||
c1 = cutlass_fp4_group_mm(
|
||||
rep_a_fp4,
|
||||
w1_fp4,
|
||||
rep_a_blockscale,
|
||||
w1_blockscale,
|
||||
w1_alphas,
|
||||
ab_strides_13,
|
||||
c_strides_13,
|
||||
problem_sizes1,
|
||||
expert_offsets[:-1],
|
||||
blockscale_offsets[:-1],
|
||||
out_dtype,
|
||||
device,
|
||||
)
|
||||
del rep_a_fp4, rep_a_blockscale
|
||||
# hidden size dimension is split to one halfpytho sized tensor.
|
||||
intermediate = torch.empty(
|
||||
(m * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype
|
||||
)
|
||||
|
||||
silu_and_mul(c1, intermediate)
|
||||
|
||||
int_fp4, int_blockscale = scaled_fp4_experts_quant(
|
||||
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk
|
||||
)
|
||||
c2 = cutlass_fp4_group_mm(
|
||||
int_fp4,
|
||||
w2_fp4,
|
||||
int_blockscale,
|
||||
w2_blockscale,
|
||||
w2_alphas,
|
||||
ab_strides_2,
|
||||
c_strides_2,
|
||||
problem_sizes2,
|
||||
expert_offsets[:-1],
|
||||
blockscale_offsets[:-1],
|
||||
out_dtype,
|
||||
device,
|
||||
)
|
||||
del int_fp4, int_blockscale
|
||||
out = (
|
||||
c2[c_map].view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half()
|
||||
).sum(dim=1)
|
||||
return out.to(dtype=out_dtype)
|
||||
|
||||
247
python/sglang/test/test_fp4_moe.py
Normal file
247
python/sglang/test/test_fp4_moe.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import scaled_fp4_quant
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
if torch.cuda.get_device_capability() < (10, 0):
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
kE2M1ToFloat = torch.tensor(
|
||||
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
||||
)
|
||||
|
||||
FLOAT8_E4M3_MAX = 448.0
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
|
||||
|
||||
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
m_tiles = (m + 128 - 1) // 128
|
||||
f = block_size * 4
|
||||
k_tiles = (k + f - 1) // f
|
||||
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def dequantize_nvfp4_to_dtype(
|
||||
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
||||
):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
assert tensor_fp4.dtype == torch.uint8
|
||||
m, packed_k = tensor_fp4.shape
|
||||
k = packed_k * 2
|
||||
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
||||
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
||||
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
||||
|
||||
# scale the tensor
|
||||
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
||||
return out.to(dtype=dtype)
|
||||
|
||||
|
||||
def break_fp4_bytes(a, dtype):
|
||||
assert a.dtype == torch.uint8
|
||||
m, n = a.shape
|
||||
|
||||
# Vectorized nibble processing
|
||||
a_flat = a.flatten()
|
||||
high = (a_flat & 0xF0) >> 4 # Upper nibbles
|
||||
low = a_flat & 0x0F # Lower nibbles
|
||||
|
||||
# Combine nibbles for batch processing
|
||||
combined = torch.stack((low, high), dim=1).flatten()
|
||||
|
||||
# Vectorized sign and magnitude extraction
|
||||
signs = (combined & 0x08).to(torch.bool) # Sign bits
|
||||
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
|
||||
|
||||
# Device-aware lookup and sign application
|
||||
kE2M1 = kE2M1ToFloat.to(device=a.device)
|
||||
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
|
||||
|
||||
# Reshape to final form
|
||||
return values.reshape(m, n * 2).to(dtype=dtype)
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(64, 1024, 1024),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(64, 2048, 1024),
|
||||
(224, 1024, 1024),
|
||||
(224, 1024, 1536),
|
||||
]
|
||||
|
||||
|
||||
# Reference implementation of torch_moe
|
||||
def torch_moe(a, w1, w2, score, topk, expert_map):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
|
||||
0, 1
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_cutlass_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
||||
):
|
||||
|
||||
torch.manual_seed(7)
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
quant_blocksize = 16
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
sf_w1_2n = round_up(2 * n, 128)
|
||||
sf_w1_k = round_up(k // quant_blocksize, 4)
|
||||
w1_blockscale = torch.empty(
|
||||
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
sf_w2_k = round_up(k, 128)
|
||||
sf_w2_n = round_up(n // quant_blocksize, 4)
|
||||
w2_blockscale = torch.empty(
|
||||
(e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
|
||||
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
|
||||
w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
|
||||
w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
|
||||
|
||||
for expert in range(e):
|
||||
w1_amax = torch.abs(w1).max().to(torch.float32)
|
||||
w2_amax = torch.abs(w2).max().to(torch.float32)
|
||||
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
||||
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
||||
|
||||
w1_q[expert], w1_blockscale[expert] = scaled_fp4_quant(
|
||||
w1[expert], w1_gs[expert]
|
||||
)
|
||||
|
||||
w2_q[expert], w2_blockscale[expert] = scaled_fp4_quant(
|
||||
w2[expert], w2_gs[expert]
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
# strides for the cutlass moe_fp4 kernel
|
||||
ab_strides_13 = torch.full(
|
||||
(e,), w1_q.shape[2] * 2, dtype=torch.int64, device=w1_q.device
|
||||
)
|
||||
c_strides_13 = torch.full(
|
||||
(e,), w1_q.shape[1], dtype=torch.int64, device=w1_q.device
|
||||
)
|
||||
ab_strides_2 = torch.full(
|
||||
(e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device
|
||||
)
|
||||
c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device)
|
||||
cutlass_output = cutlass_moe_fp4(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
w1_fp4=w1_q,
|
||||
w1_blockscale=w1_blockscale,
|
||||
w1_alphas=(1 / w1_gs),
|
||||
a2_gscale=a2_gs,
|
||||
w2_fp4=w2_q,
|
||||
w2_blockscale=w2_blockscale,
|
||||
w2_alphas=(1 / w2_gs),
|
||||
ab_strides_13=ab_strides_13,
|
||||
ab_strides_2=ab_strides_2,
|
||||
c_strides_13=c_strides_13,
|
||||
c_strides_2=c_strides_2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=e,
|
||||
device=a.device,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
a_fp4, a_scale_interleaved = scaled_fp4_quant(a, a_global_scale)
|
||||
_, m_k = a_fp4.shape
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=a.dtype,
|
||||
device=a.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
|
||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=w1.dtype,
|
||||
device=w1.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=w2.dtype,
|
||||
device=w2.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
|
||||
|
||||
torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
|
||||
@@ -210,6 +210,7 @@ set(SOURCES
|
||||
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
|
||||
"csrc/gemm/fp8_gemm_kernel.cu"
|
||||
"csrc/gemm/int8_gemm_kernel.cu"
|
||||
"csrc/gemm/nvfp4_expert_quant.cu"
|
||||
"csrc/gemm/nvfp4_quant_entry.cu"
|
||||
"csrc/gemm/nvfp4_quant_kernels.cu"
|
||||
"csrc/gemm/nvfp4_scaled_mm_entry.cu"
|
||||
@@ -222,6 +223,7 @@ set(SOURCES
|
||||
"csrc/moe/moe_align_kernel.cu"
|
||||
"csrc/moe/moe_fused_gate.cu"
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu"
|
||||
"csrc/moe/nvfp4_blockwise_moe.cu"
|
||||
"csrc/moe/fp8_blockwise_moe_kernel.cu"
|
||||
"csrc/moe/prepare_moe_input.cu"
|
||||
"csrc/moe/ep_moe_reorder_kernel.cu"
|
||||
|
||||
@@ -132,6 +132,20 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
" Tensor! output_scale, Tensor! input_scale) -> ()");
|
||||
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
||||
|
||||
// Compute NVFP4 experts quantization.
|
||||
m.def(
|
||||
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
m.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
|
||||
|
||||
m.def(
|
||||
"cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b,"
|
||||
"Tensor a_blockscale, Tensor b_blockscale, Tensor alphas,"
|
||||
"Tensor ab_strides, Tensor c_strides, Tensor problem_sizes,"
|
||||
" Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
||||
m.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
@@ -161,9 +175,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"expert_offsets, Tensor workspace) -> ()");
|
||||
m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm);
|
||||
m.def(
|
||||
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor problem_sizes1, Tensor problem_sizes2, Tensor "
|
||||
"input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> ()");
|
||||
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor? blockscale_offsets, Tensor problem_sizes1,"
|
||||
" Tensor problem_sizes2, Tensor input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> "
|
||||
"()");
|
||||
m.impl("prepare_moe_input", torch::kCUDA, &prepare_moe_input);
|
||||
|
||||
m.def("shuffle_rows(Tensor input, Tensor dst2src_map, Tensor output) -> ()");
|
||||
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
|
||||
431
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
Normal file
431
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
Normal file
@@ -0,0 +1,431 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
template <typename T>
|
||||
struct TypeConverter {
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2> {
|
||||
using Type = half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162> {
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat16> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
|
||||
#define ELTS_PER_THREAD 8
|
||||
|
||||
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
|
||||
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
||||
|
||||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
|
||||
// PTX instructions used here requires sm100a.
|
||||
#if CUDA_VERSION >= 12080
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0]),
|
||||
"f"(array[1]),
|
||||
"f"(array[2]),
|
||||
"f"(array[3]),
|
||||
"f"(array[4]),
|
||||
"f"(array[5]),
|
||||
"f"(array[6]),
|
||||
"f"(array[7]));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
|
||||
// PTX instructions used here requires sm100a.
|
||||
#if CUDA_VERSION >= 12080
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0].x),
|
||||
"f"(array[0].y),
|
||||
"f"(array[1].x),
|
||||
"f"(array[1].y),
|
||||
"f"(array[2].x),
|
||||
"f"(array[2].y),
|
||||
"f"(array[3].x),
|
||||
"f"(array[3].y));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
// Fast reciprocal.
|
||||
inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||
float b;
|
||||
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
|
||||
return b;
|
||||
}
|
||||
|
||||
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
||||
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2);
|
||||
|
||||
// One pair of threads write one SF to global memory.
|
||||
// TODO: stage through smem for packed STG.32
|
||||
// is it better than STG.8 from 4 threads ?
|
||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
||||
// SF vector index (16 elements share one SF in the K dimension).
|
||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||
int32_t mIdx = rowIdx;
|
||||
|
||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
||||
|
||||
int32_t mTileIdx = mIdx / (32 * 4);
|
||||
// SF vector size 16.
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numKTiles = (numCols + factor - 1) / factor;
|
||||
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
||||
|
||||
int32_t kTileIdx = (kIdx / 4);
|
||||
int64_t kTileStride = 32 * 4 * 4;
|
||||
|
||||
// M tile layout [32, 4] is column-major.
|
||||
int32_t outerMIdx = (mIdx % 32);
|
||||
int64_t outerMStride = 4 * 4;
|
||||
|
||||
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
||||
int64_t innerMStride = 4;
|
||||
|
||||
int32_t innerKIdx = (kIdx % 4);
|
||||
int64_t innerKStride = 1;
|
||||
|
||||
// Compute the global offset.
|
||||
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride +
|
||||
innerMIdx * innerMStride + innerKIdx * innerKStride;
|
||||
|
||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||
}
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Define a 16 bytes packed data type.
|
||||
template <class Type>
|
||||
struct PackedVec {
|
||||
typename TypeConverter<Type>::Type elts[4];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedVec<__nv_fp8_e4m3> {
|
||||
__nv_fp8x2_e4m3 elts[8];
|
||||
};
|
||||
|
||||
// Quantizes the provided PackedVec into the uint32_t output
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
// Get absolute maximum values among the local 8 values.
|
||||
auto localMax = __habs2(vec.elts[0]);
|
||||
|
||||
// Local maximum value.
|
||||
#pragma unroll
|
||||
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
|
||||
}
|
||||
|
||||
// Get the absolute maximum among all 16 values (two threads).
|
||||
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
||||
// Get the final absolute maximum values.
|
||||
float vecMax = float(__hmax(localMax.x, localMax.y));
|
||||
|
||||
// Get the SF (max value of the vector / max value of e2m1).
|
||||
// maximum value of e2m1 = 6.0.
|
||||
// TODO: use half as compute data type.
|
||||
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
||||
// 8 bits representation of the SF.
|
||||
uint8_t fp8SFVal;
|
||||
// Write the SF to global memory (STG.8).
|
||||
if constexpr (UE8M0_SF) {
|
||||
// Extract the 8 exponent bits from float32.
|
||||
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
|
||||
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
|
||||
fp8SFVal = tmp & 0xff;
|
||||
// Convert back to fp32.
|
||||
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
|
||||
} else {
|
||||
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
|
||||
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
|
||||
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
|
||||
// Convert back to fp32.
|
||||
SFValue = float(tmp);
|
||||
}
|
||||
// Get the output scale.
|
||||
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
|
||||
// reciprocal(SFScaleVal))
|
||||
float outputScale =
|
||||
SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;
|
||||
|
||||
if (SFout) {
|
||||
// Write the SF to global memory (STG.8).
|
||||
*SFout = fp8SFVal;
|
||||
}
|
||||
|
||||
// Convert the input to float.
|
||||
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
fp2Vals[i] = __half22float2(vec.elts[i]);
|
||||
} else {
|
||||
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
|
||||
}
|
||||
fp2Vals[i].x *= outputScale;
|
||||
fp2Vals[i].y *= outputScale;
|
||||
}
|
||||
|
||||
// Convert to e2m1 values.
|
||||
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
||||
|
||||
// Write the e2m1 values to global memory.
|
||||
return e2m1Vec;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
|
||||
#else
|
||||
cvt_fp16_to_fp4(
|
||||
#endif
|
||||
int32_t numRows,
|
||||
int32_t numCols,
|
||||
Type const* in,
|
||||
float const* SFScale,
|
||||
uint32_t* out,
|
||||
uint32_t* SFout,
|
||||
uint32_t* input_offset_by_experts,
|
||||
uint32_t* output_scale_offset_by_experts,
|
||||
int n_experts) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
|
||||
|
||||
// Input tensor row/col loops.
|
||||
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
|
||||
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) {
|
||||
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = inOffset;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Find index within the experts.
|
||||
int rowIdx_in_expert = 0;
|
||||
int expert_idx = 0;
|
||||
for (int i = 0; i < n_experts; i++) {
|
||||
if (rowIdx >= input_offset_by_experts[i] && rowIdx < input_offset_by_experts[i + 1]) {
|
||||
rowIdx_in_expert = rowIdx - input_offset_by_experts[i];
|
||||
expert_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
// The actual output_scales dim is computed from the padded numCols.
|
||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
||||
uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
||||
|
||||
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void quant_impl(
|
||||
void* output,
|
||||
void* output_scale,
|
||||
void* input,
|
||||
void* input_global_scale,
|
||||
void* input_offset_by_experts,
|
||||
void* output_scale_offset_by_experts,
|
||||
int m_topk,
|
||||
int k,
|
||||
int n_experts,
|
||||
cudaStream_t stream) {
|
||||
// TODO: this multiProcessorCount should be cached.
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int multiProcessorCount;
|
||||
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);
|
||||
|
||||
// Grid, Block size.
|
||||
// Each thread converts 8 values.
|
||||
dim3 block(std::min(int(k / ELTS_PER_THREAD), 512));
|
||||
// Get number of blocks per SM (assume we can fully utilize the SM).
|
||||
int const numBlocksPerSM = 2048 / block.x;
|
||||
dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
|
||||
m_topk,
|
||||
k,
|
||||
reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts);
|
||||
}
|
||||
|
||||
/*Quantization entry for fp4 experts quantization*/
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
#define CHECK_INPUT(x, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m);
|
||||
|
||||
// constexpr auto FP8 = at::ScalarType::Float8_e4m3fn;
|
||||
constexpr auto HALF = at::ScalarType::Half;
|
||||
constexpr auto BF16 = at::ScalarType::BFloat16;
|
||||
constexpr auto FLOAT = at::ScalarType::Float;
|
||||
constexpr auto INT = at::ScalarType::Int;
|
||||
constexpr auto UINT8 = at::ScalarType::Byte;
|
||||
|
||||
void scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
CHECK_INPUT(output, "output must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input, "input must be a CUDA tensor");
|
||||
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2);
|
||||
TORCH_CHECK(output_scale.dim() == 2);
|
||||
TORCH_CHECK(input.dim() == 2);
|
||||
TORCH_CHECK(input_global_scale.dim() == 1);
|
||||
TORCH_CHECK(input_offset_by_experts.dim() == 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
|
||||
|
||||
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
||||
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
||||
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
|
||||
// output is uint8 (two nvfp4 values are packed into one uint8)
|
||||
// output_scale is int32 (four fp8 values are packed into one int32)
|
||||
TORCH_CHECK(output.scalar_type() == UINT8);
|
||||
TORCH_CHECK(output_scale.scalar_type() == INT);
|
||||
|
||||
const int BLOCK_SIZE = 16;
|
||||
auto m_topk = input.size(0);
|
||||
auto k = input.size(1);
|
||||
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output.size(0) == m_topk);
|
||||
TORCH_CHECK(output.size(1) == k / 2);
|
||||
int scales_k = k / BLOCK_SIZE;
|
||||
// 4 means the swizzle requirement by nvidia nvfp4.
|
||||
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
||||
// 4 means 4 fp8 values are packed into one int32
|
||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
|
||||
auto in_dtype = input.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
quant_impl<half>(
|
||||
output.data_ptr(),
|
||||
output_scale.data_ptr(),
|
||||
input.data_ptr(),
|
||||
input_global_scale.data_ptr(),
|
||||
input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(),
|
||||
m_topk,
|
||||
k,
|
||||
n_experts,
|
||||
stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
quant_impl<__nv_bfloat16>(
|
||||
output.data_ptr(),
|
||||
output_scale.data_ptr(),
|
||||
input.data_ptr(),
|
||||
input_global_scale.data_ptr(),
|
||||
input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(),
|
||||
m_topk,
|
||||
k,
|
||||
n_experts,
|
||||
stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,15 @@ limitations under the License.
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
void scaled_fp4_quant_sm100a(
|
||||
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf);
|
||||
|
||||
void scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
#endif
|
||||
|
||||
void scaled_fp4_quant(
|
||||
@@ -27,3 +36,17 @@ void scaled_fp4_quant(
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization");
|
||||
}
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
return scaled_fp4_experts_quant_sm100a(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
|
||||
}
|
||||
|
||||
471
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
Normal file
471
sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu
Normal file
@@ -0,0 +1,471 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <
|
||||
typename ElementAB,
|
||||
typename ElementC,
|
||||
typename ElementSF,
|
||||
typename ElementAccumulator,
|
||||
typename LayoutSFA,
|
||||
typename LayoutSFB,
|
||||
typename ScaleConfig>
|
||||
__global__ void __get_group_gemm_starts(
|
||||
ElementAB** a_offsets,
|
||||
ElementAB** b_offsets,
|
||||
ElementC** out_offsets,
|
||||
ElementSF** a_scales_offsets,
|
||||
ElementSF** b_scales_offsets,
|
||||
ElementAccumulator** alpha_offsets,
|
||||
LayoutSFA* layout_sfa_base_as_int,
|
||||
LayoutSFB* layout_sfb_base_as_int,
|
||||
ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int,
|
||||
ElementC* out_base_as_int,
|
||||
ElementSF* a_scales_base_as_int,
|
||||
ElementSF* b_scales_base_as_int,
|
||||
ElementAccumulator* alphas_base_as_int,
|
||||
const int32_t* expert_offsets,
|
||||
const int32_t* sf_offsets,
|
||||
const int32_t* problem_sizes_as_shapes,
|
||||
const int K,
|
||||
const int N) {
|
||||
int64_t expert_id = threadIdx.x;
|
||||
if (expert_id >= gridDim.x * blockDim.x) {
|
||||
return;
|
||||
}
|
||||
// Originally int32_t but upcasting to int64_t to avoid overflow
|
||||
// during offset calculations
|
||||
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
|
||||
int64_t sf_offset = static_cast<int64_t>(sf_offsets[expert_id]);
|
||||
// size for block in block scale.
|
||||
int64_t group_size = 16;
|
||||
int64_t m = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3]);
|
||||
int64_t n = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 1]);
|
||||
int64_t k = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 2]);
|
||||
assert((m >= 0 && n == N && k == K && k % 2 == 0) && "unexpected problem sizes");
|
||||
|
||||
int64_t half_k = static_cast<int64_t>(k / 2);
|
||||
int64_t group_k = static_cast<int64_t>(k / group_size);
|
||||
// Shape of A as uint8/byte = [M, K // 2]
|
||||
// Shape of B as uint8/byte = [E, N, K // 2]
|
||||
a_offsets[expert_id] = a_base_as_int + expert_offset * half_k;
|
||||
|
||||
b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k;
|
||||
// Shape of C = [M, N]
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
// Shape of a_scale = [sum(sf_sizes), K // group_size]
|
||||
a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k;
|
||||
|
||||
assert((reinterpret_cast<uintptr_t>(a_scales_offsets[expert_id]) % 128) == 0 && "TMA requires 128-byte alignment");
|
||||
|
||||
// Shape of B scale = [E, N, K // group_size]
|
||||
b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k;
|
||||
assert((reinterpret_cast<uintptr_t>(b_scales_offsets[expert_id]) % 128) == 0 && "TMA requires 128-byte alignment");
|
||||
// Shape of alpha = [E]
|
||||
alpha_offsets[expert_id] = alphas_base_as_int + expert_id;
|
||||
|
||||
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
|
||||
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
|
||||
|
||||
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(
|
||||
cute::make_shape(static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
|
||||
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(
|
||||
cute::make_shape(static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE( \
|
||||
ELEMENT_AB_TYPE, SF_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, LayoutSFA, LayoutSFB, ScaleConfig> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
|
||||
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_starts.data_ptr()), \
|
||||
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
|
||||
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
|
||||
static_cast<float**>(alpha_starts.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
|
||||
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
|
||||
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
|
||||
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
|
||||
static_cast<float*>(alphas.data_ptr()), \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<int32_t*>(sf_offsets.data_ptr()), \
|
||||
static_cast<int32_t*>(problem_sizes.data_ptr()), \
|
||||
K, \
|
||||
N); \
|
||||
}
|
||||
|
||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
void run_get_group_gemm_starts(
|
||||
const torch::Tensor& a_starts,
|
||||
const torch::Tensor& b_starts,
|
||||
const torch::Tensor& out_starts,
|
||||
const torch::Tensor& a_scales_starts,
|
||||
const torch::Tensor& b_scales_starts,
|
||||
const torch::Tensor& alpha_starts,
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
/*these are used for their base addresses*/
|
||||
torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& out_tensors,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& alphas,
|
||||
torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& sf_offsets,
|
||||
torch::Tensor const& problem_sizes,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
TORCH_CHECK(out_tensors.size(1) == N, "Output tensor shape doesn't match expected shape");
|
||||
TORCH_CHECK(
|
||||
K / 2 == b_tensors.size(2),
|
||||
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
|
||||
" dimension must match");
|
||||
if (false) {
|
||||
}
|
||||
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
|
||||
// ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
|
||||
cutlass::float_e2m1_t,
|
||||
cutlass::float_ue4m3_t,
|
||||
torch::kBFloat16,
|
||||
cutlass::bfloat16_t,
|
||||
LayoutSFA,
|
||||
LayoutSFB,
|
||||
ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
|
||||
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kFloat16, half, LayoutSFA, LayoutSFB, ScaleConfig)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void run_fp4_blockwise_scaled_group_mm(
|
||||
torch::Tensor& output,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale,
|
||||
const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& ab_strides,
|
||||
const torch::Tensor& c_strides,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& sf_offsets,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
||||
using ElementType = cutlass::float_e2m1_t;
|
||||
using ElementSFType = cutlass::float_ue4m3_t;
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
|
||||
using ElementC = OutType;
|
||||
using ElementD = ElementC;
|
||||
using ElementAccumulator = float;
|
||||
// Layout definitions
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = LayoutC;
|
||||
|
||||
// Alignment constraints
|
||||
static constexpr int AlignmentA = 32;
|
||||
static constexpr int AlignmentB = 32;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Architecture definitions
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
|
||||
using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
|
||||
// on the tile size
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
struct MMA1SMConfig {
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
|
||||
};
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
EpilogueOperatorClass,
|
||||
typename MMA1SMConfig::MmaTileShape,
|
||||
ClusterShape,
|
||||
Shape<_128, _64>,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementC,
|
||||
LayoutC*,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutC*,
|
||||
AlignmentD,
|
||||
typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
MainloopOperatorClass,
|
||||
ElementA,
|
||||
LayoutA*,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB*,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename MMA1SMConfig::MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename MMA1SMConfig::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
|
||||
|
||||
using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using Gemm = Gemm1SM;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
|
||||
using ScaleConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
||||
|
||||
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
alpha_ptrs,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
a_blockscale,
|
||||
b_blockscales,
|
||||
alphas,
|
||||
expert_offsets,
|
||||
sf_offsets,
|
||||
problem_sizes,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
|
||||
// Create an instance of the GEMM
|
||||
Gemm gemm_op;
|
||||
|
||||
// Initialize problem_sizes_as_shapes correctly
|
||||
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
|
||||
// Set the Scheduler info
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<
|
||||
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||
scheduler.raster_order = RasterOrderOptions::AlongM;
|
||||
hw_info.device_id = a.get_device();
|
||||
static std::unordered_map<int, int> cached_sm_counts;
|
||||
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
||||
cached_sm_counts[hw_info.device_id] =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
|
||||
|
||||
// Mainloop Arguments
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementType**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(ab_strides.data_ptr()),
|
||||
static_cast<const ElementType**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(ab_strides.data_ptr()),
|
||||
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
|
||||
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
|
||||
|
||||
// Epilogue Arguments
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, // epilogue.thread
|
||||
nullptr,
|
||||
static_cast<StrideC*>(c_strides.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||
auto& fusion_args = epilogue_args.thread;
|
||||
fusion_args.alpha_ptr_array = reinterpret_cast<float**>(alpha_ptrs.data_ptr());
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 1};
|
||||
|
||||
// Gemm Arguments
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info,
|
||||
scheduler};
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(args);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
|
||||
|
||||
// Run the GEMM
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
|
||||
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
|
||||
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
CHECK_TYPE(x, st, m)
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
torch::Tensor& output,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale,
|
||||
const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& ab_strides,
|
||||
const torch::Tensor& c_strides,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& sf_offsets) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
// Input validation
|
||||
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
|
||||
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
|
||||
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
|
||||
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
|
||||
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
|
||||
|
||||
TORCH_CHECK(
|
||||
a_blockscale.dim() == 2,
|
||||
"expected a_blockscale to be of shape [num_experts, rounded_m,"
|
||||
" k // group_size], observed rank: ",
|
||||
a_blockscale.dim())
|
||||
TORCH_CHECK(
|
||||
b_blockscales.dim() == 3,
|
||||
"expected b_blockscale to be of shape: "
|
||||
" [num_experts, n, k // group_size], observed rank: ",
|
||||
b_blockscales.dim())
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have the shape (num_experts, 3)");
|
||||
TORCH_CHECK(
|
||||
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32.");
|
||||
|
||||
int M = static_cast<int>(a.size(0));
|
||||
int N = static_cast<int>(b.size(1));
|
||||
int E = static_cast<int>(b.size(0));
|
||||
int K = static_cast<int>(2 * b.size(2));
|
||||
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
|
||||
output,
|
||||
a,
|
||||
b,
|
||||
a_blockscale,
|
||||
b_blockscales,
|
||||
alphas,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
sf_offsets,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
} else {
|
||||
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>(
|
||||
output,
|
||||
a,
|
||||
b,
|
||||
a_blockscale,
|
||||
b_blockscales,
|
||||
alphas,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
sf_offsets,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_fp4_group_mm kernel, sgl-kernel must "
|
||||
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
|
||||
"12.8 or above.");
|
||||
#endif
|
||||
}
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/array.h"
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes(
|
||||
@@ -11,9 +13,9 @@ __global__ void compute_problem_sizes(
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length,
|
||||
const int n,
|
||||
const int k) {
|
||||
const int64_t topk_length,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
int occurrences = 0;
|
||||
@@ -26,11 +28,11 @@ __global__ void compute_problem_sizes(
|
||||
if (threadIdx.x == 0) {
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes1[expert_id * 3 + 1] = static_cast<int32_t>(2 * n);
|
||||
problem_sizes1[expert_id * 3 + 2] = static_cast<int32_t>(k);
|
||||
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 1] = k;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
problem_sizes2[expert_id * 3 + 1] = static_cast<int32_t>(k);
|
||||
problem_sizes2[expert_id * 3 + 2] = static_cast<int32_t>(n);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +40,7 @@ __global__ void compute_expert_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1,
|
||||
int32_t* expert_offsets,
|
||||
int32_t* atomic_buffer,
|
||||
const int num_experts) {
|
||||
const int64_t num_experts) {
|
||||
int32_t tot_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
@@ -48,13 +50,34 @@ __global__ void compute_expert_offsets(
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_blockscale_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1,
|
||||
int32_t* expert_offsets,
|
||||
int32_t* blockscale_offsets,
|
||||
int32_t* atomic_buffer,
|
||||
const int64_t num_experts) {
|
||||
int32_t tot_offset = 0;
|
||||
int32_t tot_rounded_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
blockscale_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
int num_tokens = problem_sizes1[i * 3];
|
||||
int rounded_num_tokens = (num_tokens + (128 - 1)) / 128 * 128;
|
||||
tot_offset += num_tokens;
|
||||
tot_rounded_offset += rounded_num_tokens;
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
blockscale_offsets[i + 1] = tot_rounded_offset;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(
|
||||
const int* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ topk_ids,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length,
|
||||
const int topk) {
|
||||
const int64_t topk_length,
|
||||
const int64_t topk) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
@@ -69,6 +92,7 @@ __global__ void compute_arg_sorts(
|
||||
void get_moe_prepare_input_caller(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
@@ -80,8 +104,10 @@ void get_moe_prepare_input_caller(
|
||||
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
||||
uint32_t num_threads = static_cast<uint32_t>(min(THREADS_PER_EXPERT, topk_ids.numel()));
|
||||
uint32_t num_blocks = static_cast<uint32_t>(num_experts);
|
||||
|
||||
compute_problem_sizes<<<num_blocks, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
@@ -89,12 +115,21 @@ void get_moe_prepare_input_caller(
|
||||
topk_ids.numel(),
|
||||
n,
|
||||
k);
|
||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
num_experts);
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
if (blockscale_offsets.has_value()) {
|
||||
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
num_experts);
|
||||
} else {
|
||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
||||
num_experts);
|
||||
}
|
||||
compute_arg_sorts<<<num_blocks, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
@@ -106,6 +141,7 @@ void get_moe_prepare_input_caller(
|
||||
void prepare_moe_input(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
@@ -117,6 +153,7 @@ void prepare_moe_input(
|
||||
get_moe_prepare_input_caller(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
@@ -126,3 +163,92 @@ void prepare_moe_input(
|
||||
k);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void shuffleRowsKernel(
|
||||
const T* input,
|
||||
const int32_t* dst2src_map,
|
||||
T* output,
|
||||
int64_t num_src_rows,
|
||||
int64_t num_dst_rows,
|
||||
int64_t num_cols) {
|
||||
int64_t dest_row_idx = blockIdx.x;
|
||||
int64_t const source_row_idx = dst2src_map[dest_row_idx];
|
||||
|
||||
if (blockIdx.x < num_dst_rows) {
|
||||
// Load 128-bits per thread
|
||||
constexpr uint64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
|
||||
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
|
||||
|
||||
// Duplicate and permute rows
|
||||
auto const* source_row_ptr = reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
|
||||
auto* dest_row_ptr = reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);
|
||||
|
||||
auto const start_offset = threadIdx.x;
|
||||
auto const stride = blockDim.x;
|
||||
auto const num_elems_in_col = num_cols / ELEM_PER_THREAD;
|
||||
|
||||
for (auto elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
|
||||
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define DECLARE_SHUFFLE_ROWS(T) \
|
||||
__global__ void shuffleRowsKernel( \
|
||||
const T* input, \
|
||||
const int32_t* dst2src_map, \
|
||||
T* output, \
|
||||
int64_t num_src_rows, \
|
||||
int64_t num_dest_rows, \
|
||||
int64_t num_cols);
|
||||
|
||||
DECLARE_SHUFFLE_ROWS(float);
|
||||
DECLARE_SHUFFLE_ROWS(half);
|
||||
DECLARE_SHUFFLE_ROWS(__nv_bfloat16);
|
||||
DECLARE_SHUFFLE_ROWS(__nv_fp8_e4m3);
|
||||
DECLARE_SHUFFLE_ROWS(uint8_t);
|
||||
|
||||
#define SHUFFLE_ROWS(T) \
|
||||
shuffleRowsKernel<T><<<blocks, threads, 0, stream>>>( \
|
||||
reinterpret_cast<const T*>(input), \
|
||||
static_cast<const int32_t*>(dst2src_map.data_ptr()), \
|
||||
reinterpret_cast<T*>(output), \
|
||||
num_src_rows, \
|
||||
num_dst_rows, \
|
||||
num_cols)
|
||||
|
||||
#define DTYPE_DISPATCH_CASE(T, CUDA_T) \
|
||||
case T: \
|
||||
SHUFFLE_ROWS(CUDA_T); \
|
||||
break;
|
||||
|
||||
void shuffle_rows_caller(
|
||||
const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
|
||||
TORCH_CHECK(
|
||||
input_tensor.scalar_type() == output_tensor.scalar_type(),
|
||||
"Input and output tensors must have the same data type");
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
uint32_t blocks = static_cast<uint32_t>(output_tensor.size(0));
|
||||
uint32_t threads = 256;
|
||||
int64_t num_dst_rows = output_tensor.size(0);
|
||||
int64_t num_src_rows = input_tensor.size(0);
|
||||
int64_t num_cols = input_tensor.size(1);
|
||||
const void* input = input_tensor.data_ptr();
|
||||
void* output = output_tensor.data_ptr();
|
||||
switch (input_tensor.scalar_type()) {
|
||||
DTYPE_DISPATCH_CASE(torch::kFloat16, half);
|
||||
DTYPE_DISPATCH_CASE(torch::kBFloat16, __nv_bfloat16);
|
||||
DTYPE_DISPATCH_CASE(torch::kFloat32, float);
|
||||
DTYPE_DISPATCH_CASE(torch::kFloat8_e4m3fn, __nv_fp8_e4m3);
|
||||
DTYPE_DISPATCH_CASE(torch::kUInt8, uint8_t);
|
||||
default:
|
||||
TORCH_CHECK(false, "[moe replicate input] data type dispatch fail!");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
|
||||
shuffle_rows_caller(input_tensor, dst2src_map, output_tensor);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -232,6 +232,7 @@ void fp8_blockwise_scaled_grouped_mm(
|
||||
void prepare_moe_input(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
@@ -251,6 +252,29 @@ void ep_moe_pre_reorder(
|
||||
int64_t topk,
|
||||
bool use_per_token_if_dynamic);
|
||||
|
||||
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor);
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
torch::Tensor& output,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale,
|
||||
const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& ab_strides,
|
||||
const torch::Tensor& c_strides,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& sf_offsets);
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
|
||||
@@ -38,14 +38,17 @@ from sgl_kernel.gemm import (
|
||||
int8_scaled_mm,
|
||||
qserve_w4a8_per_chn_gemm,
|
||||
qserve_w4a8_per_group_gemm,
|
||||
scaled_fp4_experts_quant,
|
||||
scaled_fp4_quant,
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_group_quant_int8,
|
||||
sgl_per_token_quant_fp8,
|
||||
shuffle_rows,
|
||||
)
|
||||
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
||||
from sgl_kernel.moe import (
|
||||
cutlass_fp4_group_mm,
|
||||
ep_moe_pre_reorder,
|
||||
fp8_blockwise_scaled_grouped_mm,
|
||||
moe_align_block_size,
|
||||
|
||||
@@ -241,3 +241,80 @@ def qserve_w4a8_per_group_gemm(
|
||||
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
|
||||
)
|
||||
return out_feats
|
||||
|
||||
|
||||
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
|
||||
output_tensor = torch.empty(
|
||||
output_tensor_shape,
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def scaled_fp4_experts_quant(
|
||||
input_tensor: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
blockscale_offsets: torch.Tensor,
|
||||
topk: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
packed MoE Inputs.
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4
|
||||
expert_map: The expert map tensor
|
||||
input_global_scale: A scalar scaling factor for the entire tensor.
|
||||
expert_offsets: The expert offsets tensor
|
||||
blockscale_offsets: The blockscale offsets tensor
|
||||
Outputs:
|
||||
output: The quantized tensor in FP4
|
||||
output_scales: The blockscale tensor in FP8-E4M3
|
||||
"""
|
||||
assert (
|
||||
input_tensor.ndim == 2
|
||||
), f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
|
||||
if expert_map is not None:
|
||||
(m, k) = input_tensor.shape
|
||||
output_tensor_shape = (m * topk, k)
|
||||
input_tensor = shuffle_rows(input_tensor, expert_map, output_tensor_shape)
|
||||
m_numtopk, k = input_tensor.shape
|
||||
# Control the maximum number of tokens per expert supported by the
|
||||
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
|
||||
# from running out of memory. This value can also be increased to support
|
||||
# larger models.
|
||||
import os
|
||||
|
||||
MAX_TOKENS_PER_EXPERT = os.environ.get("MODELOPT_MAX_TOKENS_PER_EXPERT", 65536)
|
||||
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
|
||||
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
|
||||
f"{MAX_TOKENS_PER_EXPERT})"
|
||||
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
|
||||
f" MODELOPT_MAX_TOKENS_PER_EXPERT to set this value."
|
||||
)
|
||||
scales_k = k // 16
|
||||
padded_k = (scales_k + (4 - 1)) // 4
|
||||
|
||||
# output is uint8 and packed fp4 values
|
||||
output = torch.empty(
|
||||
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
|
||||
)
|
||||
output_scales = torch.empty(
|
||||
MAX_TOKENS_PER_EXPERT * topk,
|
||||
padded_k,
|
||||
dtype=torch.int32,
|
||||
device=input_tensor.device,
|
||||
)
|
||||
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
|
||||
output,
|
||||
output_scales,
|
||||
input_tensor,
|
||||
input_global_scale,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
)
|
||||
output_scales = output_scales.view(torch.float8_e4m3fn)
|
||||
return output, output_scales
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -138,10 +140,12 @@ def prepare_moe_input(
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
blockscale_offsets: Optional[torch.Tensor] = None,
|
||||
):
|
||||
torch.ops.sgl_kernel.prepare_moe_input.default(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
@@ -150,3 +154,54 @@ def prepare_moe_input(
|
||||
n,
|
||||
k,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_fp4_group_mm(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_blockscale,
|
||||
b_blockscale,
|
||||
alphas,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
out_dtype,
|
||||
device,
|
||||
):
|
||||
"""
|
||||
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
||||
the gemms for each combination based on the specified problem sizes.
|
||||
|
||||
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
|
||||
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
|
||||
input and expert weights.
|
||||
- a_/b_scales: The blockscales in FP8-E4M3 precision
|
||||
- ab_strides/c_strides: Strides for the a/b tensors between rows.
|
||||
- expert_offsets/sf_offsets: Indices that mark at which token index
|
||||
each expert begins its computation. The number of tokens
|
||||
computed with expert E is expert_offsets[E + 1] -
|
||||
expert_offsets[E] And the sf_size per expert is
|
||||
sf_offset[E+1] - sf_offset[E]
|
||||
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
||||
MMs used in the fused MoE operation.
|
||||
"""
|
||||
m_topk = a_fp4.shape[0]
|
||||
n = b_fp4.shape[1]
|
||||
c_shape = (m_topk, n)
|
||||
c = torch.empty(c_shape, device=device, dtype=out_dtype)
|
||||
torch.ops.sgl_kernel.cutlass_fp4_group_mm.default(
|
||||
c,
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_blockscale,
|
||||
b_blockscale,
|
||||
alphas,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
)
|
||||
return c.to(dtype=out_dtype)
|
||||
|
||||
Reference in New Issue
Block a user