[NVIDIA] Add Low Latency NVFP4 decode kernels from Flashinfer (#8552)
Co-authored-by: Cheng Wan <cwan@x.ai>
This commit is contained in:
@@ -14,13 +14,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
silu_and_mul_masked_post_quant_fwd,
|
||||
tma_align_input_scale,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||
FlashInferFusedMoE,
|
||||
FusedMoE,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8 import (
|
||||
@@ -48,7 +44,6 @@ _is_npu = is_npu()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
|
||||
if not (_is_npu or _is_hip):
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
@@ -741,6 +736,22 @@ class FlashInferEPMoE(EPMoE):
|
||||
def get_moe_impl_class():
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
return DeepEPMoE
|
||||
|
||||
# NEW: Direct FP4 detection (bypasses EP requirements)
|
||||
# Check for FP4 quantization with TRTLLM flag, regardless of EP
|
||||
if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
|
||||
try:
|
||||
# Check the quantization argument directly
|
||||
quantization = global_server_args_dict.get("quantization")
|
||||
if quantization == "modelopt_fp4":
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||
FlashInferFP4MoE,
|
||||
)
|
||||
|
||||
return FlashInferFP4MoE
|
||||
except:
|
||||
pass
|
||||
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
||||
return FusedMoE
|
||||
if get_moe_expert_parallel_world_size() > 1:
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
||||
|
||||
import importlib.util
|
||||
import datetime
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from packaging import version as pkg_version
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_rank,
|
||||
@@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||
)
|
||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@@ -29,22 +31,58 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
||||
from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
is_cpu,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
next_power_of_2,
|
||||
)
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import (
|
||||
RoutingMethodType,
|
||||
fp4_quantize,
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
|
||||
# Try to import FP4 TRTLLM function if flashinfer is available
|
||||
trtllm_fp4_block_scale_moe = None
|
||||
if should_use_flashinfer_trtllm_moe():
|
||||
try:
|
||||
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
|
||||
except ImportError:
|
||||
trtllm_fp4_block_scale_moe = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def should_use_flashinfer_trtllm_moe():
|
||||
return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
|
||||
not importlib.util.find_spec("flashinfer")
|
||||
or pkg_version.parse(__import__("flashinfer").__version__)
|
||||
>= pkg_version.parse("0.2.9rc1")
|
||||
)
|
||||
def _is_fp4_quantization_enabled():
|
||||
"""Check if ModelOpt FP4 quantization is enabled."""
|
||||
try:
|
||||
# Use the same simple check that works for class selection
|
||||
quantization = global_server_args_dict.get("quantization")
|
||||
return quantization == "modelopt_fp4"
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
# Guess tokens per expert assuming perfect expert distribution first.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
@@ -157,10 +195,6 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
|
||||
self.quant_method.enable_flashinfer_cutlass_moe = (
|
||||
self.enable_flashinfer_cutlass_moe
|
||||
)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_config = quant_config
|
||||
@@ -747,7 +781,130 @@ class FlashInferFusedMoE(FusedMoE):
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class FlashInferFP4MoE(FusedMoE):
|
||||
"""FP4 TRTLLM MoE implementation using FlashInfer."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Extract DeepSeek-specific parameters
|
||||
renormalize = kwargs.pop("renormalize", True)
|
||||
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
||||
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
||||
num_expert_group = kwargs.pop("num_expert_group", None)
|
||||
topk_group = kwargs.pop("topk_group", None)
|
||||
correction_bias = kwargs.pop("correction_bias", None)
|
||||
|
||||
# Extract additional TopK parameters that were previously extracted in forward
|
||||
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Store DeepSeek parameters
|
||||
self.renormalize = renormalize
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Helper: quantize hidden states to FP4 each forward pass
|
||||
# ---------------------------------------------------------------------
|
||||
def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor):
|
||||
"""
|
||||
Quantize hidden states using global scale factor from quantization method.
|
||||
|
||||
Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
|
||||
Only block scales are computed at runtime for efficiency.
|
||||
|
||||
Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
|
||||
"""
|
||||
|
||||
# flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
|
||||
# Only the block scales are computed at runtime
|
||||
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
|
||||
hidden_states,
|
||||
self.w13_input_scale_quant,
|
||||
16, # sf_vec_size
|
||||
False, # use_ue8m0
|
||||
False, # is_sf_swizzled_layout
|
||||
)
|
||||
|
||||
hs_fp4 = hs_fp4_bytes.reshape(
|
||||
hidden_states.shape[0], hidden_states.shape[1] // 2
|
||||
)
|
||||
hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)
|
||||
|
||||
return hs_fp4, hs_sf
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output):
|
||||
"""Forward pass using FP4 TRTLLM kernel.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor
|
||||
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
|
||||
"""
|
||||
|
||||
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
||||
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
||||
raise ValueError(
|
||||
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
||||
)
|
||||
|
||||
_, router_logits = topk_output
|
||||
|
||||
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
|
||||
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
result = trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
||||
hidden_states=hs_fp4,
|
||||
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
|
||||
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
|
||||
gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn
|
||||
),
|
||||
gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
|
||||
gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn
|
||||
),
|
||||
output1_scale_scalar=self.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=self.g1_alphas.data,
|
||||
output2_scale_scalar=self.g2_alphas.data,
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
n_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
|
||||
local_num_experts=self.num_local_experts,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
tile_tokens_dim=_get_tile_tokens_dim(
|
||||
hidden_states.shape[0], self.top_k, self.num_local_experts
|
||||
),
|
||||
routing_method_type=RoutingMethodType.DeepSeekV3,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_fused_moe_impl_class():
|
||||
"""Factory function to get the appropriate FusedMoE implementation class."""
|
||||
if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
|
||||
# Use FP4 variant when FP4 quantization is enabled
|
||||
return FlashInferFP4MoE
|
||||
elif should_use_flashinfer_trtllm_moe():
|
||||
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
|
||||
return FlashInferFusedMoE
|
||||
else:
|
||||
# Default case
|
||||
return FusedMoE
|
||||
|
||||
@@ -1,4 +1,20 @@
|
||||
import importlib.util
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
|
||||
from packaging import version as pkg_version
|
||||
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def should_use_flashinfer_trtllm_moe():
|
||||
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
|
||||
not importlib.util.find_spec("flashinfer")
|
||||
or pkg_version.parse(__import__("flashinfer").__version__)
|
||||
>= pkg_version.parse("0.2.9rc1")
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class MoeA2ABackend(Enum):
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
@@ -29,6 +31,7 @@ from sglang.srt.layers.quantization.utils import (
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import is_cuda, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -39,6 +42,11 @@ if is_cuda():
|
||||
|
||||
try:
|
||||
from flashinfer import mm_fp4 as fp4_gemm
|
||||
from flashinfer import (
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
)
|
||||
|
||||
enable_flashinfer_fp4_gemm = True
|
||||
except ImportError:
|
||||
@@ -47,6 +55,9 @@ except ImportError:
|
||||
else:
|
||||
fp4_gemm = None
|
||||
enable_flashinfer_fp4_gemm = False
|
||||
reorder_rows_for_gated_act_gemm = None
|
||||
shuffle_matrix_a = None
|
||||
shuffle_matrix_sf_a = None
|
||||
|
||||
try:
|
||||
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
||||
@@ -527,6 +538,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
|
||||
@@ -536,6 +548,9 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
return ModelOptFp4LinearMethod(self)
|
||||
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
elif isinstance(layer, FlashInferFP4MoE):
|
||||
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
|
||||
return ModelOptNvFp4FusedMoEMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ModelOptNvFp4FusedMoEMethod(self)
|
||||
return None
|
||||
@@ -726,7 +741,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
" quantization. Please use Blackwell and"
|
||||
" above."
|
||||
)
|
||||
self.enable_flashinfer_cutlass_moe = False
|
||||
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
||||
|
||||
@property
|
||||
def enable_flashinfer_cutlass_moe(self) -> bool:
|
||||
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
||||
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -743,16 +763,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
" dynamic quantization is not supported."
|
||||
)
|
||||
|
||||
# TODO(ch-wan): check if this is needed
|
||||
layer.num_experts = num_experts
|
||||
layer.num_local_experts = num_experts
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.params_dtype = params_dtype
|
||||
layer.quant_config = self.quant_config
|
||||
|
||||
weight_dtype = torch.uint8
|
||||
weight_scale_dtype = torch.float8_e4m3fn
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
# GEMM 1
|
||||
w13_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
layer.local_num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // 2,
|
||||
@@ -767,7 +791,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
# GEMM 2
|
||||
w2_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition // 2,
|
||||
@@ -781,7 +805,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w13_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // self.quant_config.group_size,
|
||||
@@ -795,7 +819,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w2_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition // self.quant_config.group_size,
|
||||
@@ -814,13 +838,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
w13_weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, 2, dtype=torch.float32),
|
||||
data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
||||
|
||||
w2_weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, dtype=torch.float32),
|
||||
data=torch.empty(layer.num_local_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
|
||||
@@ -830,18 +854,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
w13_input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, 2, dtype=torch.float32),
|
||||
data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, dtype=torch.float32),
|
||||
data=torch.empty(layer.num_local_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
def swizzle_blockscale(self, scale: torch.Tensor):
|
||||
assert scale.dtype == torch.float8_e4m3fn
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
@@ -866,9 +890,125 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
else swizzled_scale.reshape(B, M, K)
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
def prepare_static_weights_for_kernel(
|
||||
self,
|
||||
# args_dequant,
|
||||
# args,
|
||||
gemm1_weights,
|
||||
gemm2_weights,
|
||||
gemm1_scales_linear_fp4_bytes,
|
||||
gemm2_scales_linear_fp4_bytes,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_experts,
|
||||
):
|
||||
from flashinfer import (
|
||||
RoutingMethodType,
|
||||
e2m1_and_ufp8sf_scale_to_float,
|
||||
fp4_quantize,
|
||||
next_positive_power_of_2,
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
)
|
||||
|
||||
# GEMM 1
|
||||
"""Prepare quantized weights for kernel (done offline with weights)."""
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
|
||||
# Convert quantized weights to proper formats
|
||||
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2
|
||||
) # packed fp4
|
||||
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 16
|
||||
) # fp8 scaling factors
|
||||
|
||||
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, hidden_size, intermediate_size // 2
|
||||
) # packed fp4
|
||||
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(
|
||||
num_experts, hidden_size, intermediate_size // 16
|
||||
) # fp8 scaling factors
|
||||
|
||||
# Reorder rows of W1 and scales for fused gated activation
|
||||
gemm1_weights_fp4_interleaved = []
|
||||
gemm1_scales_fp4_interleaved = []
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_fp4_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
|
||||
)
|
||||
gemm1_scales_fp4_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
|
||||
)
|
||||
|
||||
# Stack weights and scales for all experts
|
||||
gemm1_weights_fp4_interleaved = torch.stack(
|
||||
gemm1_weights_fp4_interleaved
|
||||
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
|
||||
gemm1_scales_fp4_interleaved = torch.stack(
|
||||
gemm1_scales_fp4_interleaved
|
||||
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_fp4_shuffled = []
|
||||
gemm1_scales_fp4_shuffled = []
|
||||
gemm2_weights_fp4_shuffled = []
|
||||
gemm2_scales_fp4_shuffled = []
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_fp4_shuffled.append(
|
||||
shuffle_matrix_a(
|
||||
gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
gemm1_scales_fp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
|
||||
gemm2_weights_fp4_shuffled.append(
|
||||
shuffle_matrix_a(
|
||||
gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
gemm2_scales_fp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
|
||||
# Stack weights for all experts
|
||||
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
||||
gemm1_scales_fp4_shuffled = (
|
||||
torch.stack(gemm1_scales_fp4_shuffled)
|
||||
.view(torch.float8_e4m3fn)
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
||||
)
|
||||
|
||||
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
|
||||
gemm2_scales_fp4_shuffled = (
|
||||
torch.stack(gemm2_scales_fp4_shuffled)
|
||||
.view(torch.float8_e4m3fn)
|
||||
.reshape(num_experts, hidden_size, intermediate_size // 16)
|
||||
)
|
||||
return (
|
||||
gemm1_weights_fp4_shuffled,
|
||||
gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled,
|
||||
gemm2_scales_fp4_shuffled,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Process FP4 MoE weights after loading from serialized checkpoint.
|
||||
|
||||
Only supports pre-quantized checkpoints with FP8 weights and scales.
|
||||
"""
|
||||
|
||||
# GEMM 1 scale processing
|
||||
if not torch.allclose(
|
||||
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
|
||||
):
|
||||
@@ -880,73 +1020,123 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
||||
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
# Calculate input scales based on strategy
|
||||
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
|
||||
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
||||
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
||||
else:
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
|
||||
# Create shared parameters
|
||||
layer.g1_alphas = Parameter(
|
||||
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
assert (
|
||||
layer.w13_weight_scale.shape[2] % 16 == 0
|
||||
), "Expected weight_scale.dim(1) to be divisible by 16"
|
||||
assert (
|
||||
layer.w13_weight_scale.dtype == torch.float8_e4m3fn
|
||||
), "Weight Blockscale must be represented as FP8-E4M3"
|
||||
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
|
||||
|
||||
layer.w13_blockscale_swizzled = Parameter(
|
||||
w13_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
del layer.w13_weight_scale
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w13_input_scale_quant = Parameter(
|
||||
(1 / w13_input_scale).to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||
|
||||
# GEMM 2
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
||||
else:
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
|
||||
layer.g2_alphas = Parameter(
|
||||
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w13_input_scale_quant = Parameter(
|
||||
(1 / w13_input_scale).to(torch.float32), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale_quant = Parameter(
|
||||
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
assert (
|
||||
layer.w2_weight_scale.shape[2] % 16 == 0
|
||||
), "Expected weight_scale.dim(1) to be divisible by 16"
|
||||
assert (
|
||||
layer.w2_weight_scale.dtype == torch.float8_e4m3fn
|
||||
), "Weight Blockscale must be represented as FP8-E4M3"
|
||||
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
||||
# Validate weight scales
|
||||
for name, weight_scale in [
|
||||
("w13", layer.w13_weight_scale),
|
||||
("w2", layer.w2_weight_scale),
|
||||
]:
|
||||
assert (
|
||||
weight_scale.shape[2] % 16 == 0
|
||||
), f"Expected {name}_weight_scale.dim(2) to be divisible by 16"
|
||||
assert (
|
||||
weight_scale.dtype == torch.float8_e4m3fn
|
||||
), f"{name} Weight Blockscale must be represented as FP8-E4M3"
|
||||
|
||||
layer.w2_blockscale_swizzled = Parameter(
|
||||
w2_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
del layer.w2_weight_scale
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||
# Weight processing based on strategy
|
||||
if (
|
||||
self.enable_flashinfer_trtllm_moe
|
||||
and reorder_rows_for_gated_act_gemm is not None
|
||||
and shuffle_matrix_sf_a is not None
|
||||
):
|
||||
# FlashInfer TRTLLM processing - handles both w13 and w2
|
||||
(
|
||||
gemm1_weights_fp4_shuffled,
|
||||
gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled,
|
||||
gemm2_scales_fp4_shuffled,
|
||||
) = self.prepare_static_weights_for_kernel(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_weight.size(-2), # hidden_size
|
||||
layer.w13_weight.size(-2) // 2, # intermediate_size
|
||||
layer.w13_weight.size(0), # num_experts
|
||||
)
|
||||
|
||||
device = layer.w13_weight.device
|
||||
layer.cutlass_moe_params = CutlassMoEParams(
|
||||
CutlassMoEType.BlockscaledFP4,
|
||||
device,
|
||||
num_experts=layer.num_experts, # global num experts
|
||||
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
|
||||
hidden_size=layer.w13_weight.shape[2] * 2,
|
||||
) # k
|
||||
# Set flashinfer parameters
|
||||
layer.gemm1_weights_fp4_shuffled = Parameter(
|
||||
gemm1_weights_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.gemm2_weights_fp4_shuffled = Parameter(
|
||||
gemm2_weights_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.gemm1_scales_fp4_shuffled = Parameter(
|
||||
gemm1_scales_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.gemm2_scales_fp4_shuffled = Parameter(
|
||||
gemm2_scales_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
|
||||
# Additional parameter needed for TRT-LLM
|
||||
layer.g1_scale_c = Parameter(
|
||||
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Clean up weights that won't be used by TRT-LLM
|
||||
del (
|
||||
layer.w2_weight,
|
||||
layer.w2_weight_scale,
|
||||
layer.w13_weight,
|
||||
layer.w13_weight_scale,
|
||||
)
|
||||
|
||||
print("Applied flashinfer weight processing for both w13 and w2")
|
||||
|
||||
else:
|
||||
# CUTLASS processing - handle w13 and w2 separately
|
||||
|
||||
# Process w13 weights
|
||||
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
|
||||
layer.w13_blockscale_swizzled = Parameter(
|
||||
w13_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||
|
||||
# Process w2 weights
|
||||
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
||||
layer.w2_blockscale_swizzled = Parameter(
|
||||
w2_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||
|
||||
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
||||
print("Applied weight processing for both w13 and w2")
|
||||
|
||||
# Set up CUTLASS MoE parameters
|
||||
device = layer.w13_weight.device
|
||||
layer.cutlass_moe_params = CutlassMoEParams(
|
||||
CutlassMoEType.BlockscaledFP4,
|
||||
device,
|
||||
num_experts=layer.num_experts, # global num experts
|
||||
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
|
||||
hidden_size=layer.w13_weight.shape[2] * 2,
|
||||
) # k
|
||||
|
||||
@property
|
||||
def load_up_proj_weight_first(self) -> bool:
|
||||
@@ -971,13 +1161,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
||||
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
||||
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
|
||||
return layer.forward(x, topk_output)
|
||||
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
assert (
|
||||
not apply_router_weight_on_input
|
||||
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
||||
|
||||
output = flashinfer_cutlass_fused_moe(
|
||||
x,
|
||||
topk_ids.to(torch.int),
|
||||
@@ -1005,7 +1202,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
||||
output = cutlass_moe_fp4(
|
||||
a=x,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
|
||||
@@ -51,7 +51,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||
ScheduleBatchDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
BaseTokenToKVPoolAllocator,
|
||||
SWATokenToKVPoolAllocator,
|
||||
@@ -109,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"enable_triton_kernel_moe",
|
||||
"enable_multimodal",
|
||||
"enable_symm_mem",
|
||||
"quantization",
|
||||
]
|
||||
|
||||
# Put some global args for easy access
|
||||
|
||||
@@ -60,12 +60,9 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import (
|
||||
DeepEPMoE,
|
||||
get_moe_impl_class,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
@@ -307,19 +304,15 @@ class DeepseekV2MoE(nn.Module):
|
||||
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
||||
)
|
||||
|
||||
self.topk = (
|
||||
TopK(
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
renormalize=config.norm_topk_prob,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
if not should_use_flashinfer_trtllm_moe()
|
||||
else None
|
||||
self.topk = TopK(
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
renormalize=config.norm_topk_prob,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
@@ -476,10 +469,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
if self.topk is not None:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
|
||||
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
||||
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
||||
if should_use_flashinfer_trtllm_moe():
|
||||
kwargs["topk_output"] = (self.topk, router_logits)
|
||||
else:
|
||||
kwargs["router_logits"] = router_logits
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
if not _is_cuda:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
@@ -505,10 +502,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
if self.topk is not None:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
|
||||
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
||||
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
||||
if should_use_flashinfer_trtllm_moe():
|
||||
kwargs["topk_output"] = (self.topk, router_logits)
|
||||
else:
|
||||
kwargs["router_logits"] = router_logits
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
if not _is_cuda and not _use_aiter:
|
||||
# fused in biased_grouped_topk so we can skip here
|
||||
|
||||
@@ -50,11 +50,9 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import (
|
||||
get_moe_impl_class,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
|
||||
@@ -481,6 +481,13 @@ class ServerArgs:
|
||||
self.tp_size,
|
||||
], "The expert parallel size must be 1 or the same as the tensor parallel size"
|
||||
|
||||
if self.enable_flashinfer_trtllm_moe:
|
||||
if not self.disable_shared_experts_fusion:
|
||||
self.disable_shared_experts_fusion = True
|
||||
logger.warning(
|
||||
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
|
||||
)
|
||||
|
||||
# DeepEP MoE
|
||||
if self.moe_a2a_backend == "deepep":
|
||||
if self.deepep_mode == "normal":
|
||||
|
||||
Reference in New Issue
Block a user