Files

292 lines
15 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Set, Tuple, Optional
import torch
import torch.nn as nn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.forward_context import ForwardContext, get_forward_context
from .vars import *
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictorLayer as DeepSeekMultiTokenPredictorLayerOrig
from vllm.distributed import get_tp_group
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
def DeepSeekMultiTokenPredictorLayer__init__(self, vllm_config: VllmConfig, prefix: str) -> None:
super(DeepSeekMultiTokenPredictorLayerOrig, self).__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if USE_PARALLEL_MTP_EH_PROJ:
self.eh_proj = RowParallelLinear(config.hidden_size * 2,
config.hidden_size,
bias=False,
return_bias=False)
else:
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
from vllm.model_executor.models.deepseek_mtp import SharedHead
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda")
else:
topk_indices_buffer = None
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer)
class DeepSeekMultiTokenPredictorLayer(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
if not hasattr(self, "weight_capture"):
from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekMTPWegitCapture
self.weight_capture = DeepseekMTPWegitCapture(self.mtp_block)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
if inputs_embeds.shape[0] > 256:
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
deepseek_mtp_layer_input_buffer = None
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler):
deepseek_mtp_layer_input_buffer = memory_recycler.DEEPSEEK_MTP_LAYER_INPUT
from torch_vacc.vacc.custom_ops import fuse_mtp_stage0
hidden_states = fuse_mtp_stage0(
inputs_embeds,
previous_hidden_states,
positions,
self.enorm.weight,
self.hnorm.weight,
self.enorm.variance_epsilon,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos,
output=deepseek_mtp_layer_input_buffer,
)
# if USE_PARALLEL_MTP_EH_PROJ:
# tp_size = get_tensor_model_parallel_world_size()
# rank_id = get_tensor_model_parallel_rank()
# last_dim = hidden_states.shape[-1]
# if tp_size > 1:
# hiddens_tp = last_dim//tp_size
# hidden_states = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp]
hidden_states = self.eh_proj(hidden_states)
else:
hidden_states = torch.vacc.fuse_mtp_allreduce(
inputs_embeds,
previous_hidden_states,
positions,
self.enorm.weight,
self.hnorm.weight,
self.eh_proj.weight,
self.enorm.variance_epsilon,
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
rank = self.weight_capture.layer_moe.dist_args._1_rank,
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE):
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
else:
from torch_vacc.vacc.custom_ops import fuse_mla_moe_v2_allreduce_decode
layer = self.mtp_block
layer_id = 0
kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine]
positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens]
cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions]
sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions]
# 对于MTP Layer来说 residual为None且需要返回residual
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode(
hidden_states = hidden_states,
residual = None,
hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[layer_id],
q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[layer_id],
q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[layer_id],
q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[layer_id],
w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[layer_id],
w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[layer_id],
w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[layer_id],
w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[layer_id],
w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[layer_id],
w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[layer_id],
kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[layer_id],
sin_cache = sin_cache,
cos_cache = cos_cache,
slot_mapping = attn_metadata.slot_mapping,
kv_cache = kv_cache,
block_tables = attn_metadata.decode_metadata.block_tables,
block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[layer_id],
w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[layer_id],
o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[layer_id],
o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[layer_id],
# mla params
seq_lens = attn_metadata.decode_metadata.seq_lens,
sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale,
head_num = self.weight_capture.layer_moe.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# moe weight
rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[layer_id],
mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[layer_id],
mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[layer_id],
mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[layer_id],
mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[layer_id],
moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[layer_id],
moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[layer_id],
moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[layer_id],
moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[layer_id],
mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[layer_id],
moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[layer_id],
# moe params
mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
# vccl info
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
rank = self.weight_capture.layer_moe.dist_args._1_rank,
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
#hidden_states = residual + hidden_states
hidden_states = residual.add_(hidden_states)
return hidden_states
class DeepSeekMTP(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
from vllm.model_executor.models.deepseek_v2 import get_spec_layer_idx_from_weight_name
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if USE_MERGE_Q_KV_GEN_AND_Q_QR:
from vllm.model_executor.models.utils import PPMissingLayer
for layer_id in self.model.layers:
layer = self.model.layers[layer_id]
if isinstance(layer_id, PPMissingLayer):
continue
layer.mtp_block.self_attn.merge_qkv_weights()
return loaded_params
class SharedHead(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
try:
from torch_vacc.vacc.custom_ops import rms_norm
return rms_norm(hidden_states, self.norm.weight, output=hidden_states)
except Exception as e:
print(f"fuse rms_norm run fail, now use unfused ops: {e}")
return self.norm(hidden_states)