add flashinfer mxfp4 (#8847)
This commit is contained in:
@@ -21,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
@@ -31,6 +32,12 @@ from sglang.srt.utils import (
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
|
||||
# Environment variables for FlashInfer MXFP4 MoE backend
|
||||
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE", "false")
|
||||
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
|
||||
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
|
||||
)
|
||||
|
||||
if is_flashinfer_available():
|
||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
||||
from flashinfer import (
|
||||
@@ -228,16 +235,28 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self.num_experts = num_experts
|
||||
weight_dtype = torch.uint8
|
||||
scale_dtype = torch.uint8
|
||||
|
||||
intermediate_size *= 2
|
||||
mxfp4_block = 32
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
elif is_hip():
|
||||
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 128)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 64)
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2, dtype=weight_dtype
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -247,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
@@ -257,7 +276,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_weight_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, 2 * intermediate_size, dtype=torch.bfloat16),
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
||||
@@ -266,7 +289,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts, hidden_size, intermediate_size // 2, dtype=weight_dtype
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -277,7 +303,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size // mxfp4_block,
|
||||
intermediate_size_per_partition_after_pad // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -293,6 +319,158 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
logger.info(
|
||||
"Shuffling MoE weights for FlashInfer, it might take a while..."
|
||||
)
|
||||
layer.gemm1_alpha = Parameter(
|
||||
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.gemm1_beta = Parameter(
|
||||
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.gemm1_clamp_limit = Parameter(
|
||||
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False,
|
||||
)
|
||||
sf_block_size = 32 # mxfp4 block size
|
||||
|
||||
assert (
|
||||
layer.w13_weight.dim() == 3
|
||||
and layer.w13_weight.shape[0] == self.num_experts
|
||||
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight.shape[2] == self.hidden_size // 2
|
||||
)
|
||||
assert (
|
||||
layer.w13_weight_scale.dim() == 3
|
||||
and layer.w13_weight_scale.shape[0] == self.num_experts
|
||||
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
layer.w2_weight.dim() == 3
|
||||
and layer.w2_weight.shape[0] == self.num_experts
|
||||
and layer.w2_weight.shape[1] == self.hidden_size
|
||||
and layer.w2_weight.shape[2] == self.intermediate_size // 2
|
||||
)
|
||||
assert (
|
||||
layer.w2_weight_scale.dim() == 3
|
||||
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
||||
and layer.w2_weight_scale.shape[2]
|
||||
== self.intermediate_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
layer.w13_weight_bias.dim() == 2
|
||||
and layer.w13_weight_bias.shape[0] == self.num_experts
|
||||
and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2
|
||||
)
|
||||
assert (
|
||||
layer.w2_weight_bias.dim() == 2
|
||||
and layer.w2_weight_bias.shape[0] == self.num_experts
|
||||
and layer.w2_weight_bias.shape[1] == self.hidden_size
|
||||
)
|
||||
|
||||
w13_weight_scale = layer.w13_weight_scale.data
|
||||
w2_weight_scale = layer.w2_weight_scale.data
|
||||
w13_weight = layer.w13_weight.data
|
||||
w2_weight = layer.w2_weight.data
|
||||
w13_bias = layer.w13_weight_bias.data.to(torch.float32)
|
||||
w2_bias = layer.w2_weight_bias.data.to(torch.float32)
|
||||
|
||||
# Swap w1 and w3 as the definition of
|
||||
# swiglu is different in the trtllm-gen
|
||||
def swap_every_two_rows(x, axis=-1):
|
||||
shape = x.shape
|
||||
if axis < 0:
|
||||
axis = len(shape) + axis
|
||||
|
||||
# Create a new shape with pairs swapped along specified axis
|
||||
new_shape = list(shape)
|
||||
new_shape[axis] = shape[axis] // 2
|
||||
new_shape.insert(axis + 1, 2)
|
||||
|
||||
# Reshape to expose pairs, swap them, and reshape back
|
||||
x = x.reshape(*new_shape)
|
||||
x = x.flip(axis + 1)
|
||||
new_shape = list(shape)
|
||||
return x.reshape(*new_shape)
|
||||
|
||||
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
|
||||
w13_weight = swap_every_two_rows(w13_weight, -2)
|
||||
w13_bias = swap_every_two_rows(w13_bias, -1)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_mxfp4_shuffled = []
|
||||
gemm1_scales_mxfp4_shuffled = []
|
||||
gemm2_weights_mxfp4_shuffled = []
|
||||
gemm2_scales_mxfp4_shuffled = []
|
||||
gemm1_bias_shuffled = []
|
||||
gemm2_bias_shuffled = []
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
for i in range(self.num_experts):
|
||||
gemm1_weights_mxfp4_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
gemm1_scales_mxfp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(
|
||||
w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
|
||||
gemm2_weights_mxfp4_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
gemm2_scales_mxfp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m)
|
||||
)
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
|
||||
w13_weight_scale = (
|
||||
torch.stack(gemm1_scales_mxfp4_shuffled)
|
||||
.reshape(
|
||||
self.num_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size // sf_block_size,
|
||||
)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
|
||||
w2_weight_scale = (
|
||||
torch.stack(gemm2_scales_mxfp4_shuffled)
|
||||
.reshape(
|
||||
self.num_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size // sf_block_size,
|
||||
)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
|
||||
layer.w13_weight_bias = Parameter(
|
||||
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_weight_bias = Parameter(
|
||||
torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
|
||||
requires_grad=False,
|
||||
)
|
||||
return
|
||||
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
@@ -366,22 +544,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
# avoid import error when triton_kernel is not installed
|
||||
# from vllm.model_executor.layers.fused_moe.triton_kernels_moe import (
|
||||
# triton_kernel_moe_forward)
|
||||
|
||||
"""
|
||||
if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE
|
||||
or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE):
|
||||
assert not self.moe.use_ep, (
|
||||
"EP is not supported for flashinfer mxfp4 moe backend yet.")
|
||||
if envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
# When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input,
|
||||
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
|
||||
# which can theoretically improve performance
|
||||
if USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
else:
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
top_k = topk_weights.shape[-1]
|
||||
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
None, # routing_bias
|
||||
@@ -412,7 +589,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
True, # do finalize
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
"""
|
||||
|
||||
if self.use_triton_kernels:
|
||||
if self.with_bias:
|
||||
|
||||
Reference in New Issue
Block a user