[MTP][V1] Adapt mtp with graph mode in v1. (#1023)
Adapts deepseek mtp with torch air graph mode in v1. --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -4,7 +4,8 @@ from vllm.config import (VllmConfig, get_layers_from_vllm_config,
|
||||
set_current_vllm_config)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
process_weights_after_loading, set_default_torch_dtype)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
|
||||
@@ -199,6 +200,8 @@ class MtpProposer:
|
||||
loader.get_all_weights(
|
||||
self.vllm_config.speculative_config.draft_model_config,
|
||||
self.model))
|
||||
process_weights_after_loading(self.model, draft_model_config,
|
||||
target_device)
|
||||
|
||||
|
||||
# TODO Using torch instead of triton may result in poor performance
|
||||
|
||||
Reference in New Issue
Block a user