Files
enginex-vastai-va16-vllm/vllm_vacc/vllm/model_executor/models/deepseek_mtp.py
2026-04-02 04:55:00 +00:00

292 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Set, Tuple, Optional
import torch
import torch.nn as nn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.forward_context import ForwardContext, get_forward_context
from .vars import *
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictorLayer as DeepSeekMultiTokenPredictorLayerOrig
from vllm.distributed import get_tp_group
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
def DeepSeekMultiTokenPredictorLayer__init__(self, vllm_config: VllmConfig, prefix: str) -> None:
super(DeepSeekMultiTokenPredictorLayerOrig, self).__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if USE_PARALLEL_MTP_EH_PROJ:
self.eh_proj = RowParallelLinear(config.hidden_size * 2,
config.hidden_size,
bias=False,
return_bias=False)
else:
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
from vllm.model_executor.models.deepseek_mtp import SharedHead
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda")
else:
topk_indices_buffer = None
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer)
class DeepSeekMultiTokenPredictorLayer(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
if not hasattr(self, "weight_capture"):
from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekMTPWegitCapture
self.weight_capture = DeepseekMTPWegitCapture(self.mtp_block)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
if inputs_embeds.shape[0] > 256:
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
deepseek_mtp_layer_input_buffer = None
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler):
deepseek_mtp_layer_input_buffer = memory_recycler.DEEPSEEK_MTP_LAYER_INPUT
from torch_vacc.vacc.custom_ops import fuse_mtp_stage0
hidden_states = fuse_mtp_stage0(
inputs_embeds,
previous_hidden_states,
positions,
self.enorm.weight,
self.hnorm.weight,
self.enorm.variance_epsilon,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos,
output=deepseek_mtp_layer_input_buffer,
)
# if USE_PARALLEL_MTP_EH_PROJ:
# tp_size = get_tensor_model_parallel_world_size()
# rank_id = get_tensor_model_parallel_rank()
# last_dim = hidden_states.shape[-1]
# if tp_size > 1:
# hiddens_tp = last_dim//tp_size
# hidden_states = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp]
hidden_states = self.eh_proj(hidden_states)
else:
hidden_states = torch.vacc.fuse_mtp_allreduce(
inputs_embeds,
previous_hidden_states,
positions,
self.enorm.weight,
self.hnorm.weight,
self.eh_proj.weight,
self.enorm.variance_epsilon,
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
rank = self.weight_capture.layer_moe.dist_args._1_rank,
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE):
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
else:
from torch_vacc.vacc.custom_ops import fuse_mla_moe_v2_allreduce_decode
layer = self.mtp_block
layer_id = 0
kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine]
positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens]
cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions]
sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions]
# 对于MTP Layer来说 residual为None且需要返回residual
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode(
hidden_states = hidden_states,
residual = None,
hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[layer_id],
q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[layer_id],
q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[layer_id],
q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[layer_id],
w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[layer_id],
w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[layer_id],
w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[layer_id],
w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[layer_id],
w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[layer_id],
w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[layer_id],
kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[layer_id],
sin_cache = sin_cache,
cos_cache = cos_cache,
slot_mapping = attn_metadata.slot_mapping,
kv_cache = kv_cache,
block_tables = attn_metadata.decode_metadata.block_tables,
block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[layer_id],
w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[layer_id],
o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[layer_id],
o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[layer_id],
# mla params
seq_lens = attn_metadata.decode_metadata.seq_lens,
sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale,
head_num = self.weight_capture.layer_moe.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# moe weight
rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[layer_id],
mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[layer_id],
mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[layer_id],
mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[layer_id],
mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[layer_id],
moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[layer_id],
moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[layer_id],
moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[layer_id],
moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[layer_id],
mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[layer_id],
moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[layer_id],
# moe params
mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
# vccl info
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
rank = self.weight_capture.layer_moe.dist_args._1_rank,
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
#hidden_states = residual + hidden_states
hidden_states = residual.add_(hidden_states)
return hidden_states
class DeepSeekMTP(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
from vllm.model_executor.models.deepseek_v2 import get_spec_layer_idx_from_weight_name
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if USE_MERGE_Q_KV_GEN_AND_Q_QR:
from vllm.model_executor.models.utils import PPMissingLayer
for layer_id in self.model.layers:
layer = self.model.layers[layer_id]
if isinstance(layer_id, PPMissingLayer):
continue
layer.mtp_block.self_attn.merge_qkv_weights()
return loaded_params
class SharedHead(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
try:
from torch_vacc.vacc.custom_ops import rms_norm
return rms_norm(hidden_states, self.norm.weight, output=hidden_states)
except Exception as e:
print(f"fuse rms_norm run fail, now use unfused ops: {e}")
return self.norm(hidden_states)