Files
sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py

690 lines
25 KiB
Python

# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import logging
from enum import Enum
from typing import List, Optional, Tuple
import torch
from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
logger = logging.getLogger(__name__)
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
inplace: suggestion to compute inplace (modify input activation).
"""
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
top_k: Optional[int] = None,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False,
):
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.layer_id = layer_id
self.top_k = top_k
self.hidden_size = hidden_size
self.num_experts = num_experts
self.num_fused_shared_experts = num_fused_shared_experts
self.expert_map_cpu = None
self.expert_map_gpu = None
if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
enable_flashinfer_cutlass_moe = False
enable_ep_moe = False
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
self.moe_ep_size = get_moe_expert_parallel_world_size()
self.moe_ep_rank = get_moe_expert_parallel_rank()
self.moe_tp_size = get_moe_tensor_parallel_world_size()
self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size
if enable_ep_moe:
# TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
# Create a expert map for the local experts
self.expert_map_cpu[
self.moe_ep_rank
* self.num_local_experts : (self.moe_ep_rank + 1)
* self.num_local_experts
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
if not self.enable_flashinfer_cutlass_moe:
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
self.routed_scaling_factor = routed_scaling_factor
assert intermediate_size % self.moe_tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results
self.activation = activation
self.apply_router_weight_on_input = apply_router_weight_on_input
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
self.use_triton_kernels = (
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
)
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
self.quant_method.enable_flashinfer_cutlass_moe = (
self.enable_flashinfer_cutlass_moe
)
assert self.quant_method is not None
self.quant_config = quant_config
self.quant_method.create_weights(
layer=self,
num_experts=self.num_local_experts,
hidden_size=hidden_size,
# FIXME: figure out which intermediate_size to use
intermediate_size=self.intermediate_size_per_partition,
intermediate_size_per_partition=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
)
def _load_per_tensor_weight_scale(
self,
shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
expert_id: int,
):
param_data = param.data
# for per tensor weight quantization
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
elif shard_id == "w2":
param_data[expert_id] = loaded_weight
def _load_model_weight_or_group_weight_scale(
self,
shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
):
# Load grouped weight scales for group quantization
# or model weights
if shard_id == "w2":
self._load_w2(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
elif shard_id in ("w1", "w3"):
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
def _load_per_channel_weight_scale(
self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
):
# for per channel weight quantization
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
def _load_w13(
self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
# trtllm cutlass kernel assumes differently
assert shard_id in ("w1", "w3")
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
start = shard_size
else:
start = 0
if _is_cpu:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
expert_data,
loaded_weight,
start,
shard_size * tp_rank,
shard_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1)
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
expert_data = expert_data.narrow(shard_dim, start, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(
self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
):
"""Load w2 weights for down projection.
Args:
expert_data: The expert data tensor to load into
shard_dim: The dimension to shard along
shard_id: The shard ID (must be "w2")
loaded_weight: The weight tensor to load from
tp_rank: The tensor parallel rank
"""
if not isinstance(expert_data, torch.Tensor) or not isinstance(
loaded_weight, torch.Tensor
):
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
if (
self.quant_config is not None
and "modelopt" in self.quant_config.get_name()
and (expert_data.dim() != 2 or loaded_weight.dim() != 2)
):
raise ValueError(
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
)
if shard_id != "w2":
raise ValueError(f"shard_id must be 'w2', got {shard_id}")
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if _is_cpu:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
expert_data,
loaded_weight,
0, # param_data_start
shard_size * tp_rank,
shard_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1)
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
raise ValueError(
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
)
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
def _load_single_value(
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
):
param_data = param.data
# Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight
def _load_g_idx(
self,
shard_id: str,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
):
if shard_id == "w2":
self._load_w2(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map_cpu is None:
return expert_id
return self.expert_map_cpu[expert_id].item()
def weight_loader(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None:
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
return
if expert_id >= self.num_experts - self.num_fused_shared_experts:
# This is a shared expert.
physical_expert_ids = [expert_id]
else:
physical_expert_ids = (
global_expert_location_metadata.logical_to_all_physical(
self.layer_id, expert_id
)
)
for physical_expert_id in physical_expert_ids:
self._weight_loader_physical(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=physical_expert_id,
)
def _weight_loader_physical(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
def _weight_loader_impl(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
tp_rank = self.moe_tp_rank
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = (
loaded_weight.t().contiguous()
if (
self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod"
)
else loaded_weight
)
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
)
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if getattr(self, "use_flashinfer_trtllm_moe", False):
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id]
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if self.use_triton_kernels:
is_transposed = True
if is_transposed:
shard_dim = int(not shard_dim)
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if (
"compressed" in self.quant_method.__class__.__name__.lower()
and param.data[expert_id] != 1
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
):
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}"
)
self._load_single_value(
param=param, loaded_weight=loaded_weight, expert_id=expert_id
)
return
# Case g_idx
if "g_idx" in weight_name:
self._load_g_idx(
shard_dim=0,
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
if "ModelOpt" in self.quant_method.__class__.__name__:
# Determine per-tensor weight scale patterns based on variant
is_fp4_variant = (
"ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
)
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions = (
"weight_scale_2" in weight_name
if is_fp4_variant
else "weight_scale" in weight_name
) or "input_scale" in weight_name
if per_tensor_conditions:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
elif "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
# Case weight scales and zero_points
if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name:
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
]:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
else:
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}"
)
return
# Case weight_shape
if "weight_shape" in weight_name:
# only required by compressed-tensors
self._load_single_value(
param=param, loaded_weight=loaded_weight, expert_id=expert_id
)
return
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
assert self.quant_method is not None
if self.expert_map_gpu is not None:
topk_output = topk_output._replace(
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
topk_output=topk_output,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
**(
dict(
tp_rank=self.moe_tp_rank,
tp_size=self.moe_tp_size,
ep_rank=self.moe_ep_rank,
ep_size=self.moe_ep_size,
)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
else {}
),
)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
return [
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.w13_"
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
else "experts.w2_"
),
f"experts.{expert_id}.{weight_name}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]
@classmethod
def make_expert_input_scale_params_mapping(
cls,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
# (param_name, weight_name, expert_id, shard_id)
return [
(
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
f"experts.{expert_id}.{shard_id}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id in ["w1", "w2", "w3"]
]