# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable from itertools import islice from typing import Any import torch import torch.nn as nn from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config from vllm.distributed import ( get_ep_group, get_pp_group, get_tensor_model_parallel_world_size, ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator, ) from vllm.model_executor.layers.mamba.short_conv import ShortConv from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Lfm2MoeConfig from .interfaces import ( HasInnerState, IsHybrid, MixtureOfExperts, SupportsLoRA, SupportsPP, SupportsQuant, ) from .utils import ( AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) class Lfm2MoeMlp(nn.Module): def __init__( self, dim: int, ff_dim: int, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.w1 = MergedColumnParallelLinear( input_size=dim, output_sizes=[ff_dim] * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.w1", ) self.w2 = RowParallelLinear( input_size=ff_dim, output_size=dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.w2", ) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.w1(x) x = self.act_fn(gate_up) x, _ = self.w2(x) return x class Lfm2MoeSparseMoeBlock(nn.Module): def __init__( self, config: Lfm2MoeConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.ep_group = get_ep_group().device_group self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts if self.tp_size > self.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {self.n_routed_experts}." ) # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size self.physical_expert_start = self.ep_rank * self.n_local_physical_experts self.physical_expert_end = ( self.physical_expert_start + self.n_local_physical_experts ) self.gate = ReplicatedLinear( config.hidden_size, config.num_experts, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate", ) if config.use_expert_bias: self.gate.e_score_correction_bias = nn.Parameter( torch.empty(self.n_routed_experts, dtype=torch.float32) ) else: self.gate.e_score_correction_bias = None self.experts = FusedMoE( num_experts=self.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, # needed for softmax score func num_expert_group=1, topk_group=1, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, scoring_func="sigmoid", e_score_correction_bias=self.gate.e_score_correction_bias, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = ( self.experts(hidden_states=hidden_states, router_logits=router_logits) * self.routed_scaling_factor ) if self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states ) return final_hidden_states.view(orig_shape) class Lfm2MoeAttention(nn.Module): def __init__( self, config: Lfm2MoeConfig, layer_idx: int, hidden_size: int, num_heads: int, num_kv_heads: int, rope_theta: float = 10000, rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.layer_idx = layer_idx self.hidden_size = hidden_size self.num_kv_heads = num_kv_heads tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( hidden_size=self.hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.out_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=self.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.out_proj", ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=self.max_position_embeddings, base=self.rope_theta, rope_scaling=rope_scaling, is_neox_style=True, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, prefix=f"{prefix}.attn", ) self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: n_tokens, _ = hidden_states.shape qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous() k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous() q = self.q_layernorm(q) k = self.k_layernorm(k) q, k = self.rotary_emb(positions, q, k) q = q.view(n_tokens, self.num_heads * self.head_dim) k = k.view(n_tokens, self.num_kv_heads * self.head_dim) attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output class Lfm2MoeAttentionDecoderLayer(nn.Module): def __init__( self, config: Lfm2MoeConfig, layer_idx: int, model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ) -> None: super().__init__() self.prefix = prefix self.config = config self.layer_idx = layer_idx rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): rope_scaling["original_max_position_embeddings"] = ( config.original_max_position_embeddings ) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = Lfm2MoeAttention( config=config, layer_idx=layer_idx, hidden_size=config.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) if layer_idx < config.num_dense_layers: self.feed_forward = Lfm2MoeMlp( dim=config.hidden_size, ff_dim=config.intermediate_size, quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) else: self.feed_forward = Lfm2MoeSparseMoeBlock( config=config, quant_config=quant_config, prefix=f"{prefix}.feed_forward", enable_eplb=enable_eplb, ) self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.operator_norm(hidden_states) else: hidden_states, residual = self.operator_norm(hidden_states, residual) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) hidden_states, residual = self.ffn_norm(hidden_states, residual) return self.feed_forward(hidden_states), residual class Lfm2MoeShortConvDecoderLayer(nn.Module): def __init__( self, config: Lfm2MoeConfig, layer_idx: int, model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ) -> None: super().__init__() self.layer_idx = layer_idx self.conv = ShortConv( config=config, dim=config.hidden_size, layer_idx=layer_idx, model_config=model_config, cache_config=cache_config, prefix=f"{prefix}.conv", ) if layer_idx < config.num_dense_layers: self.feed_forward = Lfm2MoeMlp( dim=config.hidden_size, ff_dim=config.intermediate_size, quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) else: self.feed_forward = Lfm2MoeSparseMoeBlock( config=config, quant_config=quant_config, prefix=f"{prefix}.feed_forward", enable_eplb=enable_eplb, ) self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) def forward( self, hidden_states: torch.Tensor, residual: torch.Tensor | None, **kwargs, ): if residual is None: residual = hidden_states hidden_states = self.operator_norm(hidden_states) else: hidden_states, residual = self.operator_norm(hidden_states, residual) output = torch.empty_like(hidden_states) self.conv( hidden_states, output, ) hidden_states, residual = self.ffn_norm(output, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @support_torch_compile class Lfm2MoeModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config enable_eplb = parallel_config.enable_eplb eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size ) def get_layer(prefix: str): layer_idx = extract_layer_index(prefix) is_attn = self.config.layer_types[layer_idx] == "full_attention" layer_class = ( Lfm2MoeAttentionDecoderLayer if is_attn else Lfm2MoeShortConvDecoderLayer ) return layer_class( config, layer_idx, model_config, cache_config, quant_config=quant_config, prefix=prefix, enable_eplb=enable_eplb, ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" ) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) if get_pp_group().is_last_rank: self.embedding_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) else: self.embedding_norm = PPMissingLayer() def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, ) if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.embedding_norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", num_experts=self.config.num_experts, num_redundant_experts=self.num_redundant_experts, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), (".w1", ".w1", 0), (".w1", ".w3", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if "expert_bias" in name: name = name.replace("expert_bias", "gate.e_score_correction_bias") for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue if ("feed_forward.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if ( name.endswith(".bias") or name.endswith("_bias") ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue # Skip loading extra bias for GPTQ models. if ( name.endswith(".bias") or name.endswith("_bias") ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id, ) break else: # Skip loading extra bias for GPTQ models. if ( name.endswith(".bias") or name.endswith("_bias") ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Lfm2MoeForCausalLM( nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant, MixtureOfExperts, ): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "w1": [ "w1", "w3", ], } # LoRA specific attributes embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] @classmethod def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, ...]: return MambaStateDtypeCalculator.short_conv_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, ) @classmethod def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[tuple[int, int]]: """Calculate shapes for LFM2's convolutional cache. Args: vllm_config: vLLM config Returns: Tuple containing: - conv_state_shape: Shape for convolutional state cache """ parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config return MambaStateShapeCalculator.short_conv_state_shape( tp_world_size=parallel_config.tensor_parallel_size, intermediate_size=hf_config.hidden_size, conv_kernel=hf_config.conv_L_cache, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config assert not cache_config.enable_prefix_caching, ( "Lfm2Moe currently does not support prefix caching" ) super().__init__() self.config = config self.model = Lfm2MoeModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) # Set MoE hyperparameters self.expert_weights = [] self.moe_layers = [] example_layer = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue assert isinstance( layer, (Lfm2MoeAttentionDecoderLayer, Lfm2MoeShortConvDecoderLayer) ) if isinstance(layer.feed_forward, Lfm2MoeSparseMoeBlock): example_layer = layer.feed_forward self.moe_layers.append(layer.feed_forward.experts) if example_layer is None: raise RuntimeError( "No Lfm2MoeSparseMoeBlock layer found in the model.layers." ) self.num_moe_layers = len(self.moe_layers) self.num_expert_groups = 1 self.num_shared_experts = 0 self.num_logical_experts = example_layer.n_logical_experts self.num_physical_experts = example_layer.n_physical_experts self.num_local_physical_experts = example_layer.n_local_physical_experts self.num_routed_experts = example_layer.n_routed_experts self.num_redundant_experts = example_layer.n_redundant_experts def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) def update_physical_experts_metadata( self, num_physical_experts: int, num_local_physical_experts: int, ) -> None: assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.feed_forward, Lfm2MoeSparseMoeBlock): moe = layer.feed_forward moe.n_local_physical_experts = num_local_physical_experts moe.n_physical_experts = num_physical_experts moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping()