### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
95 lines
3.4 KiB
Python
95 lines
3.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from transformers import PretrainedConfig
|
|
from vllm.config import VllmConfig
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
from vllm.model_executor.models.deepseek_mtp import \
|
|
DeepSeekMultiTokenPredictorLayer
|
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
|
from vllm.model_executor.models.utils import maybe_prefix
|
|
|
|
from vllm_ascend.utils import vllm_version_is
|
|
|
|
if vllm_version_is("0.11.0"):
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
|
|
|
|
|
class SharedHead(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
prefix: str,
|
|
quant_config: QuantizationConfig = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "head"),
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
return self.norm(hidden_states)
|
|
|
|
|
|
def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None:
|
|
nn.Module.__init__(self)
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
|
config.hidden_size,
|
|
bias=False)
|
|
# We don't need topk_indices_buffer in Ascend
|
|
topk_indices_buffer = None
|
|
self.shared_head = SharedHead(config=config,
|
|
prefix=prefix,
|
|
quant_config=quant_config)
|
|
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
|
|
topk_indices_buffer)
|
|
|
|
|
|
def predictor_forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
previous_hidden_states: torch.Tensor,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
spec_step_index: int = 0,
|
|
) -> torch.Tensor:
|
|
assert inputs_embeds is not None
|
|
# masking inputs at position 0, as not needed by MTP
|
|
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
|
|
inputs_embeds = self.enorm(inputs_embeds)
|
|
previous_hidden_states = self.hnorm(previous_hidden_states)
|
|
|
|
hidden_states = self.eh_proj(
|
|
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
|
|
|
hidden_states, residual = self.mtp_block(positions=positions,
|
|
hidden_states=hidden_states,
|
|
residual=None)
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0
|
|
@support_torch_compile
|
|
class AscendDeepSeekMTP(DeepSeekMTP):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
|
|
|
|
DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init
|
|
if vllm_version_is("0.11.0"):
|
|
DeepSeekMultiTokenPredictorLayer.forward = predictor_forward
|