292 lines
15 KiB
Python
292 lines
15 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
from typing import Iterable, Set, Tuple, Optional
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||
|
||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||
from vllm.forward_context import ForwardContext, get_forward_context
|
||
|
||
from .vars import *
|
||
|
||
from transformers import PretrainedConfig
|
||
|
||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictorLayer as DeepSeekMultiTokenPredictorLayerOrig
|
||
|
||
from vllm.distributed import get_tp_group
|
||
|
||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
||
|
||
def DeepSeekMultiTokenPredictorLayer__init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||
super(DeepSeekMultiTokenPredictorLayerOrig, self).__init__()
|
||
config = vllm_config.model_config.hf_config
|
||
quant_config = vllm_config.quant_config
|
||
self.embed_tokens = VocabParallelEmbedding(
|
||
config.vocab_size,
|
||
config.hidden_size,
|
||
)
|
||
|
||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||
|
||
if USE_PARALLEL_MTP_EH_PROJ:
|
||
self.eh_proj = RowParallelLinear(config.hidden_size * 2,
|
||
config.hidden_size,
|
||
bias=False,
|
||
return_bias=False)
|
||
else:
|
||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||
config.hidden_size,
|
||
bias=False)
|
||
|
||
from vllm.model_executor.models.deepseek_mtp import SharedHead
|
||
self.is_v32 = hasattr(config, "index_topk")
|
||
if self.is_v32:
|
||
topk_tokens = config.index_topk
|
||
topk_indices_buffer = torch.empty(
|
||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||
topk_tokens,
|
||
dtype=torch.int32,
|
||
device="cuda")
|
||
else:
|
||
topk_indices_buffer = None
|
||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
|
||
topk_indices_buffer)
|
||
|
||
class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||
|
||
def forward(
|
||
self,
|
||
input_ids: torch.Tensor,
|
||
positions: torch.Tensor,
|
||
previous_hidden_states: torch.Tensor,
|
||
inputs_embeds: Optional[torch.Tensor] = None,
|
||
spec_step_index: int = 0,
|
||
) -> torch.Tensor:
|
||
|
||
forward_context: ForwardContext = get_forward_context()
|
||
attn_metadata = forward_context.attn_metadata
|
||
if isinstance(attn_metadata, dict):
|
||
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
|
||
|
||
if not hasattr(self, "weight_capture"):
|
||
from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekMTPWegitCapture
|
||
self.weight_capture = DeepseekMTPWegitCapture(self.mtp_block)
|
||
|
||
if inputs_embeds is None:
|
||
inputs_embeds = self.embed_tokens(input_ids)
|
||
assert inputs_embeds is not None
|
||
|
||
if inputs_embeds.shape[0] > 256:
|
||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
|
||
deepseek_mtp_layer_input_buffer = None
|
||
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler):
|
||
deepseek_mtp_layer_input_buffer = memory_recycler.DEEPSEEK_MTP_LAYER_INPUT
|
||
|
||
from torch_vacc.vacc.custom_ops import fuse_mtp_stage0
|
||
hidden_states = fuse_mtp_stage0(
|
||
inputs_embeds,
|
||
previous_hidden_states,
|
||
positions,
|
||
self.enorm.weight,
|
||
self.hnorm.weight,
|
||
self.enorm.variance_epsilon,
|
||
world_size=get_tp_group().world_size,
|
||
rank=get_tp_group().rank_in_group,
|
||
group_id=get_tp_group().group_id,
|
||
dev_info=get_tp_group().rank_device_infos,
|
||
output=deepseek_mtp_layer_input_buffer,
|
||
)
|
||
|
||
# if USE_PARALLEL_MTP_EH_PROJ:
|
||
# tp_size = get_tensor_model_parallel_world_size()
|
||
# rank_id = get_tensor_model_parallel_rank()
|
||
# last_dim = hidden_states.shape[-1]
|
||
# if tp_size > 1:
|
||
# hiddens_tp = last_dim//tp_size
|
||
# hidden_states = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp]
|
||
|
||
hidden_states = self.eh_proj(hidden_states)
|
||
else:
|
||
hidden_states = torch.vacc.fuse_mtp_allreduce(
|
||
inputs_embeds,
|
||
previous_hidden_states,
|
||
positions,
|
||
self.enorm.weight,
|
||
self.hnorm.weight,
|
||
self.eh_proj.weight,
|
||
self.enorm.variance_epsilon,
|
||
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
|
||
rank = self.weight_capture.layer_moe.dist_args._1_rank,
|
||
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
|
||
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
|
||
|
||
if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE):
|
||
hidden_states, residual = self.mtp_block(positions=positions,
|
||
hidden_states=hidden_states,
|
||
residual=None)
|
||
else:
|
||
from torch_vacc.vacc.custom_ops import fuse_mla_moe_v2_allreduce_decode
|
||
layer = self.mtp_block
|
||
layer_id = 0
|
||
|
||
kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine]
|
||
positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens]
|
||
cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions]
|
||
sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions]
|
||
|
||
# 对于MTP Layer来说, residual为None,且需要返回residual
|
||
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode(
|
||
hidden_states = hidden_states,
|
||
residual = None,
|
||
hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[layer_id],
|
||
q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[layer_id],
|
||
q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[layer_id],
|
||
q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[layer_id],
|
||
w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[layer_id],
|
||
w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[layer_id],
|
||
w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[layer_id],
|
||
w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[layer_id],
|
||
w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[layer_id],
|
||
w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[layer_id],
|
||
kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[layer_id],
|
||
sin_cache = sin_cache,
|
||
cos_cache = cos_cache,
|
||
slot_mapping = attn_metadata.slot_mapping,
|
||
kv_cache = kv_cache,
|
||
block_tables = attn_metadata.decode_metadata.block_tables,
|
||
block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
|
||
w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[layer_id],
|
||
w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[layer_id],
|
||
o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[layer_id],
|
||
o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[layer_id],
|
||
# mla params
|
||
seq_lens = attn_metadata.decode_metadata.seq_lens,
|
||
sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale,
|
||
head_num = self.weight_capture.layer_moe.attn_args._22_head_num,
|
||
# flash attention
|
||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||
# moe weight
|
||
rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[layer_id],
|
||
mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[layer_id],
|
||
mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[layer_id],
|
||
mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[layer_id],
|
||
mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[layer_id],
|
||
moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[layer_id],
|
||
moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[layer_id],
|
||
moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[layer_id],
|
||
moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[layer_id],
|
||
mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[layer_id],
|
||
moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[layer_id],
|
||
# moe params
|
||
mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
|
||
mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
|
||
moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
|
||
moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
|
||
# vccl info
|
||
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
|
||
rank = self.weight_capture.layer_moe.dist_args._1_rank,
|
||
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
|
||
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
|
||
#hidden_states = residual + hidden_states
|
||
hidden_states = residual.add_(hidden_states)
|
||
return hidden_states
|
||
|
||
class DeepSeekMTP(nn.Module):
|
||
|
||
def load_weights(self, weights: Iterable[Tuple[str,
|
||
torch.Tensor]]) -> Set[str]:
|
||
stacked_params_mapping = [
|
||
("gate_up_proj", "gate_proj", 0),
|
||
("gate_up_proj", "up_proj", 1),
|
||
]
|
||
|
||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||
ckpt_gate_proj_name="gate_proj",
|
||
ckpt_down_proj_name="down_proj",
|
||
ckpt_up_proj_name="up_proj",
|
||
num_experts=self.config.n_routed_experts)
|
||
|
||
params_dict = dict(self.named_parameters())
|
||
loaded_params: Set[str] = set()
|
||
from vllm.model_executor.models.deepseek_v2 import get_spec_layer_idx_from_weight_name
|
||
for name, loaded_weight in weights:
|
||
if "rotary_emb.inv_freq" in name:
|
||
continue
|
||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||
if spec_layer is None:
|
||
continue
|
||
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||
# Skip non-stacked layers and experts (experts handled below).
|
||
if weight_name not in name:
|
||
continue
|
||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||
# Since we handle the experts below in expert_params_mapping,
|
||
# we need to skip here BEFORE we update the name, otherwise
|
||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||
# will then be updated below in expert_params_mapping
|
||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||
if (("mlp.experts." in name) and name not in params_dict):
|
||
continue
|
||
name = name.replace(weight_name, param_name)
|
||
# Skip loading extra bias for GPTQ models.
|
||
if name.endswith(".bias") and name not in params_dict:
|
||
continue
|
||
|
||
param = params_dict[name]
|
||
weight_loader = param.weight_loader
|
||
weight_loader(param, loaded_weight, shard_id)
|
||
break
|
||
else:
|
||
for mapping in expert_params_mapping:
|
||
param_name, weight_name, expert_id, shard_id = mapping
|
||
if weight_name not in name:
|
||
continue
|
||
name = name.replace(weight_name, param_name)
|
||
|
||
param = params_dict[name]
|
||
weight_loader = param.weight_loader
|
||
weight_loader(param,
|
||
loaded_weight,
|
||
name,
|
||
shard_id=shard_id,
|
||
expert_id=expert_id)
|
||
break
|
||
else:
|
||
# Skip loading extra bias for GPTQ models.
|
||
if name.endswith(".bias") and name not in params_dict:
|
||
continue
|
||
|
||
param = params_dict[name]
|
||
weight_loader = getattr(param, "weight_loader",
|
||
default_weight_loader)
|
||
weight_loader(param, loaded_weight)
|
||
loaded_params.add(name)
|
||
|
||
if USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||
from vllm.model_executor.models.utils import PPMissingLayer
|
||
for layer_id in self.model.layers:
|
||
layer = self.model.layers[layer_id]
|
||
if isinstance(layer_id, PPMissingLayer):
|
||
continue
|
||
layer.mtp_block.self_attn.merge_qkv_weights()
|
||
return loaded_params
|
||
|
||
class SharedHead(nn.Module):
|
||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
try:
|
||
from torch_vacc.vacc.custom_ops import rms_norm
|
||
return rms_norm(hidden_states, self.norm.weight, output=hidden_states)
|
||
except Exception as e:
|
||
print(f"fuse rms_norm run fail, now use unfused ops: {e}")
|
||
|
||
return self.norm(hidden_states) |