Files
sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Yuan Luo 253454de9b Integrate triton moe kernel (#7689)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
2025-07-06 20:05:49 -07:00

981 lines
35 KiB
Python

# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import importlib
from abc import abstractmethod
from enum import Enum
from typing import Callable, List, Optional, Tuple
import torch
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_hip,
set_weight_attrs,
use_intel_amx_backend,
)
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
if has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward,
)
else:
fused_experts = None # type: ignore
import logging
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter:
from aiter import ActivationType
from aiter.fused_moe import fused_moe
from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight
logger = logging.getLogger(__name__)
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError
@abstractmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
) -> torch.Tensor:
raise NotImplementedError
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, use_triton_kernels: bool = False):
super().__init__()
self.use_triton_kernels = use_triton_kernels
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Fused gate_up_proj (column parallel)
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
if self.use_triton_kernels:
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
w13_weight = torch.nn.Parameter(
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight_n, w2_weight_k = (
hidden_size,
intermediate_size,
)
if self.use_triton_kernels:
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
# Pack weight for get better performance on CPU
if _is_cpu and _is_cpu_amx_available:
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return self.forward(
x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def forward_cuda(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
if self.use_triton_kernels:
return triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if _use_aiter:
assert not no_combine, "unsupported"
if apply_router_weight_on_input:
assert (
topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
x = x * topk_weights.to(x.dtype)
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # topk_weights must be FP32 (float32)
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
else:
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def forward_cpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported."
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights.to(
torch.float
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
topk_ids,
False, # inplace # See [Note] inplace should be False in fused_experts.
False, # use_int8_w8a8
False, # use_fp8_w8a16
None, # w1_scale
None, # w2_scale
None, # block_size
None, # a1_scale
None, # a2_scale
True, # is_vnni
)
else:
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
)
def forward_npu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
)
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("The TPU backend currently does not support MoE.")
forward_native = forward_cpu
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
inplace: suggestion to compute inplace (modify input activation).
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False,
):
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.hidden_size = hidden_size
self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
self.tp_rank = get_tensor_model_parallel_rank()
self.num_experts = num_experts
self.expert_map = None
if enable_flashinfer_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
enable_flashinfer_moe = False
enable_ep_moe = False
self.enable_flashinfer_moe = enable_flashinfer_moe
if enable_ep_moe:
assert (
self.enable_flashinfer_moe
), "FusedMoE only supports EP with --enable-flashinfer-moe"
self.ep_size = self.tp_size
self.ep_rank = self.tp_rank
self.tp_size = 1
self.tp_rank = 0
# Create a tensor of size num_experts filled with -1
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
# Create a expert map for the local experts
assert num_experts % self.ep_size == 0
self.local_num_experts = num_experts // self.ep_size
self.expert_map[
self.ep_rank
* self.local_num_experts : (self.ep_rank + 1)
* self.local_num_experts
] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
else:
self.ep_size = 1
self.ep_rank = 0
self.local_num_experts = num_experts
self.routed_scaling_factor = routed_scaling_factor
self.top_k = top_k
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.activation = activation
self.apply_router_weight_on_input = apply_router_weight_on_input
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
self.use_triton_kernels = (
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
)
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
num_experts=self.local_num_experts,
hidden_size=hidden_size,
# FIXME: figure out which intermediate_size to use
intermediate_size=self.intermediate_size_per_partition,
intermediate_size_per_partition=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
)
def _load_per_tensor_weight_scale(
self,
shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
expert_id: int,
):
param_data = param.data
# for per tensor weight quantization
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
elif shard_id == "w2":
param_data[expert_id] = loaded_weight
def _load_model_weight_or_group_weight_scale(
self,
shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int,
):
# Load grouped weight scales for group quantization
# or model weights
if shard_id == "w2":
self._load_w2(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
elif shard_id in ("w1", "w3"):
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
def _load_per_channel_weight_scale(
self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int,
):
# for per channel weight quantization
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
def _load_w13(
self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int,
):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
# trtllm cutlass kernel assumes differently
assert shard_id in ("w1", "w3")
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
start = shard_size
else:
start = 0
if _is_cpu:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
expert_data,
loaded_weight,
start,
shard_size * tp_rank,
shard_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1)
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
expert_data = expert_data.narrow(shard_dim, start, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(
self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int,
):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if _is_cpu:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
expert_data,
loaded_weight,
0, # param_data_start
shard_size * tp_rank,
shard_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1)
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
def _load_single_value(
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
):
param_data = param.data
# Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight
def _load_g_idx(
self,
shard_id: str,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.tensor,
tp_rank: int,
):
if shard_id == "w2":
self._load_w2(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map is None:
return expert_id
return self.expert_map[expert_id].item()
def weight_loader(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
# TP rank is set to 0 if EP is enabled
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = (
loaded_weight.t().contiguous()
if (
self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod"
)
else loaded_weight
)
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
)
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id]
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if self.use_triton_kernels:
is_transposed = True
if is_transposed:
shard_dim = int(not shard_dim)
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if (
"compressed" in self.quant_method.__class__.__name__.lower()
and param.data[expert_id] != 1
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
):
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}"
)
self._load_single_value(
param=param, loaded_weight=loaded_weight, expert_id=expert_id
)
return
# Case g_idx
if "g_idx" in weight_name:
self._load_g_idx(
shard_dim=0,
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
if "ModelOpt" in self.quant_method.__class__.__name__:
if "weight_scale_2" in weight_name or "input_scale" in weight_name:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
elif "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
# Case weight scales and zero_points
if "scale" in weight_name or "zero" in weight_name:
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
]:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
else:
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}"
)
return
# Case weight_shape
if "weight_shape" in weight_name:
# only required by compressed-tensors
self._load_single_value(
param=param, loaded_weight=loaded_weight, expert_id=expert_id
)
return
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
**(
dict(
tp_rank=self.tp_rank,
tp_size=self.tp_size,
ep_rank=self.ep_rank,
ep_size=self.ep_size,
)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
else {}
),
)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
return [
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.w13_"
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
else "experts.w2_"
),
f"experts.{expert_id}.{weight_name}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]
def _load_fp8_scale(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
param_data = param.data
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if (
param_data[expert_id] != 1
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
):
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}"
)
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight