Files
enginex-vastai-va16-vllm/vllm_vacc/vllm/model_executor/models/deepseek_v2.py
2026-04-02 04:55:00 +00:00

659 lines
35 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tp_group,
get_tensor_model_parallel_world_size,get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm.model_executor.models.deepseek_v2 import yarn_get_mscale, DeepseekV2MLAAttention, Indexer
from vllm.logger import init_logger
logger = init_logger(__name__)
from .vars import *
from ..ops.deepseek_fused_mlp_moe import (vacc_fused_decode_moe_fp8,
vacc_fused_prefill_moe_fp8,
vacc_fused_mlp_fp8)
from .fused_forward import *
import os
test_layer_en = os.getenv("test_layer_en", "0")
# class DeepseekV2MLAAttention(nn.Module):
# def __init__(
# self,
# vllm_config: VllmConfig,
# config: Union[DeepseekV2Config, DeepseekV3Config],
# hidden_size: int,
# num_heads: int,
# qk_nope_head_dim: int,
# qk_rope_head_dim: int,
# v_head_dim: int,
# q_lora_rank: Optional[int],
# kv_lora_rank: int,
# rope_theta: float = 10000,
# rope_scaling: Optional[dict[str, Any]] = None,
# max_position_embeddings: int = 8192,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# prefix: str = "",
# topk_indices_buffer: Optional[torch.Tensor] = None,
# ) -> None:
# super(DeepseekV2MLAAttention,self).__init__()
# self.hidden_size = hidden_size
# self.qk_nope_head_dim = qk_nope_head_dim
# self.qk_rope_head_dim = qk_rope_head_dim
# self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
# self.v_head_dim = v_head_dim
# self.q_lora_rank = q_lora_rank
# self.kv_lora_rank = kv_lora_rank
# self.num_heads = num_heads
# tp_size = get_tensor_model_parallel_world_size()
# assert num_heads % tp_size == 0
# self.num_local_heads = num_heads // tp_size
# self.scaling = self.qk_head_dim**-0.5
# self.rope_theta = rope_theta
# self.max_position_embeddings = max_position_embeddings
# if self.q_lora_rank is not None:
# if USE_PARALLEL_Q_KV_GEN:
# self.q_a_proj = RowParallelLinear(self.hidden_size,
# self.q_lora_rank,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.q_a_proj")
# else:
# self.q_a_proj = ReplicatedLinear(self.hidden_size,
# self.q_lora_rank,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.q_a_proj")
# self.q_a_layernorm = RMSNorm(self.q_lora_rank,
# eps=config.rms_norm_eps)
# self.q_b_proj = ColumnParallelLinear(q_lora_rank,
# self.num_heads *
# self.qk_head_dim,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.q_b_proj")
# else:
# self.q_proj = ColumnParallelLinear(self.hidden_size,
# self.num_heads *
# self.qk_head_dim,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.q_proj")
# if USE_PARALLEL_Q_KV_GEN:
# self.kv_a_proj_with_mqa = RowParallelLinear(
# self.hidden_size,
# self.kv_lora_rank + self.qk_rope_head_dim,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.kv_a_proj_with_mqa")
# else:
# self.kv_a_proj_with_mqa = ReplicatedLinear(
# self.hidden_size,
# self.kv_lora_rank + self.qk_rope_head_dim,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.kv_a_proj_with_mqa")
# self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
# eps=config.rms_norm_eps)
# self.kv_b_proj = ColumnParallelLinear(
# self.kv_lora_rank,
# self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.kv_b_proj")
# self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
# self.hidden_size,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.o_proj")
# rope_scaling["rope_type"] = 'deepseek_yarn'
# self.rotary_emb = get_rope(qk_rope_head_dim,
# rotary_dim=qk_rope_head_dim,
# max_position=max_position_embeddings,
# base=rope_theta,
# rope_scaling=rope_scaling,
# is_neox_style=False)
# if rope_scaling:
# mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
# scaling_factor = rope_scaling["factor"]
# mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
# self.scaling = self.scaling * mscale * mscale
# self.is_v32 = hasattr(config, "index_topk")
# if self.is_v32:
# self.indexer = Indexer(vllm_config, config, hidden_size,
# q_lora_rank, quant_config, cache_config,
# topk_indices_buffer, f"{prefix}.indexer")
# else:
# self.indexer = None
# self.mla_attn = Attention(
# num_heads=self.num_local_heads,
# head_size=self.kv_lora_rank,
# scale=self.scaling,
# num_kv_heads=1,
# cache_config=cache_config,
# quant_config=quant_config,
# prefix=f"{prefix}.attn",
# use_mla=True,
# # MLA Args
# q_lora_rank=self.q_lora_rank,
# kv_lora_rank=self.kv_lora_rank,
# qk_nope_head_dim=self.qk_nope_head_dim,
# qk_rope_head_dim=self.qk_rope_head_dim,
# qk_head_dim=self.qk_head_dim,
# v_head_dim=self.v_head_dim,
# rotary_emb=self.rotary_emb,
# q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
# kv_b_proj=self.kv_b_proj,
# o_proj=self.o_proj,
# )
# self.prefix = prefix
# self.debug_layer_idx = int(self.prefix.split(".")[-2])
# def forward(
# self,
# positions: torch.Tensor,
# hidden_states: torch.Tensor,
# kv_cache: torch.Tensor,
# attn_metadata: AttentionMetadata,
# ) -> torch.Tensor:
# tp_size = get_tensor_model_parallel_world_size()
# rank_id = get_tensor_model_parallel_rank()
# last_dim = hidden_states.shape[-1]
# if USE_PARALLEL_Q_KV_GEN: #tp qa and kva
# hidden_states_split = hidden_states
# if tp_size > 1:
# hiddens_tp = last_dim//tp_size
# hidden_states_split = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp].contiguous()
# if self.q_lora_rank is not None:
# ckq = self.q_a_proj(hidden_states_split)[0]
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
# else:
# hidden_states_or_q_c = hidden_states
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states_split)[0].split(
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
# attn_metadata)
# if self.q_lora_rank is not None:
# ckq = self.q_a_proj(hidden_states)[0]
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
# else:
# hidden_states_or_q_c = hidden_states
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
# attn_metadata)
class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor, residual = None, rms_norm = None):
# moe layer support prefill&decode vacc ops
if residual is not None:
try:
reduce_result = self.tp_size > 1
# decode moe, first seq
if self.is_decode:
hidden_states, residual = vacc_fused_decode_moe_fp8(self, self.shared_experts,
hidden_states, residual,
rms_norm, self.gate, self.experts,
self.routed_scaling_factor,
reduce_result)
return hidden_states, residual
# prefill moe, first expert
else:
hidden_states, residual = vacc_fused_prefill_moe_fp8(self, self.shared_experts,
hidden_states, residual,
rms_norm, self.gate, self.experts,
self.routed_scaling_factor,
reduce_result)
return hidden_states, residual
except Exception as e:
logger.warning("vacc fused moe run fail, now use unfused ops %s", e)
hidden_states, residual = rms_norm(hidden_states, residual)
self.experts.is_decode = self.is_decode
# 1. fuse_prefill_pre_moe
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
try:
shared_output = vacc_fused_mlp_fp8(self.shared_experts, hidden_states, moe_share=True)
except Exception as e:
logger.warning("fused mlp is Error, now use Default:%s", e)
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
# 2. fused_moe
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
# 3. add_reduce
# now fuse share_mlp add to experts
# if shared_output is not None:
# # out = input + other * alpha
# final_hidden_states = shared_output.add_(final_hidden_states, alpha=self.routed_scaling_factor)
# else:
# final_hidden_states = final_hidden_states * self.routed_scaling_factor
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if residual is not None:
return final_hidden_states.view(num_tokens, hidden_dim), residual
return final_hidden_states.view(num_tokens, hidden_dim)
class DeepseekV2MLP(nn.Module):
def forward(self, x, residual = None, rms_norm = None):
# use all fused ops
if residual is not None:
reduce_result = self.down_proj.reduce_results and self.down_proj.tp_size > 1
hidden_states, residual = vacc_fused_mlp_fp8(self,
x, residual,
rms_norm,
reduce_result)
return hidden_states, residual
# use default fuse ops
try:
output_parallel = vacc_fused_mlp_fp8(self, x, residual, rms_norm)
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
x = tensor_model_parallel_all_reduce(output_parallel)
else:
x = output_parallel
except Exception as e:
logger.warning("fuse_mlp run fail, now use default: %s", e)
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
first_k_dense_replace = self.config.first_k_dense_replace if hasattr(self.config, "first_k_dense_replace") else 3
if not hasattr(self, "weight_capture"):
from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekWeightCapture
self.weight_capture = DeepseekWeightCapture(self.layers, self.start_layer, self.end_layer)
self.cached_weights_state = True
self.cached_batch = 1
self.layer_nums = self.end_layer - self.start_layer
self.is_pipeline_first = get_pp_group().is_first_rank
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
else:
# update global seq lens, use for serve infos
# update_seqence_length(attn_metadata.decode_metadata.seq_lens)
if FUSE_ALL_DECODER_LAYERS:
self.weight_capture.update_attn_args(attn_metadata.decode_metadata.seq_lens,
attn_metadata.slot_mapping,
[self.layers[i].self_attn.mla_attn.kv_cache[forward_context.virtual_engine] for i in range(self.start_layer, first_k_dense_replace)],
[self.layers[i].self_attn.mla_attn.kv_cache[forward_context.virtual_engine] for i in range(first_k_dense_replace, self.end_layer)],
attn_metadata.decode_metadata.block_tables)
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 0)
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 1)
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 2)
if hidden_states.shape[0] != self.cached_batch:
# batch切换重新执行缓存
self.cached_weights_state = True
self.cached_batch = hidden_states.shape[0]
if self.cached_weights_state:
self.cached_weights_state = False
hidden_states, residual = forward_mla_moe_layers_with_weights(hidden_states, residual, self.weight_capture)
else:
hidden_states, residual = forward_mla_moe_layers_without_weights(hidden_states, residual, self.weight_capture)
else:
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode,fuse_mla_moe_v2_allreduce_decode
for i in range(0, self.layer_nums):
layer_id = i + self.start_layer
layer = self.layers[layer_id]
kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine]
positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens]
cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions]
sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions]
if layer_id < first_k_dense_replace:
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode(
hidden_states = hidden_states,
residual = residual,
hidden_states_norm_weight = self.weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight[i],
q_a_proj_weight = self.weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights[i],
q_a_proj_weight_scale_inv = self.weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv[i],
q_a_layernorm_weight = self.weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight[i],
w_q = self.weight_capture.layer_mlp.attn_args._3_W_Q[i],
w_q_scale = self.weight_capture.layer_mlp.attn_args._4_W_Q_scales[i],
w_uk = self.weight_capture.layer_mlp.attn_args._5_W_UK[i],
w_uk_scale = self.weight_capture.layer_mlp.attn_args._6_W_UK_scales[i],
w_qr = self.weight_capture.layer_mlp.attn_args._7_W_QR[i],
w_qr_scale = self.weight_capture.layer_mlp.attn_args._8_W_QR_scales[i],
kv_a_layernorm_weight = self.weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight[i],
sin_cache = sin_cache,# self.weight_capture.layer_mlp.attn_args._10_sin_cache,
cos_cache = cos_cache,# self.weight_capture.layer_mlp.attn_args._11_cos_cache,
slot_mapping = attn_metadata.slot_mapping,#self.weight_capture.layer_mlp.attn_args._12_slot_mapping[i],
kv_cache = kv_cache,#self.weight_capture.layer_mlp.attn_args._13_kv_cache[i],
block_tables = attn_metadata.decode_metadata.block_tables,#self.weight_capture.layer_mlp.attn_args._14_block_tables[i],
block_group_size = self.weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
w_uv = self.weight_capture.layer_mlp.attn_args._16_W_UV[i],
w_uv_scale = self.weight_capture.layer_mlp.attn_args._17_W_UV_scales[i],
o_proj_weight = self.weight_capture.layer_mlp.attn_args._18_o_proj_weight[i],
o_proj_weight_scale_inv = self.weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv[i],
# mla params
seq_lens = attn_metadata.decode_metadata.seq_lens,
sm_scale = self.weight_capture.layer_mlp.attn_args._21_sm_scale,
head_num = self.weight_capture.layer_mlp.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# mlp weight
rms_weight = self.weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight[i],
mlp_weight_13 = self.weight_capture.layer_mlp.mlp_args._1_mlp_w13[i],
mlp_weight_2 = self.weight_capture.layer_mlp.mlp_args._2_mlp_w2[i],
mlp_weight_scale_13 = self.weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale[i],
mlp_weight_scale_2 = self.weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale[i],
# mlp params
mlp_block_size_w13 = self.weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
mlp_block_size_w2 = self.weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
# vccl info
world_size = self.weight_capture.layer_mlp.dist_args._0_world_size,
rank = self.weight_capture.layer_mlp.dist_args._1_rank,
group_id = self.weight_capture.layer_mlp.dist_args._2_group_id,
dev_info = self.weight_capture.layer_mlp.dist_args._3_dev_info)
else:
wid = i - first_k_dense_replace if self.is_pipeline_first else i
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode(
hidden_states = hidden_states,
residual = residual,
hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[wid],
q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[wid],
q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[wid],
q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[wid],
w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[wid],
w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[wid],
w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[wid],
w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[wid],
w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[wid],
w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[wid],
kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[wid],
sin_cache = sin_cache,# self.weight_capture.layer_mlp.attn_args._10_sin_cache,
cos_cache = cos_cache,# self.weight_capture.layer_mlp.attn_args._11_cos_cache,
slot_mapping = attn_metadata.slot_mapping,#self.weight_capture.layer_mlp.attn_args._12_slot_mapping[i],
kv_cache = kv_cache,#self.weight_capture.layer_mlp.attn_args._13_kv_cache[i],
block_tables = attn_metadata.decode_metadata.block_tables,
block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[wid],
w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[wid],
o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[wid],
o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[wid],
# mla params
seq_lens = attn_metadata.decode_metadata.seq_lens,
sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale,
head_num = self.weight_capture.layer_moe.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# moe weight
rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[wid],
mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[wid],
mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[wid],
mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[wid],
mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[wid],
moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[wid],
moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[wid],
moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[wid],
moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[wid],
mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[wid],
moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[wid],
# moe params
mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
# vccl info
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
rank = self.weight_capture.layer_moe.dist_args._1_rank,
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
from .memory.memory_recycling import init_huge_memory_allocator
from .vars import LLM_MAX_PREFILL_SEQ_LEN
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
# default is deepseek, config can set to ['deepseek_mtp',]
model_name = "deepseek"
config_infos = vllm_vacc_config_manager().get_model_infos()
if config_infos != "default":
if config_infos in ['mtp']:
model_name = "deepseek_mtp"
else:
model_name = config_infos
if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name):
logger.warning("init huge memory allocator fail. prefill memory recycling will disable")
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if test_layer_en == "1":
test_layer = 5
if name not in ['model.embed_tokens.weight', 'model.norm.weight', 'lm_head.weight']:
if int(name.split(".")[2]) > test_layer:
continue
# TODO(simon): support nextn predict layers
if hasattr(self.config, "num_nextn_predict_layers"
) and self.config.num_nextn_predict_layers > 0:
assert self.config.num_nextn_predict_layers == 1
layer_idx = self.config.num_hidden_layers
if name.startswith(f"model.layers.{layer_idx}"):
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if USE_MERGE_Q_KV_GEN_AND_Q_QR:
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
layer.self_attn.merge_qkv_weights()
return loaded_params
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
if attn_metadata.prefill_metadata is not None:
from .memory.memory_recycling import alloc_memory_recycler
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
if hasattr(attn_metadata, 'num_prefill_tokens'):
tokens = attn_metadata.num_prefill_tokens
else:
tokens = attn_metadata.prefill_metadata.num_prefill_tokens
vllm_model_mode = "deepseek"
config_infos = vllm_vacc_config_manager().get_model_infos()
if config_infos != "default":
if config_infos in ['mtp']:
vllm_model_mode = "deepseek_mtp"
else:
vllm_model_mode = config_infos
if get_tp_group().rank_in_group == 0:
memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}'
logger.info(memory_infos)
if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size):
logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states