[Feat] Support MTP to running in full graph mode (#3892)

### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.

The change in both disable_padded_drafter_batch is True and False case
include:

1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
    1). Rebuild MTP model with ACLGraphWrapper.
    2). Add common attn metadata when start capture in MTP dummy_run.
    3). Add common attn metadata update in MTP.
    4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.

Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

---------

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-11-20 20:34:54 +08:00
committed by GitHub
parent 15c1eb025c
commit 5c9f4a40c6
8 changed files with 536 additions and 42 deletions

View File

@@ -256,6 +256,57 @@ class TestAscendMLAMetadataBuilder(TestBase):
builder.chunked_prefill_enabled, builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.chunked_prefill_enabled) mock_vllm_config.scheduler_config.chunked_prefill_enabled)
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_ascend_mla_metadata_builder_build_full_graph(
self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_spec_config = MagicMock()
mock_spec_config.num_speculative_tokens = 1
mock_spec_config.disable_padded_drafter_batch = True
mock_vllm_config.speculative_config = mock_spec_config
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
mock_device)
common_metadata = MagicMock()
model = MagicMock()
common_metadata.graph_pad_size = 8
common_metadata.num_reqs = 4
common_metadata.num_actual_tokens = 5
common_metadata.max_query_len = 5
common_metadata.seq_lens_cpu = torch.Tensor([9, 10, 8, 8]).int()
common_metadata.query_start_loc = torch.Tensor([0, 1, 2, 4, 5]).int()
common_metadata.query_start_loc_cpu = torch.Tensor([0, 1, 2, 4,
5]).int()
common_metadata.positions = torch.Tensor([1, 2, 3, 4, 5, 6]).int()
block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int()
common_metadata.block_table_tensor = block_table
common_metadata.prefill_context_parallel_metadata = None
metadata = builder.build(0, common_metadata, model)
self.assertEqual(metadata.decode.actual_seq_lengths_q,
[1, 2, 4, 5, 6, 6, 7, 8])
self.assertEqual(metadata.decode.block_table.shape[0], 8)
@patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP', @patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator)) new_callable=lambda: MagicMock(spec=GroupCoordinator))
@@ -307,6 +358,80 @@ class TestAscendMLAMetadataBuilder(TestBase):
self.assertTrue(modified) self.assertTrue(modified)
input_batch.swap_states.assert_called_once_with(1, 2) input_batch.swap_states.assert_called_once_with(1, 2)
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_get_dcp_size,
mock_dcp,
mock_get_dcp_group):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
mock_vllm_config.speculative_config = None
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
mock_device)
input_seq_lens = [1, 2, 4, 5]
expect_output = [1, 2, 4, 5, 6, 6, 7, 8]
num_reqs = 4
num_reqs_pad_size = 4
output_seq_lens = builder.pad_actual_seq_len_q_mtp_disable_pad(
num_reqs_pad_size, num_reqs, input_seq_lens)
self.assertEqual(output_seq_lens, expect_output)
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_get_dcp_size,
mock_dcp,
mock_get_dcp_group):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
mock_vllm_config.speculative_config = None
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
common_metadata = MagicMock()
common_metadata.actual_seq_lengths_q = [2, 4, 6, 8]
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
mock_device)
input_seq_lens = [2, 4, 6]
expect_output = [2, 4, 6, 8]
num_reqs = 3
num_reqs_pad_size = 1
output_seq_lens = builder.pad_actual_seq_len_q_mtp_enable_pad(
num_reqs_pad_size, num_reqs, input_seq_lens, common_metadata)
self.assertEqual(output_seq_lens, expect_output)
class TestAscendMLAMetadataBuilderBuild(TestBase): class TestAscendMLAMetadataBuilderBuild(TestBase):

View File

@@ -21,7 +21,9 @@ from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor, ForwardContext from vllm.forward_context import BatchDescriptor, ForwardContext
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.compilation.acl_graph import ACLGraphEntry, ACLGraphWrapper from vllm_ascend.compilation.acl_graph import (
ACLGraphEntry, ACLGraphWrapper, get_mtp_graph_params, set_mtp_graph_params,
update_mtp_graph_params_workspaces)
class TestACLGraphEntry(TestBase): class TestACLGraphEntry(TestBase):
@@ -718,3 +720,24 @@ class TestACLGraphWrapper(TestBase):
unwrapped = wrapper.unwrap() unwrapped = wrapper.unwrap()
self.assertEqual(unwrapped, self.mock_runnable) self.assertEqual(unwrapped, self.mock_runnable)
class TestMTPGraphParams(TestBase):
def test_set_mtp_graph_params(self):
with patch('vllm_ascend.compilation.acl_graph._mtp_graph_params',
new=None):
set_mtp_graph_params([4])
from vllm_ascend.compilation.acl_graph import _mtp_graph_params
self.assertIsNotNone(_mtp_graph_params)
@patch('vllm_ascend.compilation.acl_graph._mtp_graph_params')
def test_update_mtp_graph_params_workspaces(self, mtp_graph_params_mock):
mtp_graph_params_mock.workspaces = {4: 5}
update_mtp_graph_params_workspaces(4, 6)
self.assertEqual(mtp_graph_params_mock.workspaces[4], 6)
@patch('vllm_ascend.compilation.acl_graph._mtp_graph_params')
def test_get_mtp_graph_params(self, mtp_graph_params_mock):
graph_params = get_mtp_graph_params()
self.assertIs(mtp_graph_params_mock, graph_params)

View File

@@ -72,7 +72,8 @@ def set_ascend_forward_context(
batch_descriptor: Optional[BatchDescriptor] = None, batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None, prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None, model_instance: torch.nn.Module = None,
weight_prefetch_method: Optional[WeightPrefetchMethod] = None): weight_prefetch_method: Optional[WeightPrefetchMethod] = None,
is_mtp_model=False):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
We add some additional param into forward_context. We add some additional param into forward_context.
@@ -158,6 +159,7 @@ def set_ascend_forward_context(
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
forward_context.model_instance = model_instance forward_context.model_instance = model_instance
forward_context.weight_prefetch_method = weight_prefetch_method forward_context.weight_prefetch_method = weight_prefetch_method
forward_context.is_mtp_model = is_mtp_model
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant. # TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
# It will be improved later by implementing operator fusion through the FX graph. # It will be improved later by implementing operator fusion through the FX graph.

View File

@@ -8,6 +8,7 @@ import torch.distributed as dist
import torch_npu import torch_npu
from torch import nn from torch import nn
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import (get_dcp_group, from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank, get_decode_context_model_parallel_rank,
@@ -38,6 +39,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata, trans_rope_weight, transdata,
wait_for_kv_layer_from_connector) wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params, from vllm_ascend.compilation.acl_graph import (get_graph_params,
get_mtp_graph_params,
update_graph_params_workspaces) update_graph_params_workspaces)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
@@ -337,6 +339,74 @@ class AscendMLAMetadataBuilder:
# better way of doing this # better way of doing this
return modified_batch return modified_batch
def pad_actual_seq_len_q_mtp_enable_pad(self, num_reqs_pad_size, num_reqs,
actual_seq_lengths_q,
common_attn_metadata):
"""
Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request
in order to meet the requirement of npu_fused_infer_attention_score.
In Torchair scenario, the lengths of the queries must be padded to the same length.
And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens).
For example:
batch_size=36, num_reqs_pad_size=2, num_reqs=16
By default, each request should have inference 2 token, which means actual_seq_lengths_q should be
[2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36].
However, mtp torchair + PD scenario, the actual_seq_lengths_q may be
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token.
In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request.
after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36]
"""
FIA_SEQ_LEN_LIMIT = 16
need_padding = num_reqs_pad_size != 0 and \
len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \
common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT
if need_padding:
padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[
num_reqs:num_reqs + num_reqs_pad_size]
start_val = actual_seq_lengths_q[-1]
end_val = padding_seq_len_q[-1]
num_step = len(padding_seq_len_q)
interpolated = np.round(
np.linspace(start_val, end_val,
num_step + 1)[1:]).astype(int).tolist()
assert interpolated[-1] == end_val
assert len(interpolated) == len(padding_seq_len_q)
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
else:
actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[
num_reqs:num_reqs + num_reqs_pad_size]
return actual_seq_lengths_q
def pad_actual_seq_len_q_mtp_disable_pad(self, num_reqs_pad_size, num_reqs,
actual_seq_lengths_q):
"""
Only use for acl full graph mode.
Pad the last element of the actual_seq_lengths_q equal to the TND(T) and
the num of dimensions equal to the batch_size of main model.
For example:
batch_size = 8, num_reqs = 4, num_speculative_tokens = 1
input actual_seq_lengths_q = [1, 2, 4, 5] (the 3rd req was accept a token)
After padding the actual_seq_lengths_q will be similar to [1, 2, 4, 5, 6, 6, 7, 8]
"""
need_padding = num_reqs_pad_size > 0
if need_padding:
start_val = actual_seq_lengths_q[-1]
end_val = num_reqs + num_reqs_pad_size
num_step = num_reqs_pad_size
interpolated = np.round(
np.linspace(start_val, end_val,
num_step + 1)[1:]).astype(int).tolist()
assert interpolated[-1] == end_val
assert len(interpolated) == num_reqs_pad_size
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
return actual_seq_lengths_q
def build( def build(
self, self,
common_prefix_len: int, common_prefix_len: int,
@@ -362,17 +432,25 @@ class AscendMLAMetadataBuilder:
# it blocks on all previous kernels. # it blocks on all previous kernels.
device = self.device device = self.device
# If graph_pad_size > -1, mean is running in fullgraph mode.
graph_pad_size = common_attn_metadata.graph_pad_size
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if graph_pad_size > num_reqs and self.speculative_config.disable_padded_drafter_batch:
block_table = (
common_attn_metadata.block_table_tensor[:graph_pad_size])
else:
block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
if self.pcp_size > 1: if self.pcp_size > 1:
num_decodes_flatten = num_decodes * self.decode_threshold num_decodes_flatten = num_decodes * self.decode_threshold
block_table = common_attn_metadata.block_table_tensor[: block_table = common_attn_metadata.block_table_tensor[:
num_decodes_flatten num_decodes_flatten
+ +
num_prefills] num_prefills]
if num_actual_tokens_pcp_padded is None: if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens num_actual_tokens_pcp_padded = num_actual_tokens
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
slot_mapping = common_attn_metadata.slot_mapping[: slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded] num_actual_tokens_pcp_padded]
input_positions = common_attn_metadata.positions[: input_positions = common_attn_metadata.positions[:
@@ -565,6 +643,11 @@ class AscendMLAMetadataBuilder:
block_table = block_table[:num_decodes_flatten, ...] block_table = block_table[:num_decodes_flatten, ...]
else: else:
block_table = block_table[:num_decodes, ...] block_table = block_table[:num_decodes, ...]
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if graph_pad_size > num_decodes and \
self.speculative_config.disable_padded_drafter_batch:
block_table = block_table[:graph_pad_size, ...]
seq_lens_list = seq_lens.tolist() seq_lens_list = seq_lens.tolist()
if num_computed_tokens_of_pcp_dcp is not None: if num_computed_tokens_of_pcp_dcp is not None:
@@ -586,6 +669,52 @@ class AscendMLAMetadataBuilder:
else: else:
cp_seq_len, batch_seq_mask = None, None cp_seq_len, batch_seq_mask = None, None
if graph_pad_size > num_reqs:
if self.speculative_config.disable_padded_drafter_batch:
num_reqs_pad_size = graph_pad_size - num_reqs
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
seq_lens_list = seq_lens_list + [0] * (graph_pad_size - \
num_decodes)
num_block_pad_size = graph_pad_size - block_table.shape[0]
if num_block_pad_size > 0:
block_table_padding = torch.zeros(
(num_block_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat(
[block_table, block_table_padding], dim=0)
else:
num_token_pad_size = graph_pad_size - num_decode_tokens
num_reqs_pad_size = (
graph_pad_size //
common_attn_metadata.decode_token_per_req - num_reqs)
num_block_table_pad_size = (
graph_pad_size //
common_attn_metadata.decode_token_per_req -
num_decodes)
seq_lens_list = seq_lens.tolist() + [0] * num_reqs_pad_size
slot_padding = torch.full((num_token_pad_size, ),
PAD_SLOT_ID,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
slot_mapping = torch.cat([slot_mapping, slot_padding])
block_table_padding = torch.zeros(
(num_block_table_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat([block_table, block_table_padding],
dim=0)
position_padding = torch.zeros(
num_token_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)
input_positions = torch.cat(
[input_positions, position_padding])
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
common_attn_metadata)
# TODO: After the fullgraph supports MTP, the if branch needs to deleted # TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None assert self.cos_cache is not None
assert self.sin_cache is not None assert self.sin_cache is not None
@@ -1267,8 +1396,11 @@ class AscendMLAImpl(MLAAttentionImpl):
"actual_seq_lengths": actual_seq_lengths, "actual_seq_lengths": actual_seq_lengths,
"actual_seq_lengths_kv": decode_meta.seq_lens_list, "actual_seq_lengths_kv": decode_meta.seq_lens_list,
} }
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
if forward_context.is_mtp_model:
graph_params = get_mtp_graph_params()
else:
graph_params = get_graph_params()
if forward_context.capturing: if forward_context.capturing:
stream = torch_npu.npu.current_stream() stream = torch_npu.npu.current_stream()

View File

@@ -241,6 +241,9 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
def update_mla_attn_params(update_stream, forward_context, runtime_shape, def update_mla_attn_params(update_stream, forward_context, runtime_shape,
speculative_config): speculative_config):
if forward_context.is_mtp_model:
graph_params = get_mtp_graph_params()
else:
graph_params = get_graph_params() graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args # FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph. # for each layer's attention op in the graph.
@@ -257,7 +260,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
softmax_lse) = param softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[ seq_lens_list = forward_context.attn_metadata[
key].decode.seq_lens_list key].decode.seq_lens_list
if speculative_config and speculative_config.method == "deepseek_mtp": if speculative_config and speculative_config.method == "deepseek_mtp" \
and not forward_context.is_mtp_model:
actual_seq_lengths = forward_context.attn_metadata[ actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1 spec_multiple = speculative_config.num_speculative_tokens + 1
@@ -267,6 +271,13 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
spec_multiple * (i + 1) spec_multiple * (i + 1)
for i in range(runtime_shape // spec_multiple) for i in range(runtime_shape // spec_multiple)
] ]
elif forward_context.is_mtp_model:
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[
key].decode.block_table
seq_lens_list = seq_lens_list + [0] * (
len(actual_seq_lengths) - len(seq_lens_list))
else: else:
seq_lens_list = seq_lens_list + [0] * (runtime_shape - seq_lens_list = seq_lens_list + [0] * (runtime_shape -
len(seq_lens_list)) len(seq_lens_list))
@@ -443,3 +454,32 @@ def update_graph_params_workspaces(num_tokens: int, workspace: int):
def get_graph_params(): def get_graph_params():
return _graph_params return _graph_params
_mtp_graph_params: Optional[GraphParams] = None
def set_mtp_graph_params(aclgraph_capture_sizes: set[int]):
global _mtp_graph_params
if _mtp_graph_params is not None:
raise ValueError("MTPGraph parameters have already been set!")
_mtp_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
)
def update_mtp_graph_params_workspaces(num_tokens: int, workspace: Any):
global _mtp_graph_params
if _mtp_graph_params is not None:
_mtp_graph_params.workspaces[num_tokens] = workspace
def get_mtp_graph_params():
return _mtp_graph_params

View File

@@ -10,6 +10,12 @@ from vllm.model_executor.models.deepseek_mtp import \
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.models.utils import maybe_prefix
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.compilation.decorators import support_torch_compile
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
class SharedHead(nn.Module): class SharedHead(nn.Module):
@@ -51,4 +57,38 @@ def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None:
topk_indices_buffer) topk_indices_buffer)
def predictor_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0
@support_torch_compile
class AscendDeepSeekMTP(DeepSeekMTP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init
if vllm_version_is("0.11.0"):
DeepSeekMultiTokenPredictorLayer.forward = predictor_forward

View File

@@ -1,5 +1,5 @@
import importlib import importlib
from typing import Optional from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
@@ -7,7 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm.config import (CUDAGraphMode, VllmConfig, from vllm.config import (CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config) get_layers_from_vllm_config, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
@@ -32,7 +32,11 @@ from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_mtp_graph_params,
update_mla_attn_params)
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
prefill_context_parallel_enable, prefill_context_parallel_enable,
@@ -52,9 +56,14 @@ logger = init_logger(__name__)
PADDING_SLOT_ID = -1 PADDING_SLOT_ID = -1
_deepseek_mtp_path = "vllm.model_executor.models.deepseek_mtp"
_deepseek_mtp_model = "DeepSeekMTP"
if vllm_version_is("0.11.0"):
_deepseek_mtp_path = "vllm_ascend.patch.worker.patch_deepseek_mtp"
_deepseek_mtp_model = "AscendDeepSeekMTP"
_MTP_MODELS = { _MTP_MODELS = {
"DeepseekV3ForCausalLM": "DeepseekV3ForCausalLM": (_deepseek_mtp_path, _deepseek_mtp_model),
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"Qwen3NextForCausalLM": "Qwen3NextForCausalLM":
("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP") ("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP")
} }
@@ -75,6 +84,9 @@ def _load_model(architecture):
class MtpProposer(Proposer): class MtpProposer(Proposer):
# TODO: Find out why ModelRunner does not this explicit typing?
model: Union[nn.Module, ACLGraphWrapper]
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
@@ -203,6 +215,15 @@ class MtpProposer(Proposer):
process_weights_after_loading(self.model, draft_model_config, process_weights_after_loading(self.model, draft_model_config,
target_device) target_device)
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
):
self.update_stream: torch.npu.Stream = torch.npu.Stream()
set_mtp_graph_params(
self.vllm_config.compilation_config.cudagraph_capture_sizes)
self.model = ACLGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
@torch.inference_mode() @torch.inference_mode()
def dummy_run(self, def dummy_run(self,
num_tokens: int, num_tokens: int,
@@ -222,12 +243,55 @@ class MtpProposer(Proposer):
moe_comm_type = self.runner._select_moe_comm_method( moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill) num_tokens, with_prefill)
if skip_attn:
attn_metadata = None
elif aclgraph_runtime_mode == CUDAGraphMode.FULL:
if len(self.runner.attn_groups) > 0:
num_computed_tokens_cpu = (
self.runner.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs])
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.runner.
query_start_loc_cpu[:num_reqs + 1],
seq_lens_cpu=self.runner.seq_lens_cpu,
seq_lens=self.runner.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=self.num_speculative_tokens + 1,
num_computed_tokens_cpu=num_computed_tokens_cpu,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor()[:num_reqs],
slot_mapping=self.runner.input_batch.block_table[0].
slot_mapping,
positions=self.runner.positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
cos=self.runner.cos,
sin=self.runner.sin,
)
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_mtp = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.SpecDecoding,
self.runner.get_model())
attn_metadata = {}
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_mtp
else:
attn_metadata = None
else:
attn_metadata = None attn_metadata = None
input_ids = self.input_ids[:num_tokens] input_ids = self.input_ids[:num_tokens]
positions = self.positions[:num_tokens] positions = self.positions[:num_tokens]
previous_hidden_states = self.hidden_states[:num_tokens] previous_hidden_states = self.hidden_states[:num_tokens]
for _ in range(self.num_speculative_tokens): for i in range(self.num_speculative_tokens):
if i > 0:
aclgraph_runtime_mode = CUDAGraphMode.NONE
with set_ascend_forward_context( with set_ascend_forward_context(
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
@@ -239,10 +303,19 @@ class MtpProposer(Proposer):
in_profile_run=self.runner.in_profile_run, in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0, num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor): batch_descriptor=batch_descriptor,
is_mtp_model=True):
self.model(input_ids=input_ids, self.model(input_ids=input_ids,
positions=positions, positions=positions,
hidden_states=previous_hidden_states) hidden_states=previous_hidden_states)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
not forward_context.capturing:
if self.vllm_config.model_config.use_mla:
update_mla_attn_params(
self.update_stream, forward_context,
positions.shape[0],
self.vllm_config.speculative_config)
if with_prefill: if with_prefill:
break break
@@ -324,6 +397,7 @@ class MtpProposer(Proposer):
common_attn_metadata.query_start_loc = \ common_attn_metadata.query_start_loc = \
query_start_loc_pcp_full[:num_reqs + 1] query_start_loc_pcp_full[:num_reqs + 1]
if self.speculative_config.disable_padded_drafter_batch: if self.speculative_config.disable_padded_drafter_batch:
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
token_indices_to_sample = None token_indices_to_sample = None
common_attn_metadata, token_indices =\ common_attn_metadata, token_indices =\
self._prepare_inputs( self._prepare_inputs(
@@ -358,6 +432,8 @@ class MtpProposer(Proposer):
long_seq_metadata=long_seq_metadata, long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs, num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs, num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens,
) )
return draft_token_ids return draft_token_ids
@@ -460,6 +536,13 @@ class MtpProposer(Proposer):
token_indices = torch.from_numpy(token_indices_np).to( token_indices = torch.from_numpy(token_indices_np).to(
device, non_blocking=True) device, non_blocking=True)
common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_(
common_attn_metadata.slot_mapping[token_indices])
common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1)
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
# so we do not need to fixed them. But if they are used in the future,
# we should fixed them.
spec_common_attn_metadata = AscendCommonAttentionMetadata( spec_common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=new_query_start_loc_cpu.to(device, query_start_loc=new_query_start_loc_cpu.to(device,
non_blocking=True), non_blocking=True),
@@ -472,7 +555,7 @@ class MtpProposer(Proposer):
num_actual_tokens=total_num_tokens, num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(), max_query_len=new_query_len_per_req.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor, block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices], slot_mapping=common_attn_metadata.slot_mapping,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q, actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
positions=common_attn_metadata.positions[token_indices], positions=common_attn_metadata.positions[token_indices],
attn_mask=self.runner.attn_mask, attn_mask=self.runner.attn_mask,
@@ -502,6 +585,8 @@ class MtpProposer(Proposer):
long_seq_metadata=None, long_seq_metadata=None,
num_prefill_reqs=0, num_prefill_reqs=0,
num_decode_reqs=0, num_decode_reqs=0,
scheduler_output: SchedulerOutput = None,
num_scheduled_tokens: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
@@ -585,14 +670,11 @@ class MtpProposer(Proposer):
assert self.runner is not None assert self.runner is not None
builder = self.runner.attn_groups[0][0].get_metadata_builder() if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
attn_metadata_mtp = builder.build(0, common_attn_metadata, ) and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]:
self.runner.get_model()) num_input_tokens = self.vllm_config.pad_for_cudagraph(
attn_metadata = {} num_scheduled_tokens)
for layer_name in self.attn_layer_name: elif self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
attn_metadata[layer_name] = attn_metadata_mtp
if self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
# Acl graph mode, add padding to the batch size # Acl graph mode, add padding to the batch size
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else: else:
@@ -609,19 +691,39 @@ class MtpProposer(Proposer):
moe_comm_type = self.runner._select_moe_comm_method( moe_comm_type = self.runner._select_moe_comm_method(
num_input_tokens, with_prefill) num_input_tokens, with_prefill)
if scheduler_output:
max_query_len = common_attn_metadata.max_query_len
uniform_decode = (max_query_len in list(
range(1, self.num_speculative_tokens +
2))) and (scheduler_output.total_num_scheduled_tokens
== self.runner.input_batch.num_reqs *
(self.num_speculative_tokens + 1))
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode)
else:
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False) uniform_decode=False)
aclgraph_runtime_mode, batch_descriptor = \ aclgraph_runtime_mode, batch_descriptor = \
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor) self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
if aclgraph_runtime_mode not in [
CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
]: ) and aclgraph_runtime_mode == CUDAGraphMode.FULL:
# Fallback to piecewise graph, when acl full graph is enabled graph_pad_size = num_input_tokens
logger.debug( else:
"Currently the eagle proposer only supports cudagraph_mode " # Currently, runner.graph_pad_size will always be -1.
f"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} " graph_pad_size = self.runner.graph_pad_size
"to CUDAGraphMode.PIECEWISE")
aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE # If use fullgraph and disable_padded_drafter_batch=True, We need to
# update the graph_pad_size in common_attn_metadata, to tell the
# builder padding some elements.
common_attn_metadata.graph_pad_size = graph_pad_size
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_mtp = builder.build(0, common_attn_metadata,
self.runner.get_model())
attn_metadata = {}
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_mtp
for step in range(self.num_speculative_tokens): for step in range(self.num_speculative_tokens):
with set_ascend_forward_context( with set_ascend_forward_context(
@@ -635,7 +737,8 @@ class MtpProposer(Proposer):
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
in_profile_run=self.runner.in_profile_run, in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens): num_actual_tokens=num_tokens,
is_mtp_model=True):
with ProfileExecuteDuration().capture_async('mtp_forward'): with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {} model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata model_kwargs["attn_metadata"] = attn_metadata
@@ -644,6 +747,13 @@ class MtpProposer(Proposer):
input_ids=self.input_ids[:num_input_tokens], input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens]) hidden_states=self.hidden_states[:num_input_tokens])
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
if self.vllm_config.model_config.use_mla:
update_mla_attn_params(
self.update_stream, forward_context,
num_input_tokens,
self.vllm_config.speculative_config)
num_indices = last_token_indices.shape[0] num_indices = last_token_indices.shape[0]
if lmhead_tp_enable(): if lmhead_tp_enable():
@@ -699,8 +809,17 @@ class MtpProposer(Proposer):
input_ids = draft_token_ids_list[-1].int() input_ids = draft_token_ids_list[-1].int()
positions += 1 positions += 1
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
if self.speculative_config.disable_padded_drafter_batch or \
aclgraph_runtime_mode != CUDAGraphMode.FULL:
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
1:batch_size + 1].tolist() 1:batch_size + 1].tolist()
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
attn_metadata_i.decode.actual_seq_lengths_q = \
builder.pad_actual_seq_len_q_mtp_disable_pad(
graph_pad_size - batch_size,
batch_size,
attn_metadata_i.decode.actual_seq_lengths_q)
attn_metadata_i.decode.cos = builder.cos_cache[ attn_metadata_i.decode.cos = builder.cos_cache[
positions].unsqueeze(1).unsqueeze(2) positions].unsqueeze(1).unsqueeze(2)
attn_metadata_i.decode.sin = builder.sin_cache[ attn_metadata_i.decode.sin = builder.sin_cache[
@@ -735,6 +854,10 @@ class MtpProposer(Proposer):
self.positions[:batch_size] = clamped_positions self.positions[:batch_size] = clamped_positions
self.hidden_states[:hidden_states.shape[0]] = hidden_states self.hidden_states[:hidden_states.shape[0]] = hidden_states
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
if self.speculative_config.disable_padded_drafter_batch:
self.positions[batch_size:num_input_tokens] = 0
self.input_ids[batch_size:num_input_tokens] = 0
self.hidden_states[batch_size:num_input_tokens].fill_(0)
if attn_metadata_i.prefill is not None: if attn_metadata_i.prefill is not None:
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
@@ -751,12 +874,19 @@ class MtpProposer(Proposer):
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist( attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
) )
decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list
if aclgraph_runtime_mode == CUDAGraphMode.FULL and \
self.speculative_config.disable_padded_drafter_batch:
attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [
0
] * (graph_pad_size - len(decode_seq_lens_list))
attn_metadata_i.decode.input_positions = self.positions[: attn_metadata_i.decode.input_positions = self.positions[:
num_input_tokens] num_input_tokens]
attn_metadata_i.decode.max_seq_lens += 1 attn_metadata_i.decode.max_seq_lens += 1
attn_metadata_i.decode.max_seq_lens = min( attn_metadata_i.decode.max_seq_lens = min(
attn_metadata_i.decode.max_seq_lens, attn_metadata_i.decode.max_seq_lens,
self.runner.model_config.max_model_len) self.runner.model_config.max_model_len)
torch.npu.synchronize()
# mtp>1: [batch_size, k] # mtp>1: [batch_size, k]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1) draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
@@ -915,6 +1045,9 @@ class MtpProposer(Proposer):
total_num_tokens = query_start_loc_cpu[-1].item() total_num_tokens = query_start_loc_cpu[-1].item()
token_indices = self.arange[:total_num_tokens] token_indices = self.arange[:total_num_tokens]
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
# so we do not need to fixed them. But if they are used in the future,
# we should fixed them.
spec_common_attn_metadata = AscendCommonAttentionMetadata( spec_common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc, query_start_loc=common_attn_metadata.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu, query_start_loc_cpu=query_start_loc_cpu,

View File

@@ -3112,7 +3112,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.drafter.dummy_run( self.drafter.dummy_run(
num_tokens=num_tokens, num_tokens=num_tokens,
with_prefill=with_prefill, with_prefill=with_prefill,
skip_attn=True,
num_reqs=num_reqs, num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,