# SPDX-License-Identifier: Apache-2.0 from typing import Iterable, Set, Tuple, Optional import torch import torch.nn as nn from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.forward_context import ForwardContext, get_forward_context from .vars import * from transformers import PretrainedConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictorLayer as DeepSeekMultiTokenPredictorLayerOrig from vllm.distributed import get_tp_group from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer def DeepSeekMultiTokenPredictorLayer__init__(self, vllm_config: VllmConfig, prefix: str) -> None: super(DeepSeekMultiTokenPredictorLayerOrig, self).__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if USE_PARALLEL_MTP_EH_PROJ: self.eh_proj = RowParallelLinear(config.hidden_size * 2, config.hidden_size, bias=False, return_bias=False) else: self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) from vllm.model_executor.models.deepseek_mtp import SharedHead self.is_v32 = hasattr(config, "index_topk") if self.is_v32: topk_tokens = config.index_topk topk_indices_buffer = torch.empty( vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, device="cuda") else: topk_indices_buffer = None self.shared_head = SharedHead(config=config, quant_config=quant_config) self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, topk_indices_buffer) class DeepSeekMultiTokenPredictorLayer(nn.Module): def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, spec_step_index: int = 0, ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata.items().__iter__().__next__()[1] if not hasattr(self, "weight_capture"): from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekMTPWegitCapture self.weight_capture = DeepseekMTPWegitCapture(self.mtp_block) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) assert inputs_embeds is not None if inputs_embeds.shape[0] > 256: from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler deepseek_mtp_layer_input_buffer = None if isinstance(memory_recycler, DeepseekMTPMemoryRecycler): deepseek_mtp_layer_input_buffer = memory_recycler.DEEPSEEK_MTP_LAYER_INPUT from torch_vacc.vacc.custom_ops import fuse_mtp_stage0 hidden_states = fuse_mtp_stage0( inputs_embeds, previous_hidden_states, positions, self.enorm.weight, self.hnorm.weight, self.enorm.variance_epsilon, world_size=get_tp_group().world_size, rank=get_tp_group().rank_in_group, group_id=get_tp_group().group_id, dev_info=get_tp_group().rank_device_infos, output=deepseek_mtp_layer_input_buffer, ) # if USE_PARALLEL_MTP_EH_PROJ: # tp_size = get_tensor_model_parallel_world_size() # rank_id = get_tensor_model_parallel_rank() # last_dim = hidden_states.shape[-1] # if tp_size > 1: # hiddens_tp = last_dim//tp_size # hidden_states = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp] hidden_states = self.eh_proj(hidden_states) else: hidden_states = torch.vacc.fuse_mtp_allreduce( inputs_embeds, previous_hidden_states, positions, self.enorm.weight, self.hnorm.weight, self.eh_proj.weight, self.enorm.variance_epsilon, world_size = self.weight_capture.layer_moe.dist_args._0_world_size, rank = self.weight_capture.layer_moe.dist_args._1_rank, group_id = self.weight_capture.layer_moe.dist_args._2_group_id, dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info) if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE): hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, residual=None) else: from torch_vacc.vacc.custom_ops import fuse_mla_moe_v2_allreduce_decode layer = self.mtp_block layer_id = 0 kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine] positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens] cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions] sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions] # 对于MTP Layer来说, residual为None,且需要返回residual hidden_states, residual = fuse_mla_moe_v2_allreduce_decode( hidden_states = hidden_states, residual = None, hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[layer_id], q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[layer_id], q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[layer_id], q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[layer_id], w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[layer_id], w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[layer_id], w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[layer_id], w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[layer_id], w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[layer_id], w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[layer_id], kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[layer_id], sin_cache = sin_cache, cos_cache = cos_cache, slot_mapping = attn_metadata.slot_mapping, kv_cache = kv_cache, block_tables = attn_metadata.decode_metadata.block_tables, block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size, w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[layer_id], w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[layer_id], o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[layer_id], o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[layer_id], # mla params seq_lens = attn_metadata.decode_metadata.seq_lens, sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale, head_num = self.weight_capture.layer_moe.attn_args._22_head_num, # flash attention flash_attention = (USE_FLASH_ATTENTION==1), # moe weight rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[layer_id], mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[layer_id], mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[layer_id], mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[layer_id], mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[layer_id], moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[layer_id], moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[layer_id], moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[layer_id], moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[layer_id], mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[layer_id], moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[layer_id], # moe params mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size, mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size, moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size, moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size, # vccl info world_size = self.weight_capture.layer_moe.dist_args._0_world_size, rank = self.weight_capture.layer_moe.dist_args._1_rank, group_id = self.weight_capture.layer_moe.dist_args._2_group_id, dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info) #hidden_states = residual + hidden_states hidden_states = residual.add_(hidden_states) return hidden_states class DeepSeekMTP(nn.Module): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts) params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() from vllm.model_executor.models.deepseek_v2 import get_spec_layer_idx_from_weight_name for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is None: continue name = self._rewrite_spec_layer_name(spec_layer, name) for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue # We have mlp.experts[0].gate_proj in the checkpoint. # Since we handle the experts below in expert_params_mapping, # we need to skip here BEFORE we update the name, otherwise # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. if (("mlp.experts." in name) and name not in params_dict): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) if USE_MERGE_Q_KV_GEN_AND_Q_QR: from vllm.model_executor.models.utils import PPMissingLayer for layer_id in self.model.layers: layer = self.model.layers[layer_id] if isinstance(layer_id, PPMissingLayer): continue layer.mtp_block.self_attn.merge_qkv_weights() return loaded_params class SharedHead(nn.Module): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: try: from torch_vacc.vacc.custom_ops import rms_norm return rms_norm(hidden_states, self.norm.weight, output=hidden_states) except Exception as e: print(f"fuse rms_norm run fail, now use unfused ops: {e}") return self.norm(hidden_states)