from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tp_group, get_tensor_model_parallel_world_size,get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce) from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.model_executor.models.interfaces import SupportsPP from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.models.deepseek_v2 import yarn_get_mscale, DeepseekV2MLAAttention, Indexer from vllm.logger import init_logger logger = init_logger(__name__) from .vars import * from ..ops.deepseek_fused_mlp_moe import (vacc_fused_decode_moe_fp8, vacc_fused_prefill_moe_fp8, vacc_fused_mlp_fp8) from .fused_forward import * import os test_layer_en = os.getenv("test_layer_en", "0") # class DeepseekV2MLAAttention(nn.Module): # def __init__( # self, # vllm_config: VllmConfig, # config: Union[DeepseekV2Config, DeepseekV3Config], # hidden_size: int, # num_heads: int, # qk_nope_head_dim: int, # qk_rope_head_dim: int, # v_head_dim: int, # q_lora_rank: Optional[int], # kv_lora_rank: int, # rope_theta: float = 10000, # rope_scaling: Optional[dict[str, Any]] = None, # max_position_embeddings: int = 8192, # cache_config: Optional[CacheConfig] = None, # quant_config: Optional[QuantizationConfig] = None, # prefix: str = "", # topk_indices_buffer: Optional[torch.Tensor] = None, # ) -> None: # super(DeepseekV2MLAAttention,self).__init__() # self.hidden_size = hidden_size # self.qk_nope_head_dim = qk_nope_head_dim # self.qk_rope_head_dim = qk_rope_head_dim # self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim # self.v_head_dim = v_head_dim # self.q_lora_rank = q_lora_rank # self.kv_lora_rank = kv_lora_rank # self.num_heads = num_heads # tp_size = get_tensor_model_parallel_world_size() # assert num_heads % tp_size == 0 # self.num_local_heads = num_heads // tp_size # self.scaling = self.qk_head_dim**-0.5 # self.rope_theta = rope_theta # self.max_position_embeddings = max_position_embeddings # if self.q_lora_rank is not None: # if USE_PARALLEL_Q_KV_GEN: # self.q_a_proj = RowParallelLinear(self.hidden_size, # self.q_lora_rank, # bias=False, # quant_config=quant_config, # prefix=f"{prefix}.q_a_proj") # else: # self.q_a_proj = ReplicatedLinear(self.hidden_size, # self.q_lora_rank, # bias=False, # quant_config=quant_config, # prefix=f"{prefix}.q_a_proj") # self.q_a_layernorm = RMSNorm(self.q_lora_rank, # eps=config.rms_norm_eps) # self.q_b_proj = ColumnParallelLinear(q_lora_rank, # self.num_heads * # self.qk_head_dim, # bias=False, # quant_config=quant_config, # prefix=f"{prefix}.q_b_proj") # else: # self.q_proj = ColumnParallelLinear(self.hidden_size, # self.num_heads * # self.qk_head_dim, # bias=False, # quant_config=quant_config, # prefix=f"{prefix}.q_proj") # if USE_PARALLEL_Q_KV_GEN: # self.kv_a_proj_with_mqa = RowParallelLinear( # self.hidden_size, # self.kv_lora_rank + self.qk_rope_head_dim, # bias=False, # quant_config=quant_config, # prefix=f"{prefix}.kv_a_proj_with_mqa") # else: # self.kv_a_proj_with_mqa = ReplicatedLinear( # self.hidden_size, # self.kv_lora_rank + self.qk_rope_head_dim, # bias=False, # quant_config=quant_config, # prefix=f"{prefix}.kv_a_proj_with_mqa") # self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, # eps=config.rms_norm_eps) # self.kv_b_proj = ColumnParallelLinear( # self.kv_lora_rank, # self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), # bias=False, # quant_config=quant_config, # prefix=f"{prefix}.kv_b_proj") # self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, # self.hidden_size, # bias=False, # quant_config=quant_config, # prefix=f"{prefix}.o_proj") # rope_scaling["rope_type"] = 'deepseek_yarn' # self.rotary_emb = get_rope(qk_rope_head_dim, # rotary_dim=qk_rope_head_dim, # max_position=max_position_embeddings, # base=rope_theta, # rope_scaling=rope_scaling, # is_neox_style=False) # if rope_scaling: # mscale_all_dim = rope_scaling.get("mscale_all_dim", False) # scaling_factor = rope_scaling["factor"] # mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) # self.scaling = self.scaling * mscale * mscale # self.is_v32 = hasattr(config, "index_topk") # if self.is_v32: # self.indexer = Indexer(vllm_config, config, hidden_size, # q_lora_rank, quant_config, cache_config, # topk_indices_buffer, f"{prefix}.indexer") # else: # self.indexer = None # self.mla_attn = Attention( # num_heads=self.num_local_heads, # head_size=self.kv_lora_rank, # scale=self.scaling, # num_kv_heads=1, # cache_config=cache_config, # quant_config=quant_config, # prefix=f"{prefix}.attn", # use_mla=True, # # MLA Args # q_lora_rank=self.q_lora_rank, # kv_lora_rank=self.kv_lora_rank, # qk_nope_head_dim=self.qk_nope_head_dim, # qk_rope_head_dim=self.qk_rope_head_dim, # qk_head_dim=self.qk_head_dim, # v_head_dim=self.v_head_dim, # rotary_emb=self.rotary_emb, # q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, # kv_b_proj=self.kv_b_proj, # o_proj=self.o_proj, # ) # self.prefix = prefix # self.debug_layer_idx = int(self.prefix.split(".")[-2]) # def forward( # self, # positions: torch.Tensor, # hidden_states: torch.Tensor, # kv_cache: torch.Tensor, # attn_metadata: AttentionMetadata, # ) -> torch.Tensor: # tp_size = get_tensor_model_parallel_world_size() # rank_id = get_tensor_model_parallel_rank() # last_dim = hidden_states.shape[-1] # if USE_PARALLEL_Q_KV_GEN: #tp qa and kva # hidden_states_split = hidden_states # if tp_size > 1: # hiddens_tp = last_dim//tp_size # hidden_states_split = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp].contiguous() # if self.q_lora_rank is not None: # ckq = self.q_a_proj(hidden_states_split)[0] # hidden_states_or_q_c = self.q_a_layernorm(ckq) # else: # hidden_states_or_q_c = hidden_states # kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states_split)[0].split( # [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) # kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) # return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache, # attn_metadata) # if self.q_lora_rank is not None: # ckq = self.q_a_proj(hidden_states)[0] # hidden_states_or_q_c = self.q_a_layernorm(ckq) # else: # hidden_states_or_q_c = hidden_states # kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( # [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) # kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) # return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache, # attn_metadata) class DeepseekV2MoE(nn.Module): def forward(self, hidden_states: torch.Tensor, residual = None, rms_norm = None): # moe layer support prefill&decode vacc ops if residual is not None: try: reduce_result = self.tp_size > 1 # decode moe, first seq if self.is_decode: hidden_states, residual = vacc_fused_decode_moe_fp8(self, self.shared_experts, hidden_states, residual, rms_norm, self.gate, self.experts, self.routed_scaling_factor, reduce_result) return hidden_states, residual # prefill moe, first expert else: hidden_states, residual = vacc_fused_prefill_moe_fp8(self, self.shared_experts, hidden_states, residual, rms_norm, self.gate, self.experts, self.routed_scaling_factor, reduce_result) return hidden_states, residual except Exception as e: logger.warning("vacc fused moe run fail, now use unfused ops %s", e) hidden_states, residual = rms_norm(hidden_states, residual) self.experts.is_decode = self.is_decode # 1. fuse_prefill_pre_moe num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) if self.n_shared_experts is not None: try: shared_output = vacc_fused_mlp_fp8(self.shared_experts, hidden_states, moe_share=True) except Exception as e: logger.warning("fused mlp is Error, now use Default:%s", e) shared_output = self.shared_experts(hidden_states) router_logits, _ = self.gate(hidden_states) # 2. fused_moe final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits) # 3. add_reduce # now fuse share_mlp add to experts # if shared_output is not None: # # out = input + other * alpha # final_hidden_states = shared_output.add_(final_hidden_states, alpha=self.routed_scaling_factor) # else: # final_hidden_states = final_hidden_states * self.routed_scaling_factor if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) if residual is not None: return final_hidden_states.view(num_tokens, hidden_dim), residual return final_hidden_states.view(num_tokens, hidden_dim) class DeepseekV2MLP(nn.Module): def forward(self, x, residual = None, rms_norm = None): # use all fused ops if residual is not None: reduce_result = self.down_proj.reduce_results and self.down_proj.tp_size > 1 hidden_states, residual = vacc_fused_mlp_fp8(self, x, residual, rms_norm, reduce_result) return hidden_states, residual # use default fuse ops try: output_parallel = vacc_fused_mlp_fp8(self, x, residual, rms_norm) if self.down_proj.reduce_results and self.down_proj.tp_size > 1: x = tensor_model_parallel_all_reduce(output_parallel) else: x = output_parallel except Exception as e: logger.warning("fuse_mlp run fail, now use default: %s", e) gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x class DeepseekV2Model(nn.Module): def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata.items().__iter__().__next__()[1] first_k_dense_replace = self.config.first_k_dense_replace if hasattr(self.config, "first_k_dense_replace") else 3 if not hasattr(self, "weight_capture"): from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekWeightCapture self.weight_capture = DeepseekWeightCapture(self.layers, self.start_layer, self.end_layer) self.cached_weights_state = True self.cached_batch = 1 self.layer_nums = self.end_layer - self.start_layer self.is_pipeline_first = get_pp_group().is_first_rank if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE): for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual) else: # update global seq lens, use for serve infos # update_seqence_length(attn_metadata.decode_metadata.seq_lens) if FUSE_ALL_DECODER_LAYERS: self.weight_capture.update_attn_args(attn_metadata.decode_metadata.seq_lens, attn_metadata.slot_mapping, [self.layers[i].self_attn.mla_attn.kv_cache[forward_context.virtual_engine] for i in range(self.start_layer, first_k_dense_replace)], [self.layers[i].self_attn.mla_attn.kv_cache[forward_context.virtual_engine] for i in range(first_k_dense_replace, self.end_layer)], attn_metadata.decode_metadata.block_tables) hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 0) hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 1) hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 2) if hidden_states.shape[0] != self.cached_batch: # batch切换,重新执行缓存 self.cached_weights_state = True self.cached_batch = hidden_states.shape[0] if self.cached_weights_state: self.cached_weights_state = False hidden_states, residual = forward_mla_moe_layers_with_weights(hidden_states, residual, self.weight_capture) else: hidden_states, residual = forward_mla_moe_layers_without_weights(hidden_states, residual, self.weight_capture) else: from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode,fuse_mla_moe_v2_allreduce_decode for i in range(0, self.layer_nums): layer_id = i + self.start_layer layer = self.layers[layer_id] kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine] positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens] cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions] sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions] if layer_id < first_k_dense_replace: hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode( hidden_states = hidden_states, residual = residual, hidden_states_norm_weight = self.weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight[i], q_a_proj_weight = self.weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights[i], q_a_proj_weight_scale_inv = self.weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv[i], q_a_layernorm_weight = self.weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight[i], w_q = self.weight_capture.layer_mlp.attn_args._3_W_Q[i], w_q_scale = self.weight_capture.layer_mlp.attn_args._4_W_Q_scales[i], w_uk = self.weight_capture.layer_mlp.attn_args._5_W_UK[i], w_uk_scale = self.weight_capture.layer_mlp.attn_args._6_W_UK_scales[i], w_qr = self.weight_capture.layer_mlp.attn_args._7_W_QR[i], w_qr_scale = self.weight_capture.layer_mlp.attn_args._8_W_QR_scales[i], kv_a_layernorm_weight = self.weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight[i], sin_cache = sin_cache,# self.weight_capture.layer_mlp.attn_args._10_sin_cache, cos_cache = cos_cache,# self.weight_capture.layer_mlp.attn_args._11_cos_cache, slot_mapping = attn_metadata.slot_mapping,#self.weight_capture.layer_mlp.attn_args._12_slot_mapping[i], kv_cache = kv_cache,#self.weight_capture.layer_mlp.attn_args._13_kv_cache[i], block_tables = attn_metadata.decode_metadata.block_tables,#self.weight_capture.layer_mlp.attn_args._14_block_tables[i], block_group_size = self.weight_capture.layer_mlp.attn_args._15_env_blk_grp_size, w_uv = self.weight_capture.layer_mlp.attn_args._16_W_UV[i], w_uv_scale = self.weight_capture.layer_mlp.attn_args._17_W_UV_scales[i], o_proj_weight = self.weight_capture.layer_mlp.attn_args._18_o_proj_weight[i], o_proj_weight_scale_inv = self.weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv[i], # mla params seq_lens = attn_metadata.decode_metadata.seq_lens, sm_scale = self.weight_capture.layer_mlp.attn_args._21_sm_scale, head_num = self.weight_capture.layer_mlp.attn_args._22_head_num, # flash attention flash_attention = (USE_FLASH_ATTENTION==1), # mlp weight rms_weight = self.weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight[i], mlp_weight_13 = self.weight_capture.layer_mlp.mlp_args._1_mlp_w13[i], mlp_weight_2 = self.weight_capture.layer_mlp.mlp_args._2_mlp_w2[i], mlp_weight_scale_13 = self.weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale[i], mlp_weight_scale_2 = self.weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale[i], # mlp params mlp_block_size_w13 = self.weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size, mlp_block_size_w2 = self.weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size, # vccl info world_size = self.weight_capture.layer_mlp.dist_args._0_world_size, rank = self.weight_capture.layer_mlp.dist_args._1_rank, group_id = self.weight_capture.layer_mlp.dist_args._2_group_id, dev_info = self.weight_capture.layer_mlp.dist_args._3_dev_info) else: wid = i - first_k_dense_replace if self.is_pipeline_first else i hidden_states, residual = fuse_mla_moe_v2_allreduce_decode( hidden_states = hidden_states, residual = residual, hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[wid], q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[wid], q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[wid], q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[wid], w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[wid], w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[wid], w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[wid], w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[wid], w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[wid], w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[wid], kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[wid], sin_cache = sin_cache,# self.weight_capture.layer_mlp.attn_args._10_sin_cache, cos_cache = cos_cache,# self.weight_capture.layer_mlp.attn_args._11_cos_cache, slot_mapping = attn_metadata.slot_mapping,#self.weight_capture.layer_mlp.attn_args._12_slot_mapping[i], kv_cache = kv_cache,#self.weight_capture.layer_mlp.attn_args._13_kv_cache[i], block_tables = attn_metadata.decode_metadata.block_tables, block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size, w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[wid], w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[wid], o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[wid], o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[wid], # mla params seq_lens = attn_metadata.decode_metadata.seq_lens, sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale, head_num = self.weight_capture.layer_moe.attn_args._22_head_num, # flash attention flash_attention = (USE_FLASH_ATTENTION==1), # moe weight rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[wid], mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[wid], mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[wid], mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[wid], mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[wid], moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[wid], moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[wid], moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[wid], moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[wid], mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[wid], moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[wid], # moe params mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size, mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size, moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size, moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size, # vccl info world_size = self.weight_capture.layer_moe.dist_args._0_world_size, rank = self.weight_capture.layer_moe.dist_args._1_rank, group_id = self.weight_capture.layer_moe.dist_args._2_group_id, dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class DeepseekV2ForCausalLM(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: from .memory.memory_recycling import init_huge_memory_allocator from .vars import LLM_MAX_PREFILL_SEQ_LEN from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager # default is deepseek, config can set to ['deepseek_mtp',] model_name = "deepseek" config_infos = vllm_vacc_config_manager().get_model_infos() if config_infos != "default": if config_infos in ['mtp']: model_name = "deepseek_mtp" else: model_name = config_infos if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name): logger.warning("init huge memory allocator fail. prefill memory recycling will disable") from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts) params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if test_layer_en == "1": test_layer = 5 if name not in ['model.embed_tokens.weight', 'model.norm.weight', 'lm_head.weight']: if int(name.split(".")[2]) > test_layer: continue # TODO(simon): support nextn predict layers if hasattr(self.config, "num_nextn_predict_layers" ) and self.config.num_nextn_predict_layers > 0: assert self.config.num_nextn_predict_layers == 1 layer_idx = self.config.num_hidden_layers if name.startswith(f"model.layers.{layer_idx}"): continue for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue # We have mlp.experts[0].gate_proj in the checkpoint. # Since we handle the experts below in expert_params_mapping, # we need to skip here BEFORE we update the name, otherwise # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. if (("mlp.experts." in name) and name not in params_dict): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) if USE_MERGE_Q_KV_GEN_AND_Q_QR: for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue layer.self_attn.merge_qkv_weights() return loaded_params def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: attn_metadata = get_forward_context().attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata.items().__iter__().__next__()[1] if attn_metadata.prefill_metadata is not None: from .memory.memory_recycling import alloc_memory_recycler from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager if hasattr(attn_metadata, 'num_prefill_tokens'): tokens = attn_metadata.num_prefill_tokens else: tokens = attn_metadata.prefill_metadata.num_prefill_tokens vllm_model_mode = "deepseek" config_infos = vllm_vacc_config_manager().get_model_infos() if config_infos != "default": if config_infos in ['mtp']: vllm_model_mode = "deepseek_mtp" else: vllm_model_mode = config_infos if get_tp_group().rank_in_group == 0: memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}' logger.info(memory_infos) if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size): logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens) hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states