[Feat][Graph] Support MTP for ACL Graph (#2932)
### What this PR does / why we need it?
This PR depends on the merge of #2707 and has adapted the aclgraph
functionality to support MTP.
### How was this patch tested?
- vLLM version: v0.10.2
- vLLM main:
2b85697031
---------
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from torchair import patch_for_hcom
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
|
||||
set_current_vllm_config)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
process_weights_after_loading, set_default_torch_dtype)
|
||||
@@ -363,8 +363,14 @@ class MtpProposer(Proposer):
|
||||
not self.runner.with_prefill
|
||||
|
||||
if is_running_torchair:
|
||||
# Torchair graph mode, padding is same as the main model
|
||||
num_input_tokens = self.runner.graph_pad_size
|
||||
elif (self.runner.use_aclgraph
|
||||
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
|
||||
# Acl graph mode, add padding to the batch size
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
# Eager mode, no padding needed
|
||||
num_input_tokens = num_tokens
|
||||
|
||||
seq_lens = target_positions[last_token_indices] + 1
|
||||
@@ -410,7 +416,7 @@ class MtpProposer(Proposer):
|
||||
# TODO: adapt enable_dbo later
|
||||
(num_input_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self.runner._sync_metadata_across_dp(
|
||||
num_tokens, self.runner.with_prefill, False)
|
||||
num_input_tokens, self.runner.with_prefill, False)
|
||||
else:
|
||||
# torchair mode can reuse self.runner.num_tokens_across_dp
|
||||
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
||||
@@ -418,6 +424,10 @@ class MtpProposer(Proposer):
|
||||
|
||||
moe_comm_method = self.runner._select_moe_comm_method(
|
||||
num_input_tokens, with_prefill)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=False)
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
|
||||
|
||||
for step in range(self.num_speculative_tokens):
|
||||
with set_ascend_forward_context(
|
||||
@@ -428,6 +438,7 @@ class MtpProposer(Proposer):
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
moe_comm_method=moe_comm_method,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=num_tokens):
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
|
||||
Reference in New Issue
Block a user