[0.11.0][MTP][Aclgraph] Fix the support aclgraph with MTP (#3912)
### What this PR does / why we need it? Fix 2 breaks of aclgraph with MTP: 1. deepseekmtp in vllm 0.11.0 does not support aclgraph and lack the `support_torch_compile` decorator 2. There is a d2h synchornization in the original forward of mtp predictor. The fix pr in vllm https://github.com/vllm-project/vllm/pull/27643 As we'll fix it in vllm main, this fix pr is only needed in branch v0.11.0-dev The profling shows that MTP replays in aclgraph now: <img width="1612" height="1866" alt="a7d7f04155df4ed454b7eb20a92b2e2a" src="https://github.com/user-attachments/assets/eaa4b9ff-aeb0-416d-964f-5a06e497f155" /> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -1,16 +1,54 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import vllm
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.models.deepseek_mtp import \
|
from vllm.model_executor.models.deepseek_mtp import (
|
||||||
DeepSeekMultiTokenPredictorLayer
|
DeepSeekMTP, DeepSeekMultiTokenPredictorLayer)
|
||||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
||||||
from vllm.model_executor.models.utils import maybe_prefix
|
from vllm.model_executor.models.utils import maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
previous_hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
spec_step_index: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert inputs_embeds is not None
|
||||||
|
# masking inputs at position 0, as not needed by MTP
|
||||||
|
# Patch this for aclgraph support, as the original operation introduced d2h sync,
|
||||||
|
# which breaks aclgraph
|
||||||
|
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
|
||||||
|
inputs_embeds = self.enorm(inputs_embeds)
|
||||||
|
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.eh_proj(
|
||||||
|
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||||
|
|
||||||
|
hidden_states, residual = self.mtp_block(positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
residual=None)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0
|
||||||
|
@support_torch_compile
|
||||||
|
class AscendDeepSeekMTP(DeepSeekMTP):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
|
||||||
class SharedHead(nn.Module):
|
class SharedHead(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -53,3 +91,4 @@ def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init
|
DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init
|
||||||
|
vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer.forward = forward
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
|
|||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
from vllm.model_executor.model_loader.utils import (
|
from vllm.model_executor.model_loader.utils import (
|
||||||
process_weights_after_loading, set_default_torch_dtype)
|
process_weights_after_loading, set_default_torch_dtype)
|
||||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
@@ -19,6 +18,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
|
from vllm_ascend.patch.worker.patch_deepseek_mtp import \
|
||||||
|
AscendDeepSeekMTP as DeepSeekMTP
|
||||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||||
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
|
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
|
||||||
TorchairDeepSeekMTP
|
TorchairDeepSeekMTP
|
||||||
|
|||||||
Reference in New Issue
Block a user