[MTP][V1] Adapt mtp with graph mode in v1. (#1023)

Adapts deepseek mtp with torch air graph mode in v1.

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-06-09 22:21:42 +08:00
committed by GitHub
parent 5ac4872f5e
commit cd2f14a1b3
4 changed files with 87 additions and 24 deletions

View File

@@ -4,7 +4,8 @@ from vllm.config import (VllmConfig, get_layers_from_vllm_config,
set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
@@ -199,6 +200,8 @@ class MtpProposer:
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
process_weights_after_loading(self.model, draft_model_config,
target_device)
# TODO Using torch instead of triton may result in poor performance