[Feat]mtp aclgraph support (#3244)
### What this PR does / why we need it? Currently, MTP Model in deepseek can not be capture in ACLGraph. This PR is use to allow MTP to be captured in ACLGraph mode. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -23,6 +23,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -179,6 +180,7 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
return logits
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
Reference in New Issue
Block a user