Files
sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Lianmin Zheng 9a44b643c6 Fix CI (#9012)
2025-08-09 13:33:42 -07:00

1128 lines
41 KiB
Python

# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import datetime
import glob
import logging
import os
import sys
from enum import Enum
from typing import List, Optional, Tuple
import torch
from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_flashinfer_available,
is_hip,
next_power_of_2,
round_up,
)
if is_flashinfer_available():
from flashinfer import (
RoutingMethodType,
fp4_quantize,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
# Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe = None
if should_use_flashinfer_trtllm_moe():
try:
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
except ImportError:
trtllm_fp4_block_scale_moe = None
logger = logging.getLogger(__name__)
def _is_fp4_quantization_enabled():
"""Check if ModelOpt FP4 quantization is enabled."""
try:
# Use the same simple check that works for class selection
quantization = global_server_args_dict.get("quantization")
return quantization == "modelopt_fp4"
except:
return False
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
inplace: suggestion to compute inplace (modify input activation).
"""
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
top_k: Optional[int] = None,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
use_weight_loader_fused: bool = False,
with_bias=False,
):
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.layer_id = layer_id
self.top_k = top_k
self.hidden_size = hidden_size
self.num_experts = num_experts
self.num_fused_shared_experts = num_fused_shared_experts
self.expert_map_cpu = None
self.expert_map_gpu = None
# For activation
self.activation_alpha = activation_alpha
self.swiglu_limit = swiglu_limit
if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
enable_flashinfer_cutlass_moe = False
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
self.moe_ep_size = get_moe_expert_parallel_world_size()
self.moe_ep_rank = get_moe_expert_parallel_rank()
self.moe_tp_size = get_moe_tensor_parallel_world_size()
self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size
if self.moe_ep_size > 1:
# TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1
self.expert_map_cpu = torch.full(
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
)
self.expert_map_cpu = torch.full(
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
)
# Create a expert map for the local experts
self.expert_map_cpu[
self.moe_ep_rank
* self.num_local_experts : (self.moe_ep_rank + 1)
* self.num_local_experts
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
self.routed_scaling_factor = routed_scaling_factor
assert intermediate_size % self.moe_tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results
self.activation = activation
self.apply_router_weight_on_input = apply_router_weight_on_input
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
self.use_triton_kernels = (
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
)
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
self.quant_config = quant_config
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
"enable_flashinfer_mxfp4_moe", False
)
if (
self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
and self.use_enable_flashinfer_mxfp4_moe
):
hidden_size = round_up(hidden_size, 256)
self.hidden_size = hidden_size
self.quant_method.create_weights(
layer=self,
num_experts=self.num_local_experts,
hidden_size=hidden_size,
# FIXME: figure out which intermediate_size to use
intermediate_size=self.intermediate_size_per_partition,
intermediate_size_per_partition=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=(
self.weight_loader
if not use_weight_loader_fused
else self.weight_loader_fused
),
with_bias=with_bias,
)
def _load_per_tensor_weight_scale(
self,
shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
expert_id: int,
):
param_data = param.data
# for per tensor weight quantization
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
elif shard_id == "w2":
param_data[expert_id] = loaded_weight
def _load_model_weight_or_group_weight_scale(
self,
shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
is_bias: bool = False,
):
# Load grouped weight scales for group quantization
# or model weights
if shard_id == "w2":
self._load_w2(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
is_bias=is_bias,
)
elif shard_id in ("w1", "w3", "w13"):
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
is_bias=is_bias,
)
def _load_per_channel_weight_scale(
self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
):
# for per channel weight quantization
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
def _load_w13(
self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
is_bias: bool = False,
):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
assert shard_id in {"w1", "w3", "w13"}
if is_bias:
# if this weight is a bias, the last dimension must be the sharded dimension
shard_dim = -1
if shard_id in {"w1", "w3"}:
# non-fused version
shard_size = expert_data.shape[shard_dim] // 2
elif shard_id in {"w13"}:
# fused version
shard_size = expert_data.shape[shard_dim]
else:
raise NotImplementedError
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
# trtllm cutlass kernel assumes differently
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
start = shard_size
else:
start = 0
if _is_cpu:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
expert_data,
loaded_weight,
start,
shard_size * tp_rank,
shard_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
if not is_bias and self.use_triton_kernels:
# do not transpose for bias
loaded_weight = loaded_weight.transpose(-2, -1)
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
expert_data = expert_data.narrow(shard_dim, start, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(
self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
is_bias: bool = False,
):
"""Load w2 weights for down projection.
Args:
expert_data: The expert data tensor to load into
shard_dim: The dimension to shard along
shard_id: The shard ID (must be "w2")
loaded_weight: The weight tensor to load from
tp_rank: The tensor parallel rank
"""
if not isinstance(expert_data, torch.Tensor) or not isinstance(
loaded_weight, torch.Tensor
):
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
if (
self.quant_config is not None
and "modelopt" in self.quant_config.get_name()
and (expert_data.dim() != 2 or loaded_weight.dim() != 2)
):
raise ValueError(
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
)
if shard_id != "w2":
raise ValueError(f"shard_id must be 'w2', got {shard_id}")
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
if is_bias:
# this expert_data is a bias, not weight,
# for w2_weight_bias in TP, it does not need to be sharded
shard_size = expert_data.shape[-1]
else:
# this parameter is a weight matrix
# for w2 in TP, it shards the input_features, i.e., shard_dim=2
shard_size = expert_data.shape[shard_dim]
if _is_cpu:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
expert_data,
loaded_weight,
0, # param_data_start
shard_size * tp_rank,
shard_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not is_bias and not self.use_presharded_weights:
if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1)
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
def _load_single_value(
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
):
param_data = param.data
# Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight
def _load_g_idx(
self,
shard_id: str,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
):
if shard_id == "w2":
self._load_w2(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map_cpu is None:
return expert_id
return self.expert_map_cpu[expert_id].item()
def weight_loader(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: Optional[int],
) -> None:
# if expert_id is None, then
# all the experts are loaded at the same time
if (
not expert_id
and self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
):
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
param.data[:, :dim1].copy_(loaded_weight)
else:
dim1 = loaded_weight.shape[1]
dim2 = loaded_weight.shape[2]
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return
global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None:
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
return
if expert_id >= self.num_experts - self.num_fused_shared_experts:
# This is a shared expert.
physical_expert_ids = [expert_id]
else:
physical_expert_ids = (
global_expert_location_metadata.logical_to_all_physical(
self.layer_id, expert_id
)
)
for physical_expert_id in physical_expert_ids:
self._weight_loader_physical(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=physical_expert_id,
)
def _weight_loader_physical(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
def _weight_loader_impl(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
tp_rank = self.moe_tp_rank
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = (
loaded_weight.t().contiguous()
if (
self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod"
)
else loaded_weight
)
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
)
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if should_use_flashinfer_trtllm_moe():
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id]
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if self.use_triton_kernels:
is_transposed = True
if is_transposed:
shard_dim = int(not shard_dim)
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if (
"compressed" in self.quant_method.__class__.__name__.lower()
and param.data[expert_id] != 1
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
):
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}"
)
self._load_single_value(
param=param, loaded_weight=loaded_weight, expert_id=expert_id
)
return
# Case g_idx
if "g_idx" in weight_name:
self._load_g_idx(
shard_dim=0,
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
if "ModelOpt" in self.quant_method.__class__.__name__:
# Determine per-tensor weight scale patterns based on variant
is_fp4_variant = (
"ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
)
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions = (
"weight_scale_2" in weight_name
if is_fp4_variant
else "weight_scale" in weight_name
) or "input_scale" in weight_name
if per_tensor_conditions:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
elif "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
# Case weight scales and zero_points
if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name:
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
]:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
else:
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}"
)
return
# Case weight_shape
if "weight_shape" in weight_name:
# only required by compressed-tensors
self._load_single_value(
param=param, loaded_weight=loaded_weight, expert_id=expert_id
)
return
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
def weight_loader_fused(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
) -> None:
tp_rank = self.moe_tp_rank
if self.quant_config is not None and self.quant_config.get_name() == "mxfp4":
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
param.data[:, :dim1].copy_(loaded_weight)
elif "scale" in weight_name:
param.data.copy_(loaded_weight)
else:
dim1 = loaded_weight.shape[1]
dim2 = loaded_weight.shape[2]
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO: check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = (
loaded_weight.t().contiguous()
if (
self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod"
)
else loaded_weight
)
if shard_id not in ("w13", "w2"):
raise ValueError(f"shard_id must be ['w13','w2'] but " f"got {shard_id}.")
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w13": 1, "w2": 2}
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE = {"w13": 2, "w2": 1}
expert_data = param.data
is_bias = expert_data.dim() == 2
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
if self.use_triton_kernels:
is_transposed = True
shard_dim = (
SHARD_ID_TO_SHARDED_DIM[shard_id]
if not is_transposed
else SHARD_ID_TO_SHARDED_DIM_TRANSPOSE[shard_id]
)
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
is_bias=is_bias,
)
return
else:
logging.warning(
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
)
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
origin_hidden_states_dim = hidden_states.shape[-1]
if self.hidden_size != origin_hidden_states_dim:
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, self.hidden_size - origin_hidden_states_dim),
mode="constant",
value=0.0,
)
assert self.quant_method is not None
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
if self.expert_map_cpu is not None and self.expert_map_gpu is None:
# If we are in EP mode, we need to move the expert map to GPU.
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
if self.expert_map_gpu is not None and isinstance(
topk_output, StandardTopKOutput
):
topk_output = topk_output._replace(
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
)
# Matrix multiply.
with use_symmetric_memory(get_tp_group()) as sm:
kwargs = {}
if self.activation_alpha is not None:
kwargs["activation_alpha"] = self.activation_alpha
if self.swiglu_limit is not None:
kwargs["swiglu_limit"] = self.swiglu_limit
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
topk_output=topk_output,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
**(
dict(
tp_rank=self.moe_tp_rank,
tp_size=self.moe_tp_size,
ep_rank=self.moe_ep_rank,
ep_size=self.moe_ep_size,
)
if self.quant_method.__class__.__name__
== "ModelOptNvFp4FusedMoEMethod"
else {}
),
**kwargs,
)
sm.tag(final_hidden_states)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states[..., :origin_hidden_states_dim].contiguous()
@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
return [
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.w13_"
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
else "experts.w2_"
),
f"experts.{expert_id}.{weight_name}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]
@classmethod
def make_expert_params_mapping_fused(
cls,
ckpt_gate_up_proj_name: str,
ckpt_down_proj_name: str,
ckpt_gate_up_proj_bias_name: str,
ckpt_down_proj_bias_name: str,
):
return [
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
(
"experts.w13_weight_bias",
f"experts.{ckpt_gate_up_proj_bias_name}",
"w13",
),
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
]
@classmethod
def make_expert_params_mapping_fused_mxfp4(
cls,
ckpt_gate_up_proj_name: str,
ckpt_down_proj_name: str,
ckpt_gate_up_proj_bias_name: str,
ckpt_down_proj_bias_name: str,
ckpt_gate_up_proj_scale_name: str,
ckpt_down_proj_scale_name: str,
):
return [
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
(
"experts.w13_weight_bias",
f"experts.{ckpt_gate_up_proj_bias_name}",
"w13",
),
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
(
"experts.w13_weight_scale",
f"experts.{ckpt_gate_up_proj_scale_name}",
"w13",
),
("experts.w2_weight_scale", f"experts.{ckpt_down_proj_scale_name}", "w2"),
]
@classmethod
def make_expert_input_scale_params_mapping(
cls,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
# (param_name, weight_name, expert_id, shard_id)
return [
(
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
f"experts.{expert_id}.{shard_id}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id in ["w1", "w2", "w3"]
]
class FlashInferFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs):
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
super().__init__(*args, **kwargs)
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
assert self.use_flashinfer_trtllm_moe
assert (
self.activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"
assert self.quant_method is not None
assert (
self.renormalize
), "Renormalize is required for flashinfer blockscale fp8 moe"
assert (
self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
# TRTLLM mode expects (TopK_config, router_logits) tuple
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
# Matrix multiply.
final_hidden_states = self.quant_method.apply_with_router_logits(
layer=self,
x=hidden_states,
router_logits=router_logits,
activation=self.activation,
routed_scaling_factor=self.routed_scaling_factor,
)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
class FlashInferFP4MoE(FusedMoE):
"""FP4 TRTLLM MoE implementation using FlashInfer."""
def __init__(self, *args, **kwargs):
# Extract DeepSeek-specific parameters
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
# Extract additional TopK parameters that were previously extracted in forward
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
super().__init__(*args, **kwargs)
# Store DeepSeek parameters
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
# ---------------------------------------------------------------------
# Helper: quantize hidden states to FP4 each forward pass
# ---------------------------------------------------------------------
def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor):
"""
Quantize hidden states using global scale factor from quantization method.
Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
Only block scales are computed at runtime for efficiency.
Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
"""
# flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
# Only the block scales are computed at runtime
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
hidden_states,
self.w13_input_scale_quant,
16, # sf_vec_size
False, # use_ue8m0
False, # is_sf_swizzled_layout
)
hs_fp4 = hs_fp4_bytes.reshape(
hidden_states.shape[0], hidden_states.shape[1] // 2
)
hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)
return hs_fp4, hs_sf
def forward(self, hidden_states: torch.Tensor, topk_output):
"""Forward pass using FP4 TRTLLM kernel.
Args:
hidden_states: Input tensor
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
"""
# TRTLLM mode expects (TopK_config, router_logits) tuple
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
router_logits = router_logits.to(torch.float32)
result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=self.correction_bias.to(hidden_states.dtype),
hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=self.g1_scale_c.data,
output1_scale_gate_scalar=self.g1_alphas.data,
output2_scale_scalar=self.g2_alphas.data,
num_experts=self.num_experts,
top_k=self.top_k,
n_group=self.num_expert_group,
topk_group=self.topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts,
routed_scaling_factor=self.routed_scaling_factor,
tile_tokens_dim=_get_tile_tokens_dim(
hidden_states.shape[0], self.top_k, self.num_local_experts
),
routing_method_type=RoutingMethodType.DeepSeekV3,
do_finalize=True,
)[0]
return result
def get_fused_moe_impl_class():
"""Factory function to get the appropriate FusedMoE implementation class."""
if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
# Use FP4 variant when FP4 quantization is enabled
return FlashInferFP4MoE
elif should_use_flashinfer_trtllm_moe():
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
return FlashInferFusedMoE
else:
# Default case
return FusedMoE