Files
2026-03-10 13:31:25 +08:00

100 lines
3.9 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch
import torch.nn as nn
from fastcore.basics import patch_to
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, SharedHead)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
# from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.model_executor.layers.sampler import get_sampler
@patch_to(DeepSeekMultiTokenPredictorLayer)
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
super(DeepSeekMultiTokenPredictorLayer, self).__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda")
else:
topk_indices_buffer = None
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer)
@patch_to(DeepSeekMultiTokenPredictorLayer)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
torch.zeros_like(inputs_embeds), inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds.unsqueeze(0))
previous_hidden_states = self.hnorm(previous_hidden_states.unsqueeze(0))
fused_hidden_states = torch.cat([inputs_embeds, previous_hidden_states],
dim=-1)
hidden_states = self.eh_proj(fused_hidden_states)
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
hidden_states = residual + hidden_states
return hidden_states.squeeze(0)
@patch_to(DeepSeekMultiTokenPredictor)
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers)
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor(
mtp_layer.shared_head.head,
mtp_layer.shared_head(
hidden_states.unsqueeze(0)).squeeze(0).contiguous())
return logits