[CUTLASS-FP4-MOE] Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (#6887)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -18,11 +19,12 @@ if _is_cuda:
|
||||
fp8_blockwise_scaled_grouped_mm,
|
||||
prepare_moe_input,
|
||||
scaled_fp4_experts_quant,
|
||||
shuffle_rows,
|
||||
silu_and_mul,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_fused_experts(
|
||||
def cutlass_fused_experts_fp8(
|
||||
a: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
@@ -223,17 +225,10 @@ def cutlass_moe_fp4(
|
||||
w2_fp4: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alphas: torch.Tensor,
|
||||
ab_strides_13: torch.Tensor,
|
||||
ab_strides_2: torch.Tensor,
|
||||
c_strides_13: torch.Tensor,
|
||||
c_strides_2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
params: CutlassMoEParams,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
):
|
||||
"""
|
||||
MoE implementation for FP4 Inputs
|
||||
@@ -291,77 +286,70 @@ def cutlass_moe_fp4(
|
||||
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
|
||||
e_w2, k_w2, half_n_w2 = w2_fp4.shape
|
||||
|
||||
assert e_w1 == e_w2 and e_w1 == e, (
|
||||
assert e_w1 == e_w2 and e_w1 == params.num_experts, (
|
||||
"Number of experts must match",
|
||||
" between weights.",
|
||||
)
|
||||
assert (
|
||||
k_a // 2 == half_k_w1 and k == k_w2
|
||||
k_a // 2 == half_k_w1 and params.hidden_size == k_w2
|
||||
), "Hidden size mismatch between a, w1 and w2"
|
||||
assert nx2_w1 == n * 2 and half_n_w2 == n // 2, "mismatch in " "expected `n`"
|
||||
assert m == m_a, "input shape mismatch"
|
||||
assert (
|
||||
nx2_w1 == params.intermediate_size_per_partition * 2
|
||||
and half_n_w2 == params.intermediate_size_per_partition // 2
|
||||
), ("mismatch in " "expected `n`")
|
||||
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
|
||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
||||
assert (
|
||||
topk_weights.shape[0] == m and topk_ids.shape[0] == m
|
||||
), "topk must be provided for each row of a"
|
||||
|
||||
out_dtype = a.dtype
|
||||
num_topk = topk_ids.shape[1]
|
||||
|
||||
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
|
||||
# Problem size: (num_experts, (m,2n,k))
|
||||
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
|
||||
# Problem size: (num_experts, (m,n,k))
|
||||
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)
|
||||
|
||||
device = a.device
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
|
||||
# problem shapes should have [m, n, k]
|
||||
# Note that problem sizes are based on logical number of elements.
|
||||
blockscale_offsets = torch.empty(e + 1, dtype=torch.int32, device=device)
|
||||
prepare_moe_input(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
params.expert_offsets,
|
||||
params.problem_sizes1,
|
||||
params.problem_sizes2,
|
||||
a_map,
|
||||
c_map,
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
blockscale_offsets,
|
||||
params.num_experts,
|
||||
params.intermediate_size_per_partition,
|
||||
params.hidden_size,
|
||||
params.blockscale_offsets,
|
||||
)
|
||||
|
||||
rep_a_fp4, rep_a_blockscale = scaled_fp4_experts_quant(
|
||||
a, a1_gscale, expert_offsets, blockscale_offsets, num_topk, expert_map=a_map
|
||||
a,
|
||||
a1_gscale,
|
||||
params.expert_offsets,
|
||||
params.blockscale_offsets,
|
||||
num_topk,
|
||||
expert_map=a_map,
|
||||
)
|
||||
|
||||
c1 = cutlass_fp4_group_mm(
|
||||
rep_a_fp4,
|
||||
w1_fp4,
|
||||
rep_a_blockscale,
|
||||
w1_blockscale,
|
||||
w1_alphas,
|
||||
ab_strides_13,
|
||||
c_strides_13,
|
||||
problem_sizes1,
|
||||
expert_offsets[:-1],
|
||||
blockscale_offsets[:-1],
|
||||
out_dtype,
|
||||
device,
|
||||
params.to_gemm1_args(),
|
||||
)
|
||||
del rep_a_fp4, rep_a_blockscale
|
||||
|
||||
# hidden size dimension is split to one halfpytho sized tensor.
|
||||
intermediate = torch.empty(
|
||||
(m * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype
|
||||
(m_a * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype
|
||||
)
|
||||
|
||||
silu_and_mul(c1, intermediate)
|
||||
|
||||
int_fp4, int_blockscale = scaled_fp4_experts_quant(
|
||||
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk
|
||||
intermediate,
|
||||
a2_gscale,
|
||||
params.expert_offsets,
|
||||
params.blockscale_offsets,
|
||||
num_topk,
|
||||
)
|
||||
c2 = cutlass_fp4_group_mm(
|
||||
int_fp4,
|
||||
@@ -369,16 +357,13 @@ def cutlass_moe_fp4(
|
||||
int_blockscale,
|
||||
w2_blockscale,
|
||||
w2_alphas,
|
||||
ab_strides_2,
|
||||
c_strides_2,
|
||||
problem_sizes2,
|
||||
expert_offsets[:-1],
|
||||
blockscale_offsets[:-1],
|
||||
out_dtype,
|
||||
device,
|
||||
params.to_gemm2_args(),
|
||||
)
|
||||
del int_fp4, int_blockscale
|
||||
out = (
|
||||
c2[c_map].view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half()
|
||||
).sum(dim=1)
|
||||
return out.to(dtype=out_dtype)
|
||||
c2 = shuffle_rows(c2, c_map, (m_a * num_topk, params.hidden_size))
|
||||
c2 = c2.view(m_a, num_topk, params.hidden_size)
|
||||
if not apply_router_weight_on_input:
|
||||
c2 = c2 * topk_weights.view(m_a, num_topk, 1).to(out_dtype)
|
||||
return c2.sum(dim=1).to(out_dtype)
|
||||
|
||||
169
python/sglang/srt/layers/moe/cutlass_moe_params.py
Normal file
169
python/sglang/srt/layers/moe/cutlass_moe_params.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class CutlassMoEType(Enum):
|
||||
"""
|
||||
Enum for the different types of cutlass moe operations
|
||||
that are currently supported in SGLang.
|
||||
"""
|
||||
|
||||
BlockscaledFP8 = auto()
|
||||
BlockscaledFP4 = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CutlassMoEParams:
|
||||
"""
|
||||
Parameters for the cutlass moe operation.
|
||||
"""
|
||||
|
||||
# Type as defined above
|
||||
cutlass_moe_type: CutlassMoEType
|
||||
|
||||
# Strides for activations, weights and output in logical number of elements.
|
||||
# The activations & output stride is the number of elements to the next row.
|
||||
# The weights stride is the number of elements to the next row per expert.
|
||||
# For example, if the weight is [e, n, k], then the b_stride is a tensor of
|
||||
# shape [e] with each element being k. Similarly for activations, if the
|
||||
# shape is [m, k], then the a_stride has shape [e] with each value k.
|
||||
# Similarly for output, if the output is [m, n], then the c_stride is a
|
||||
# tensor of shape [e] with each element being k.
|
||||
|
||||
# Note: cutlass_fp4_group_mm is designed to accept the strides of
|
||||
# activations and weights to be the same, so it is passed in as a single
|
||||
# tensor.
|
||||
# ab_strides_13: [e] dtype: int64 [Gemm 1: Activation / Weight strides]
|
||||
# ab_strides_2: [e] dtype: int64 [Gemm 2: Activation / Weight strides]
|
||||
# c_strides_13: [e] dtype: int64 [Gemm 1: Output Strides]
|
||||
# c_strides_2: [e] dtype: int64 [Gemm 2: Output Strides]
|
||||
ab_strides_13: torch.Tensor
|
||||
ab_strides_2: torch.Tensor
|
||||
c_strides_13: torch.Tensor
|
||||
c_strides_2: torch.Tensor
|
||||
|
||||
# m: Total number of tokens
|
||||
# n: intermediate size per partition
|
||||
# k: hidden size per expert
|
||||
# e: Number of experts
|
||||
# device: Device to run computation on and store tensors
|
||||
m: int
|
||||
intermediate_size_per_partition: int
|
||||
hidden_size: int
|
||||
num_experts: int
|
||||
device: torch.device
|
||||
|
||||
# Pointers container for calculating offsets of the input activations for each expert
|
||||
# a_ptrs: [e] dtype: int64
|
||||
a_ptrs: torch.Tensor
|
||||
|
||||
# Pointers container for calculating offsets of the input weights for each expert
|
||||
# b_ptrs: [e] dtype: int64
|
||||
b_ptrs: torch.Tensor
|
||||
|
||||
# Pointers container for calculating offsets of the output activations for each expert
|
||||
# out_ptrs: [e] dtype: int64
|
||||
out_ptrs: torch.Tensor
|
||||
# Pointers container for calculating offsets of the input scales for each expert
|
||||
# a_scales_ptrs: [e] dtype: int64
|
||||
# b_scales_ptrs: [e] dtype: int64
|
||||
a_scales_ptrs: torch.Tensor
|
||||
b_scales_ptrs: torch.Tensor
|
||||
|
||||
# Offsets that mark at which token index each expert begins its computation
|
||||
# The number of tokens computed with expert E is expert_offsets[E + 1] - expert_offsets[E]
|
||||
# expert_offsets: [e+1] dtype: int32
|
||||
expert_offsets: torch.Tensor
|
||||
|
||||
# Problem size: (num_experts, (m,2n,k)) for first GEMM
|
||||
# problem_sizes1: [e, 3] dtype: int32
|
||||
# Problem size: (num_experts, (m,n,k)) for second GEMM
|
||||
# problem_sizes2: [e, 3] dtype: int32
|
||||
problem_sizes1: torch.Tensor
|
||||
problem_sizes2: torch.Tensor
|
||||
# Similar to expert_offsets, but for blockscales for FP4 blockscaled Group GEMM
|
||||
blockscale_offsets: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cutlass_moe_type: CutlassMoEType,
|
||||
device: torch.device,
|
||||
num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_size: int,
|
||||
):
|
||||
self.cutlass_moe_type = cutlass_moe_type
|
||||
self.device = device
|
||||
self.num_experts = num_experts
|
||||
self.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
self.hidden_size = hidden_size
|
||||
self.n = self.intermediate_size_per_partition
|
||||
self.k = self.hidden_size
|
||||
self.e = self.num_experts
|
||||
self.ab_strides_13 = torch.full(
|
||||
(self.e,), self.k, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.ab_strides_2 = torch.full(
|
||||
(self.e,), self.n, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.c_strides_13 = torch.full(
|
||||
(self.e,), 2 * self.n, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.c_strides_2 = torch.full(
|
||||
(self.e,), self.k, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.expert_offsets = torch.empty(
|
||||
(self.e + 1,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.problem_sizes1 = torch.empty(
|
||||
(self.e, 3), dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.problem_sizes2 = torch.empty(
|
||||
(self.e, 3), dtype=torch.int32, device=self.device
|
||||
)
|
||||
if self.cutlass_moe_type == CutlassMoEType.BlockscaledFP4:
|
||||
self.blockscale_offsets = torch.empty(
|
||||
(self.e + 1,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
else:
|
||||
self.blockscale_offsets = None
|
||||
self.a_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)
|
||||
self.b_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)
|
||||
self.out_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)
|
||||
self.a_scales_ptrs = torch.empty(
|
||||
(self.e,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.b_scales_ptrs = torch.empty(
|
||||
(self.e,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
|
||||
def to_gemm1_args(self) -> dict:
|
||||
return {
|
||||
"ab_strides": self.ab_strides_13,
|
||||
"c_strides": self.c_strides_13,
|
||||
"problem_sizes": self.problem_sizes1,
|
||||
"expert_offsets": self.expert_offsets[:-1],
|
||||
"blockscale_offsets": self.blockscale_offsets[:-1],
|
||||
# "a_ptrs": self.a_ptrs,
|
||||
# "b_ptrs": self.b_ptrs,
|
||||
# "out_ptrs": self.out_ptrs,
|
||||
# "a_scales_ptrs": self.a_scales_ptrs,
|
||||
# "b_scales_ptrs": self.b_scales_ptrs,
|
||||
}
|
||||
|
||||
def to_gemm2_args(self) -> dict:
|
||||
return {
|
||||
"ab_strides": self.ab_strides_2,
|
||||
"c_strides": self.c_strides_2,
|
||||
"problem_sizes": self.problem_sizes2,
|
||||
"expert_offsets": self.expert_offsets[:-1],
|
||||
"blockscale_offsets": self.blockscale_offsets[:-1],
|
||||
# "a_ptrs": self.a_ptrs,
|
||||
# "b_ptrs": self.b_ptrs,
|
||||
# "out_ptrs": self.out_ptrs,
|
||||
# "a_scales_ptrs": self.a_scales_ptrs,
|
||||
# "b_scales_ptrs": self.b_scales_ptrs,
|
||||
}
|
||||
@@ -982,9 +982,9 @@ class Fp8MoEMethod:
|
||||
and self.block_quant
|
||||
and is_sm100_supported()
|
||||
):
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
||||
|
||||
return cutlass_fused_experts(
|
||||
return cutlass_fused_experts_fp8(
|
||||
x,
|
||||
layer.w13_weight.transpose(1, 2),
|
||||
layer.w2_weight.transpose(1, 2),
|
||||
|
||||
@@ -6,7 +6,7 @@ import triton # Added import
|
||||
import triton.testing # Added import
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda")
|
||||
|
||||
# --- Lambdas for Benchmarking ---
|
||||
cutlass_lambda = lambda: cutlass_fused_experts(
|
||||
cutlass_lambda = lambda: cutlass_fused_experts_fp8(
|
||||
x,
|
||||
w1.transpose(1, 2), # Transposed
|
||||
w2.transpose(1, 2), # Transposed
|
||||
@@ -193,7 +193,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
print("Running correctness check...")
|
||||
with torch.no_grad():
|
||||
# Run CUTLASS version (requires transposed weights)
|
||||
y_cutlass = cutlass_fused_experts(
|
||||
y_cutlass = cutlass_fused_experts_fp8(
|
||||
x,
|
||||
w1.transpose(1, 2), # Transposed
|
||||
w2.transpose(1, 2), # Transposed
|
||||
|
||||
@@ -5,6 +5,7 @@ from sgl_kernel import scaled_fp4_quant
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
if torch.cuda.get_device_capability() < (10, 0):
|
||||
@@ -179,6 +180,13 @@ def test_cutlass_fp4_moe_no_graph(
|
||||
(e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device
|
||||
)
|
||||
c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device)
|
||||
params = CutlassMoEParams(
|
||||
CutlassMoEType.BlockscaledFP4,
|
||||
device=a.device,
|
||||
num_experts=e,
|
||||
intermediate_size_per_partition=n, # n
|
||||
hidden_size=k,
|
||||
) # k
|
||||
cutlass_output = cutlass_moe_fp4(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
@@ -189,17 +197,10 @@ def test_cutlass_fp4_moe_no_graph(
|
||||
w2_fp4=w2_q,
|
||||
w2_blockscale=w2_blockscale,
|
||||
w2_alphas=(1 / w2_gs),
|
||||
ab_strides_13=ab_strides_13,
|
||||
ab_strides_2=ab_strides_2,
|
||||
c_strides_13=c_strides_13,
|
||||
c_strides_2=c_strides_2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=e,
|
||||
device=a.device,
|
||||
params=params,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
|
||||
Reference in New Issue
Block a user