# SPDX-License-Identifier: Apache-2.0 from abc import abstractmethod from enum import Enum from typing import Callable, List, Optional, Tuple import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter import vllm.envs as envs from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op if current_platform.is_cuda_alike(): from .fused_moe import fused_experts else: fused_experts = None # type: ignore if current_platform.is_tpu(): # the iterative moe implementation is used until the moe_pallas is fixed from .moe_torch_iterative import fused_moe as fused_moe_pallas else: fused_moe_pallas = None # type: ignore logger = init_logger(__name__) class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" GROUP = "group" BLOCK = "block" class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError @abstractmethod def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: raise NotImplementedError @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter(torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) w2_weight = torch.nn.Parameter(torch.empty( num_experts, hidden_size, intermediate_size_per_partition, dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which # can benefit from tensors located far enough from one another in memory if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm() and weight.stride(-1) == 1 and (weight.stride(-2) * weight.element_size()) % 512 == 0): num_pad = 256 // weight.element_size() weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] torch.cuda.empty_cache() return weight def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( layer.w13_weight.data), requires_grad=False) layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( layer.w2_weight.data), requires_grad=False) # Lazy import to avoid importing triton. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled, shuffle_weights) if is_rocm_aiter_moe_enabled(): # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight.data, layer.w2_weight.data) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: import intel_extension_for_pytorch as ipex layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( layer.w13_weight, layer.w2_weight, use_prepack=envs.VLLM_CPU_MOE_PREPACK, ) else: raise NotImplementedError("CPU MOE only supports x86 arch.") def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, extra_residual: torch.Tensor = None, routed_scaling_factor: float = 1.0, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: out = self.forward( x=x, layer=layer, router_logits=router_logits, top_k=top_k, renormalize=renormalize, use_grouped_topk=use_grouped_topk, topk_group=topk_group, num_expert_group=num_expert_group, global_num_experts=global_num_experts, expert_map=expert_map, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) * routed_scaling_factor if extra_residual is not None: out += extra_residual return out def forward_cuda( self, layer: torch.nn.Module, x: torch.Tensor, use_grouped_topk: bool, top_k: int, router_logits: torch.Tensor, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) return fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map) def forward_cpu( self, layer: torch.nn.Module, x: torch.Tensor, use_grouped_topk: bool, top_k: int, router_logits: torch.Tensor, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", apply_router_weight_on_input: bool = False, **kwargs, ): assert activation == "silu", f"{activation} is not supported." assert apply_router_weight_on_input is False return layer.ipex_fusion( x, use_grouped_topk, top_k, router_logits, renormalize, topk_group, num_expert_group, custom_routing_function, scoring_func, e_score_correction_bias, ) def forward_hpu( self, layer: torch.nn.Module, x: torch.Tensor, use_grouped_topk: bool, top_k: int, router_logits: torch.Tensor, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert not use_grouped_topk assert num_expert_group is None assert topk_group is None assert custom_routing_function is None assert layer is not None assert apply_router_weight_on_input is False if scoring_func != "softmax": raise NotImplementedError( "Only softmax scoring function is supported for HPU.") if e_score_correction_bias is not None: raise NotImplementedError( "Expert score correction bias is not supported for HPU.") return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight, router_logits, top_k) def forward_tpu( self, layer: torch.nn.Module, x: torch.Tensor, use_grouped_topk: bool, top_k: int, router_logits: torch.Tensor, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert not use_grouped_topk assert num_expert_group is None assert topk_group is None assert custom_routing_function is None assert apply_router_weight_on_input is False if scoring_func != "softmax": raise NotImplementedError( "Only softmax scoring function is supported for TPU.") if e_score_correction_bias is not None: raise NotImplementedError( "Expert score correction bias is not supported for TPU.") assert activation == "silu", f"{activation} is not supported for TPU." return fused_moe_pallas(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk=top_k, gating_output=router_logits, global_num_experts=global_num_experts, expert_map=expert_map, renormalize=renormalize) forward_native = forward_tpu if current_platform.is_tpu() else forward_cuda def determine_expert_map( ep_size: int, ep_rank: int, global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]: """ Calculates how many experts should be assigned to each rank for EP and creates a mapping from global to local expert index. Experts are distributed evenly across ranks. Any remaining are assigned to the last rank. Args: ep_size (int): The size of the expert parallel group global_num_experts (int): The total number of experts in the model. Returns: Tuple[int, Optional[torch.Tensor]]: A tuple containing: - local_num_experts (int): The number of experts assigned to the current rank. - expert_map (Optional[torch.Tensor]): A tensor of shape (global_num_experts,) mapping from global to local index. Contains -1 for experts not assigned to the current rank. Returns None if ep_size is 1. """ assert ep_size > 0 if ep_size == 1: return (global_num_experts, None) local_num_experts = global_num_experts // ep_size # Create a tensor of size num_experts filled with -1 expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) # Create a expert map for the local experts if ep_rank < (ep_size - 1): # Each non-last rank gets local_num_experts experts. expert_map[ep_rank * local_num_experts: (ep_rank + 1) * local_num_experts] = \ torch.arange(0, local_num_experts, dtype=torch.int32) else: # All remaining experts are assigned to the last rank. local_num_experts = (global_num_experts - ep_rank * local_num_experts) expert_map[-local_num_experts:] = \ torch.arange(0, local_num_experts, dtype=torch.int32) return (local_num_experts, expert_map) class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. This layer contains both MergedColumnParallel weights (gate_up_proj / w13) and RowParallelLinear weights (down_proj/ w2). Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We copy that naming convention here and handle any remapping in the load_weights function in each model implementation. Args: num_experts: Number of experts in the model top_k: Number of experts selected for each token hidden_size: Input hidden state size of the transformer intermediate_size: Intermediate size of the experts params_dtype: Data type for the parameters. reduce_results: Whether to all all_reduce on the output of the layer renomalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. """ def __init__( self, num_experts: int, # Global number of experts top_k: int, hidden_size: int, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = False, renormalize: bool = True, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, ep_size: Optional[int] = None, dp_size: Optional[int] = None, prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ): super().__init__() if params_dtype is None: params_dtype = torch.get_default_dtype() # Note: here we guard against accessing the TP and DP groups when # uninitialized (this happens when testing) self.tp_size = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()) tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank() self.dp_size = (dp_size if dp_size is not None else get_dp_group().world_size) self.dp_rank = (0 if self.dp_size == 1 else get_dp_group().rank_in_group) self.global_num_experts = num_experts # Use expert parallelism instead of tensor parallelism? vllm_config = get_current_vllm_config() use_ep = (vllm_config.parallel_config.enable_expert_parallel and self.tp_size > 1) # For smuggling this layer into the fused moe custom op # self.use_direct_call = self.dp_size == 1 self.use_direct_call = True if not self.use_direct_call: compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError("Duplicate layer name: {}".format(prefix)) compilation_config.static_forward_context[prefix] = self self.layer_name = prefix if use_ep: # Set TP size to 1 to adjust for EP and adjust EP size and rank # for DP attention. self.ep_rank = tp_rank + self.tp_size * self.dp_rank self.tp_rank = 0 self.ep_size = self.tp_size * self.dp_size self.tp_size = 1 self.local_num_experts, self.expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) else: # Adjust TP size for DP attention self.tp_rank = tp_rank + self.tp_size * self.dp_rank self.ep_rank = 0 self.tp_size = self.tp_size * self.dp_size self.ep_size = 1 self.local_num_experts = self.global_num_experts self.expert_map = None self.top_k = top_k self.global_num_experts = num_experts self.hidden_size = hidden_size self.num_experts = num_experts assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize self.use_grouped_topk = use_grouped_topk if self.use_grouped_topk: assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group self.topk_group = topk_group self.custom_routing_function = custom_routing_function self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias self.activation = activation if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") if current_platform.is_hpu(): from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( UnquantizedFusedMoEMethod()) else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, "intermediate_size_per_partition": self.intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, } # need full intermediate size pre-sharding for WNA16 act order if (self.quant_method.__class__.__name__ in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int): param_data = param.data # for per tensor weight quantization if shard_id in ("w1", "w3"): # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. idx = 0 if shard_id == "w1" else 1 param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) elif shard_id == "w2": param_data[expert_id] = loaded_weight def _load_model_weight_or_group_weight_scale(self, shard_dim: int, expert_data: torch.Tensor, shard_id: str, loaded_weight: torch.Tensor, tp_rank: int, load_full_w2: bool = False): """ Load grouped weight scales for group quantization or model weights :param shard_dim: dimension to shard :param expert_data: parameter for a particular expert :param shard_id: either w1, w2, or w3 :param loaded_weight: checkpoint weight to load into the param :param tp_rank: tensor parallel rank :param load_full_w2: whether or not the w2 loaded should be sharded. """ if shard_id == "w2": # In the case where we have actorder/g_idx, we do not partition the # w2 scales, as indicated by `load_full` argument, for all tp cases self._load_w2(shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank, load_full=load_full_w2) elif shard_id in ("w1", "w3"): self._load_w13(shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): # for per channel weight quantization if shard_id == "w2": expert_data.copy_(loaded_weight) elif shard_id in ("w1", "w3"): self._load_w13(shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": expert_data = expert_data.narrow(shard_dim, 0, shard_size) # w3, up_proj: Load into second logical weight of w13. else: assert shard_id == "w3" expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data.copy_(loaded_weight) def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int, load_full: bool = False): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] if not load_full: loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) def _load_single_value(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int): param_data = param.data # Input scales can be loaded directly and should be equal. param_data[expert_id] = loaded_weight def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): if shard_id == "w2": self._load_w2(shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) else: assert shard_id in ("w1", "w3") expert_data.copy_(loaded_weight) def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: if self.expert_map is None: return expert_id return self.expert_map[expert_id].item() def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: return # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality loaded_weight = loaded_weight.t().contiguous() if ( self.quant_method.__class__.__name__ == "CompressedTensorsWNA16MoEMethod") else loaded_weight if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}.") WEIGHT_SCALE_SUPPORTED = [ e.value for e in FusedMoeWeightScaleSupported ] # Fetch the dim to shard the parameter/loaded weight # based on the shard id. This will be whatever # dimension intermediate_size_per_partition is used. SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) if is_gguf_weight_type: param.weight_type = loaded_weight.item() param.data.copy_(loaded_weight) return # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors # should be whatever dimension intermediate_size_per_partition is is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: shard_dim = int(not shard_dim) shard_dim_force = getattr(param, "shard_dim", None) shard_dim = shard_dim_force if shard_dim_force is not None else shard_dim full_load = len(loaded_weight.shape) == 3 if full_load: shard_dim += 1 # Materialize GGUF UninitializedParameter if is_gguf_weight and isinstance(param, UninitializedParameter): final_shape = list(loaded_weight.shape) if shard_id in ["w1", "w3"]: final_shape[1] *= 2 final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size param.materialize(final_shape, dtype=loaded_weight.dtype) expert_data = param.data if full_load else param.data[expert_id] # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: # this is needed for compressed-tensors only loaded_weight = loaded_weight.to(param.data.device) if param.data[expert_id] != 1 and (param.data[expert_id] - loaded_weight).abs() > 1e-5: raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param.data[expert_id]} " f"vs. {loaded_weight}") self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) return # Case g_idx if "g_idx" in weight_name: self._load_g_idx(shard_dim=0, shard_id=shard_id, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) return # Case weight scales, zero_points and offset if ("scale" in weight_name or "zero" in weight_name or "offset" in weight_name): # load the weight scales and zp based on the quantization scheme # supported weight scales/zp can be found in # FusedMoeWeightScaleSupported # TODO @dsikka: once hardened, refactor to use vLLM Parameters # specific to each case quant_method = getattr(param, "quant_method", None) if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: self._load_per_channel_weight_scale( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) elif quant_method in [ FusedMoeWeightScaleSupported.GROUP.value, FusedMoeWeightScaleSupported.BLOCK.value, ]: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank, load_full_w2=getattr(param, "load_full_w2", False)) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: self._load_per_tensor_weight_scale(shard_id=shard_id, param=param, loaded_weight=loaded_weight, expert_id=expert_id) else: raise ValueError( f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") return # Case weight_shape if "weight_shape" in weight_name: # only required by compressed-tensors self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) return # Case model weights if "weight" in weight_name: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) return @staticmethod def select_experts(hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, use_grouped_topk: bool, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, topk_ids: Optional[torch.Tensor] = None,): from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from ixformer.inference.functions import moe_grouped_topk as grouped_topk # DeekSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None topk_weights, topk_ids = grouped_topk( gating_output=router_logits, topk=top_k, renormalize=renormalize, num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, topk_ids=topk_ids, ) elif custom_routing_function is None: topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize) else: topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize) return topk_weights, topk_ids def naive_multicast(self, x: torch.Tensor, cu_tokens_across_dp_cpu: torch.Tensor): assert (len(x.shape) == 2) buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype) start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ self.dp_rank - 1] end = cu_tokens_across_dp_cpu[self.dp_rank] buffer[start:end, :].copy_(x) for idx in range(get_dp_group().world_size): start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] end = cu_tokens_across_dp_cpu[idx] get_dp_group().broadcast(buffer[start:end, :], idx) return buffer def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, extra_residual: torch.Tensor = None, routed_scaling_factor: float = 1.0): if self.use_direct_call: return self.forward_impl(hidden_states, router_logits, extra_residual, routed_scaling_factor) else: return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, extra_residual: torch.Tensor = None, routed_scaling_factor: float = 1.0): assert self.quant_method is not None if self.dp_size > 1: cu_tokens_across_dp_cpu = get_forward_context( ).dp_metadata.cu_tokens_across_dp_cpu hidden_states = self.naive_multicast(hidden_states, cu_tokens_across_dp_cpu) router_logits = self.naive_multicast(router_logits, cu_tokens_across_dp_cpu) # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, x=hidden_states, router_logits=router_logits, top_k=self.top_k, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, expert_map=self.expert_map, topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, extra_residual=extra_residual, routed_scaling_factor=routed_scaling_factor, ) if self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ self.dp_rank - 1] end = cu_tokens_across_dp_cpu[self.dp_rank] all_hidden_states = get_dp_group().all_reduce(final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) return final_hidden_states @classmethod def make_expert_params_mapping( cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, ckpt_up_proj_name: str, num_experts: int) -> List[Tuple[str, str, int, str]]: return [ # (param_name, weight_name, expert_id, shard_id) ("experts.w13_" if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) for expert_id in range(num_experts) for shard_id, weight_name in [ ("w1", ckpt_gate_proj_name), ("w2", ckpt_down_proj_name), ("w3", ckpt_up_proj_name), ] ] def extra_repr(self) -> str: s = ( f"global_num_experts={self.global_num_experts}, " f"local_num_experts={self.local_num_experts}, " f"top_k={self.top_k}, " f"intermediate_size_per_partition={self.intermediate_size_per_partition}, " # noqa: E501 f"tp_size={self.tp_size},\n" f"ep_size={self.ep_size}, " f"reduce_results={self.reduce_results}, " f"renormalize={self.renormalize}, " f"use_grouped_topk={self.use_grouped_topk}") if self.use_grouped_topk: s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501 s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501 return s def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, layer_name: str) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] assert self.quant_method is not None return self.forward_impl(hidden_states, router_logits) def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, layer_name: str) -> torch.Tensor: return torch.empty_like(hidden_states) direct_register_custom_op( op_name="moe_forward", op_func=moe_forward, mutates_args=[], fake_impl=moe_forward_fake, dispatch_key=current_platform.dispatch_key, )