# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only AfMoE model compatible with HuggingFace weights.""" import typing from collections.abc import Callable, Iterable from itertools import islice import torch from torch import nn from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed import ( get_ep_group, get_pp_group, get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.llama import LlamaMLP as AfmoeMLP from vllm.model_executor.models.utils import ( AutoWeightsLoader, PPMissingLayer, WeightsMapper, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) from vllm.sequence import IntermediateTensors logger = init_logger(__name__) class AfmoeMoE(nn.Module): def __init__( self, config, # AfmoeConfig quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.route_scale = config.route_scale self.score_func = config.score_func self.route_norm = config.route_norm self.ep_group = get_ep_group().device_group self.ep_rank = self.ep_group.rank() self.ep_size = self.ep_group.size() self.n_routed_experts: int = config.num_experts self.n_shared_experts: int = config.num_shared_experts if config.hidden_act != "silu": raise ValueError( f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now." ) # Router gate self.gate = nn.Linear( config.hidden_size, config.num_experts, bias=False, dtype=torch.float32, ) self.expert_bias = nn.Parameter( torch.empty(config.num_experts, dtype=torch.float32) ) # Load balancing settings vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size self.physical_expert_start = self.ep_rank * self.n_local_physical_experts self.physical_expert_end = ( self.physical_expert_start + self.n_local_physical_experts ) self.shared_experts = None # Shared experts if config.num_shared_experts > 0: intermediate_size = config.moe_intermediate_size * config.num_shared_experts self.shared_experts = AfmoeMLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, prefix=f"{prefix}.shared_experts", ) # Routed experts using SharedFusedMoE self.experts = SharedFusedMoE( shared_experts=self.shared_experts, num_experts=config.num_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=self.route_norm if self.score_func == "sigmoid" else False, quant_config=quant_config, use_grouped_topk=True, num_expert_group=config.n_group, topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=self.score_func, routed_scaling_factor=self.route_scale, e_score_correction_bias=self.expert_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states.to(dtype=torch.float32)) fused_moe_out = self.experts( hidden_states=hidden_states, router_logits=router_logits ) if self.shared_experts is not None: shared_output, final_hidden_states = fused_moe_out final_hidden_states = final_hidden_states + shared_output else: final_hidden_states = fused_moe_out if self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states ) return final_hidden_states.view(num_tokens, hidden_dim) class AfmoeAttention(nn.Module): def __init__( self, config, # AfmoeConfig layer_idx: int, hidden_size: int, num_heads: int, num_kv_heads: int, max_position_embeddings: int = 131072, head_dim: int | None = None, rms_norm_eps: float = 1e-05, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: super().__init__() self.layer_idx = layer_idx self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = head_dim or (hidden_size // self.total_num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.max_position_embeddings = max_position_embeddings # Check if this is a local attention layer self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention" self.sliding_window = config.sliding_window if self.is_local_attention else None self.qkv_proj = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) # Gating projection self.gate_proj = ColumnParallelLinear( hidden_size, self.total_num_heads * self.head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate_proj", ) # Q/K normalization self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) # Only create rotary embeddings for local attention if self.is_local_attention: self.rotary_emb = get_rope( self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, ) else: self.rotary_emb = None self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, per_layer_sliding_window=self.sliding_window, prefix=f"{prefix}.attn", attn_type=attn_type, ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) gate, _ = self.gate_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # Apply Q/K normalization q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(q.shape) k = self.k_norm(k.reshape(-1, self.num_kv_heads, self.head_dim)).reshape( k.shape ) # Apply rotary embeddings only for local attention if self.is_local_attention and self.rotary_emb is not None: q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) # Apply gating attn_output = attn_output * torch.sigmoid(gate) output, _ = self.o_proj(attn_output) return output class AfmoeDecoderLayer(nn.Module): def __init__( self, config, # AfmoeConfig cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size max_position_embeddings = getattr(config, "max_position_embeddings", 131072) # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. self.layer_idx = extract_layer_index(prefix) self.self_attn = AfmoeAttention( config=config, layer_idx=self.layer_idx, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, max_position_embeddings=max_position_embeddings, head_dim=config.head_dim, rms_norm_eps=config.rms_norm_eps, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) # MoE or dense FFN self.moe_enabled = self.layer_idx >= config.num_dense_layers if self.moe_enabled: self.mlp = AfmoeMoE( config=config, quant_config=quant_config, prefix=f"{prefix}.mlp", enable_eplb=enable_eplb, ) else: self.mlp = AfmoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.pre_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) hidden_states = self.post_attention_layernorm(hidden_states) # attn norm b # Fully Connected hidden_states, residual = self.pre_mlp_layernorm( # ffn norm a hidden_states, residual ) hidden_states = self.mlp(hidden_states) hidden_states = self.post_mlp_layernorm(hidden_states) # ffn norm b return hidden_states, residual @support_torch_compile( dynamic_arg_dims={ "input_ids": 0, "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, } ) class AfmoeModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config enable_eplb = vllm_config.parallel_config.enable_eplb self.config = config self.vocab_size = config.vocab_size self.mup_enabled = config.mup_enabled if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, prefix=f"{prefix}.embed_tokens" ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: AfmoeDecoderLayer( config=config, cache_config=cache_config, quant_config=quant_config, prefix=prefix, enable_eplb=enable_eplb, ), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embed_input_ids(input_ids) # Apply muP input scaling if enabled if self.mup_enabled: hidden_states = hidden_states * (self.config.hidden_size**0.5) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device ) -> IntermediateTensors: return IntermediateTensors( { "hidden_states": torch.zeros( (batch_size, self.config.hidden_size), dtype=dtype, device=device ), "residual": torch.zeros( (batch_size, self.config.hidden_size), dtype=dtype, device=device ), } ) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if (weight_name not in name) or ("self_attn.gate_proj" in name): continue # We have mlp.experts[0].gate_proj in the checkpoint. # Since we handle the experts below in expert_params_mapping, # we need to skip here BEFORE we update the name, otherwise # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue # Anyway, this is an expert weight and should not be # attempted to load as other weights later is_expert_weight = True # Do not modify `name` since the loop may continue here # Instead, create a new variable name_mapped = name.replace(weight_name, param_name) if is_pp_missing_parameter(name_mapped, self): continue param = params_dict[name_mapped] # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. weight_loader = typing.cast( Callable[..., bool], param.weight_loader ) success = weight_loader( param, loaded_weight, name_mapped, shard_id=shard_id, expert_id=expert_id, return_success=True, ) if success: name = name_mapped break else: if is_expert_weight: # We've checked that this is an expert weight # However it's not mapped locally to this rank # So we simply skip it continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } hf_to_vllm_mapper = WeightsMapper( orig_to_new_suffix={ ".router.gate.weight": ".gate.weight", }, ) fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = AfmoeModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) self.expert_weights = [] # Set MoE hyperparameters self.num_moe_layers = config.num_hidden_layers - config.num_dense_layers self.num_expert_groups = config.n_group self.moe_layers: list[SharedFusedMoE] = [] example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue assert isinstance(layer, AfmoeDecoderLayer) if layer.moe_enabled: example_moe = layer.mlp self.moe_layers.append(layer.mlp.experts) if example_moe is None and self.num_moe_layers > 0: raise RuntimeError("No AfmoeMoE layer found in model.layers.") if example_moe is not None: self.num_logical_experts = example_moe.n_logical_experts self.num_physical_experts = example_moe.n_physical_experts self.num_local_physical_experts = example_moe.n_local_physical_experts self.num_routed_experts = example_moe.n_routed_experts self.num_shared_experts = example_moe.n_shared_experts self.num_redundant_experts = example_moe.n_redundant_experts def set_eplb_state( self, expert_load_view: torch.Tensor, logical_to_physical_map: torch.Tensor, logical_replica_count: torch.Tensor, ) -> None: for layer_idx, layer in enumerate(self.moe_layers): # Register the expert weights. self.expert_weights.append(layer.get_expert_weights()) layer.set_eplb_state( moe_layer_idx=layer_idx, expert_load_view=expert_load_view, logical_to_physical_map=logical_to_physical_map, logical_replica_count=logical_replica_count, ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping()