add flashinfer mxfp4 (#8847)
This commit is contained in:
@@ -38,6 +38,7 @@ from sglang.srt.utils import (
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
next_power_of_2,
|
||||
round_up,
|
||||
)
|
||||
|
||||
if is_flashinfer_available():
|
||||
@@ -146,7 +147,6 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.top_k = top_k
|
||||
self.hidden_size = hidden_size
|
||||
self.num_experts = num_experts
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.expert_map_cpu = None
|
||||
@@ -206,6 +206,16 @@ class FusedMoE(torch.nn.Module):
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_config = quant_config
|
||||
if (
|
||||
self.quant_config is not None
|
||||
and self.quant_config.get_name() == "mxfp4"
|
||||
and (
|
||||
get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE")
|
||||
or get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE")
|
||||
)
|
||||
):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
self.hidden_size = hidden_size
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
num_experts=self.num_local_experts,
|
||||
@@ -784,6 +794,14 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
||||
origin_hidden_states_dim = hidden_states.shape[-1]
|
||||
if self.hidden_size != origin_hidden_states_dim:
|
||||
hidden_states = torch.nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, self.hidden_size - origin_hidden_states_dim),
|
||||
mode="constant",
|
||||
value=0.0,
|
||||
)
|
||||
assert self.quant_method is not None
|
||||
|
||||
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
||||
@@ -829,7 +847,7 @@ class FusedMoE(torch.nn.Module):
|
||||
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
return final_hidden_states[..., :origin_hidden_states_dim].contiguous()
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
|
||||
@@ -21,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
@@ -31,6 +32,12 @@ from sglang.srt.utils import (
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
|
||||
# Environment variables for FlashInfer MXFP4 MoE backend
|
||||
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE", "false")
|
||||
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
|
||||
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
|
||||
)
|
||||
|
||||
if is_flashinfer_available():
|
||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
||||
from flashinfer import (
|
||||
@@ -228,16 +235,28 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self.num_experts = num_experts
|
||||
weight_dtype = torch.uint8
|
||||
scale_dtype = torch.uint8
|
||||
|
||||
intermediate_size *= 2
|
||||
mxfp4_block = 32
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
elif is_hip():
|
||||
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 128)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 64)
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2, dtype=weight_dtype
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -247,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
@@ -257,7 +276,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_weight_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, 2 * intermediate_size, dtype=torch.bfloat16),
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
||||
@@ -266,7 +289,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts, hidden_size, intermediate_size // 2, dtype=weight_dtype
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -277,7 +303,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size // mxfp4_block,
|
||||
intermediate_size_per_partition_after_pad // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -293,6 +319,158 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
logger.info(
|
||||
"Shuffling MoE weights for FlashInfer, it might take a while..."
|
||||
)
|
||||
layer.gemm1_alpha = Parameter(
|
||||
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.gemm1_beta = Parameter(
|
||||
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.gemm1_clamp_limit = Parameter(
|
||||
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False,
|
||||
)
|
||||
sf_block_size = 32 # mxfp4 block size
|
||||
|
||||
assert (
|
||||
layer.w13_weight.dim() == 3
|
||||
and layer.w13_weight.shape[0] == self.num_experts
|
||||
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight.shape[2] == self.hidden_size // 2
|
||||
)
|
||||
assert (
|
||||
layer.w13_weight_scale.dim() == 3
|
||||
and layer.w13_weight_scale.shape[0] == self.num_experts
|
||||
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
layer.w2_weight.dim() == 3
|
||||
and layer.w2_weight.shape[0] == self.num_experts
|
||||
and layer.w2_weight.shape[1] == self.hidden_size
|
||||
and layer.w2_weight.shape[2] == self.intermediate_size // 2
|
||||
)
|
||||
assert (
|
||||
layer.w2_weight_scale.dim() == 3
|
||||
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
||||
and layer.w2_weight_scale.shape[2]
|
||||
== self.intermediate_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
layer.w13_weight_bias.dim() == 2
|
||||
and layer.w13_weight_bias.shape[0] == self.num_experts
|
||||
and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2
|
||||
)
|
||||
assert (
|
||||
layer.w2_weight_bias.dim() == 2
|
||||
and layer.w2_weight_bias.shape[0] == self.num_experts
|
||||
and layer.w2_weight_bias.shape[1] == self.hidden_size
|
||||
)
|
||||
|
||||
w13_weight_scale = layer.w13_weight_scale.data
|
||||
w2_weight_scale = layer.w2_weight_scale.data
|
||||
w13_weight = layer.w13_weight.data
|
||||
w2_weight = layer.w2_weight.data
|
||||
w13_bias = layer.w13_weight_bias.data.to(torch.float32)
|
||||
w2_bias = layer.w2_weight_bias.data.to(torch.float32)
|
||||
|
||||
# Swap w1 and w3 as the definition of
|
||||
# swiglu is different in the trtllm-gen
|
||||
def swap_every_two_rows(x, axis=-1):
|
||||
shape = x.shape
|
||||
if axis < 0:
|
||||
axis = len(shape) + axis
|
||||
|
||||
# Create a new shape with pairs swapped along specified axis
|
||||
new_shape = list(shape)
|
||||
new_shape[axis] = shape[axis] // 2
|
||||
new_shape.insert(axis + 1, 2)
|
||||
|
||||
# Reshape to expose pairs, swap them, and reshape back
|
||||
x = x.reshape(*new_shape)
|
||||
x = x.flip(axis + 1)
|
||||
new_shape = list(shape)
|
||||
return x.reshape(*new_shape)
|
||||
|
||||
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
|
||||
w13_weight = swap_every_two_rows(w13_weight, -2)
|
||||
w13_bias = swap_every_two_rows(w13_bias, -1)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_mxfp4_shuffled = []
|
||||
gemm1_scales_mxfp4_shuffled = []
|
||||
gemm2_weights_mxfp4_shuffled = []
|
||||
gemm2_scales_mxfp4_shuffled = []
|
||||
gemm1_bias_shuffled = []
|
||||
gemm2_bias_shuffled = []
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
for i in range(self.num_experts):
|
||||
gemm1_weights_mxfp4_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
gemm1_scales_mxfp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(
|
||||
w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
|
||||
gemm2_weights_mxfp4_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
gemm2_scales_mxfp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m)
|
||||
)
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
|
||||
w13_weight_scale = (
|
||||
torch.stack(gemm1_scales_mxfp4_shuffled)
|
||||
.reshape(
|
||||
self.num_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size // sf_block_size,
|
||||
)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
|
||||
w2_weight_scale = (
|
||||
torch.stack(gemm2_scales_mxfp4_shuffled)
|
||||
.reshape(
|
||||
self.num_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size // sf_block_size,
|
||||
)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
|
||||
layer.w13_weight_bias = Parameter(
|
||||
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_weight_bias = Parameter(
|
||||
torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
|
||||
requires_grad=False,
|
||||
)
|
||||
return
|
||||
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
@@ -366,22 +544,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
# avoid import error when triton_kernel is not installed
|
||||
# from vllm.model_executor.layers.fused_moe.triton_kernels_moe import (
|
||||
# triton_kernel_moe_forward)
|
||||
|
||||
"""
|
||||
if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE
|
||||
or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE):
|
||||
assert not self.moe.use_ep, (
|
||||
"EP is not supported for flashinfer mxfp4 moe backend yet.")
|
||||
if envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
# When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input,
|
||||
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
|
||||
# which can theoretically improve performance
|
||||
if USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
else:
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
top_k = topk_weights.shape[-1]
|
||||
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
None, # routing_bias
|
||||
@@ -412,7 +589,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
True, # do finalize
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
"""
|
||||
|
||||
if self.use_triton_kernels:
|
||||
if self.with_bias:
|
||||
|
||||
@@ -464,7 +464,21 @@ class ServerArgs:
|
||||
model_arch = self.get_hf_config().architectures[0]
|
||||
if model_arch in ["GptOssForCausalLM"]:
|
||||
self.attention_backend = "triton"
|
||||
self.enable_triton_kernel_moe = True
|
||||
|
||||
# Check if FlashInfer MXFP4 MoE is enabled
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var(
|
||||
"SGLANG_USE_FLASHINFER_MXFP4_MOE", "false"
|
||||
)
|
||||
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
|
||||
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
|
||||
)
|
||||
|
||||
# Only enable Triton kernel MoE if FlashInfer is not enabled
|
||||
if not (USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE):
|
||||
self.enable_triton_kernel_moe = True
|
||||
|
||||
self.disable_hybrid_swa_memory = True
|
||||
|
||||
quantization_config = getattr(
|
||||
|
||||
Reference in New Issue
Block a user