[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -256,6 +256,57 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_ascend_mla_metadata_builder_build_full_graph(
|
||||
self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_spec_config = MagicMock()
|
||||
mock_spec_config.num_speculative_tokens = 1
|
||||
mock_spec_config.disable_padded_drafter_batch = True
|
||||
mock_vllm_config.speculative_config = mock_spec_config
|
||||
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
mock_device)
|
||||
common_metadata = MagicMock()
|
||||
model = MagicMock()
|
||||
common_metadata.graph_pad_size = 8
|
||||
common_metadata.num_reqs = 4
|
||||
common_metadata.num_actual_tokens = 5
|
||||
common_metadata.max_query_len = 5
|
||||
common_metadata.seq_lens_cpu = torch.Tensor([9, 10, 8, 8]).int()
|
||||
common_metadata.query_start_loc = torch.Tensor([0, 1, 2, 4, 5]).int()
|
||||
common_metadata.query_start_loc_cpu = torch.Tensor([0, 1, 2, 4,
|
||||
5]).int()
|
||||
common_metadata.positions = torch.Tensor([1, 2, 3, 4, 5, 6]).int()
|
||||
block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int()
|
||||
common_metadata.block_table_tensor = block_table
|
||||
common_metadata.prefill_context_parallel_metadata = None
|
||||
metadata = builder.build(0, common_metadata, model)
|
||||
|
||||
self.assertEqual(metadata.decode.actual_seq_lengths_q,
|
||||
[1, 2, 4, 5, 6, 6, 7, 8])
|
||||
self.assertEqual(metadata.decode.block_table.shape[0], 8)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@@ -307,6 +358,80 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
self.assertTrue(modified)
|
||||
input_batch.swap_states.assert_called_once_with(1, 2)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_get_dcp_size,
|
||||
mock_dcp,
|
||||
mock_get_dcp_group):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
mock_device)
|
||||
input_seq_lens = [1, 2, 4, 5]
|
||||
expect_output = [1, 2, 4, 5, 6, 6, 7, 8]
|
||||
num_reqs = 4
|
||||
num_reqs_pad_size = 4
|
||||
output_seq_lens = builder.pad_actual_seq_len_q_mtp_disable_pad(
|
||||
num_reqs_pad_size, num_reqs, input_seq_lens)
|
||||
self.assertEqual(output_seq_lens, expect_output)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_get_dcp_size,
|
||||
mock_dcp,
|
||||
mock_get_dcp_group):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
common_metadata = MagicMock()
|
||||
common_metadata.actual_seq_lengths_q = [2, 4, 6, 8]
|
||||
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
mock_device)
|
||||
input_seq_lens = [2, 4, 6]
|
||||
expect_output = [2, 4, 6, 8]
|
||||
num_reqs = 3
|
||||
num_reqs_pad_size = 1
|
||||
output_seq_lens = builder.pad_actual_seq_len_q_mtp_enable_pad(
|
||||
num_reqs_pad_size, num_reqs, input_seq_lens, common_metadata)
|
||||
self.assertEqual(output_seq_lens, expect_output)
|
||||
|
||||
|
||||
class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
|
||||
|
||||
@@ -21,7 +21,9 @@ from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor, ForwardContext
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.compilation.acl_graph import ACLGraphEntry, ACLGraphWrapper
|
||||
from vllm_ascend.compilation.acl_graph import (
|
||||
ACLGraphEntry, ACLGraphWrapper, get_mtp_graph_params, set_mtp_graph_params,
|
||||
update_mtp_graph_params_workspaces)
|
||||
|
||||
|
||||
class TestACLGraphEntry(TestBase):
|
||||
@@ -718,3 +720,24 @@ class TestACLGraphWrapper(TestBase):
|
||||
|
||||
unwrapped = wrapper.unwrap()
|
||||
self.assertEqual(unwrapped, self.mock_runnable)
|
||||
|
||||
|
||||
class TestMTPGraphParams(TestBase):
|
||||
|
||||
def test_set_mtp_graph_params(self):
|
||||
with patch('vllm_ascend.compilation.acl_graph._mtp_graph_params',
|
||||
new=None):
|
||||
set_mtp_graph_params([4])
|
||||
from vllm_ascend.compilation.acl_graph import _mtp_graph_params
|
||||
self.assertIsNotNone(_mtp_graph_params)
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph._mtp_graph_params')
|
||||
def test_update_mtp_graph_params_workspaces(self, mtp_graph_params_mock):
|
||||
mtp_graph_params_mock.workspaces = {4: 5}
|
||||
update_mtp_graph_params_workspaces(4, 6)
|
||||
self.assertEqual(mtp_graph_params_mock.workspaces[4], 6)
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph._mtp_graph_params')
|
||||
def test_get_mtp_graph_params(self, mtp_graph_params_mock):
|
||||
graph_params = get_mtp_graph_params()
|
||||
self.assertIs(mtp_graph_params_mock, graph_params)
|
||||
|
||||
@@ -72,7 +72,8 @@ def set_ascend_forward_context(
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
prefetch_stream: torch.npu.Stream = None,
|
||||
model_instance: torch.nn.Module = None,
|
||||
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
|
||||
weight_prefetch_method: Optional[WeightPrefetchMethod] = None,
|
||||
is_mtp_model=False):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
We add some additional param into forward_context.
|
||||
@@ -158,6 +159,7 @@ def set_ascend_forward_context(
|
||||
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
||||
forward_context.model_instance = model_instance
|
||||
forward_context.weight_prefetch_method = weight_prefetch_method
|
||||
forward_context.is_mtp_model = is_mtp_model
|
||||
|
||||
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
||||
# It will be improved later by implementing operator fusion through the FX graph.
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch.distributed as dist
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_dcp_group,
|
||||
get_decode_context_model_parallel_rank,
|
||||
@@ -38,6 +39,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
get_mtp_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
@@ -337,6 +339,74 @@ class AscendMLAMetadataBuilder:
|
||||
# better way of doing this
|
||||
return modified_batch
|
||||
|
||||
def pad_actual_seq_len_q_mtp_enable_pad(self, num_reqs_pad_size, num_reqs,
|
||||
actual_seq_lengths_q,
|
||||
common_attn_metadata):
|
||||
"""
|
||||
Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request
|
||||
in order to meet the requirement of npu_fused_infer_attention_score.
|
||||
|
||||
In Torchair scenario, the lengths of the queries must be padded to the same length.
|
||||
And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens).
|
||||
|
||||
For example:
|
||||
batch_size=36, num_reqs_pad_size=2, num_reqs=16
|
||||
By default, each request should have inference 2 token, which means actual_seq_lengths_q should be
|
||||
[2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36].
|
||||
|
||||
However, mtp torchair + PD scenario, the actual_seq_lengths_q may be
|
||||
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token.
|
||||
In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request.
|
||||
after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36]
|
||||
"""
|
||||
FIA_SEQ_LEN_LIMIT = 16
|
||||
need_padding = num_reqs_pad_size != 0 and \
|
||||
len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \
|
||||
common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT
|
||||
if need_padding:
|
||||
padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[
|
||||
num_reqs:num_reqs + num_reqs_pad_size]
|
||||
start_val = actual_seq_lengths_q[-1]
|
||||
end_val = padding_seq_len_q[-1]
|
||||
|
||||
num_step = len(padding_seq_len_q)
|
||||
interpolated = np.round(
|
||||
np.linspace(start_val, end_val,
|
||||
num_step + 1)[1:]).astype(int).tolist()
|
||||
assert interpolated[-1] == end_val
|
||||
assert len(interpolated) == len(padding_seq_len_q)
|
||||
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
|
||||
else:
|
||||
actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[
|
||||
num_reqs:num_reqs + num_reqs_pad_size]
|
||||
|
||||
return actual_seq_lengths_q
|
||||
|
||||
def pad_actual_seq_len_q_mtp_disable_pad(self, num_reqs_pad_size, num_reqs,
|
||||
actual_seq_lengths_q):
|
||||
"""
|
||||
Only use for acl full graph mode.
|
||||
Pad the last element of the actual_seq_lengths_q equal to the TND(T) and
|
||||
the num of dimensions equal to the batch_size of main model.
|
||||
|
||||
For example:
|
||||
batch_size = 8, num_reqs = 4, num_speculative_tokens = 1
|
||||
input actual_seq_lengths_q = [1, 2, 4, 5] (the 3rd req was accept a token)
|
||||
After padding the actual_seq_lengths_q will be similar to [1, 2, 4, 5, 6, 6, 7, 8]
|
||||
"""
|
||||
need_padding = num_reqs_pad_size > 0
|
||||
if need_padding:
|
||||
start_val = actual_seq_lengths_q[-1]
|
||||
end_val = num_reqs + num_reqs_pad_size
|
||||
num_step = num_reqs_pad_size
|
||||
interpolated = np.round(
|
||||
np.linspace(start_val, end_val,
|
||||
num_step + 1)[1:]).astype(int).tolist()
|
||||
assert interpolated[-1] == end_val
|
||||
assert len(interpolated) == num_reqs_pad_size
|
||||
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
|
||||
return actual_seq_lengths_q
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
@@ -362,17 +432,25 @@ class AscendMLAMetadataBuilder:
|
||||
# it blocks on all previous kernels.
|
||||
device = self.device
|
||||
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
# If graph_pad_size > -1, mean is running in fullgraph mode.
|
||||
graph_pad_size = common_attn_metadata.graph_pad_size
|
||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
||||
if graph_pad_size > num_reqs and self.speculative_config.disable_padded_drafter_batch:
|
||||
block_table = (
|
||||
common_attn_metadata.block_table_tensor[:graph_pad_size])
|
||||
else:
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||
if self.pcp_size > 1:
|
||||
num_decodes_flatten = num_decodes * self.decode_threshold
|
||||
block_table = common_attn_metadata.block_table_tensor[:
|
||||
num_decodes_flatten
|
||||
+
|
||||
num_prefills]
|
||||
|
||||
if num_actual_tokens_pcp_padded is None:
|
||||
num_actual_tokens_pcp_padded = num_actual_tokens
|
||||
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
@@ -565,6 +643,11 @@ class AscendMLAMetadataBuilder:
|
||||
block_table = block_table[:num_decodes_flatten, ...]
|
||||
else:
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
||||
if graph_pad_size > num_decodes and \
|
||||
self.speculative_config.disable_padded_drafter_batch:
|
||||
block_table = block_table[:graph_pad_size, ...]
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
if num_computed_tokens_of_pcp_dcp is not None:
|
||||
@@ -586,6 +669,52 @@ class AscendMLAMetadataBuilder:
|
||||
else:
|
||||
cp_seq_len, batch_seq_mask = None, None
|
||||
|
||||
if graph_pad_size > num_reqs:
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
num_reqs_pad_size = graph_pad_size - num_reqs
|
||||
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
|
||||
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
|
||||
seq_lens_list = seq_lens_list + [0] * (graph_pad_size - \
|
||||
num_decodes)
|
||||
num_block_pad_size = graph_pad_size - block_table.shape[0]
|
||||
if num_block_pad_size > 0:
|
||||
block_table_padding = torch.zeros(
|
||||
(num_block_pad_size, ) + block_table.shape[1:],
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device)
|
||||
block_table = torch.cat(
|
||||
[block_table, block_table_padding], dim=0)
|
||||
else:
|
||||
num_token_pad_size = graph_pad_size - num_decode_tokens
|
||||
num_reqs_pad_size = (
|
||||
graph_pad_size //
|
||||
common_attn_metadata.decode_token_per_req - num_reqs)
|
||||
num_block_table_pad_size = (
|
||||
graph_pad_size //
|
||||
common_attn_metadata.decode_token_per_req -
|
||||
num_decodes)
|
||||
seq_lens_list = seq_lens.tolist() + [0] * num_reqs_pad_size
|
||||
slot_padding = torch.full((num_token_pad_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=slot_mapping.dtype,
|
||||
device=slot_mapping.device)
|
||||
slot_mapping = torch.cat([slot_mapping, slot_padding])
|
||||
block_table_padding = torch.zeros(
|
||||
(num_block_table_pad_size, ) + block_table.shape[1:],
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device)
|
||||
block_table = torch.cat([block_table, block_table_padding],
|
||||
dim=0)
|
||||
position_padding = torch.zeros(
|
||||
num_token_pad_size,
|
||||
dtype=input_positions.dtype,
|
||||
device=input_positions.device)
|
||||
input_positions = torch.cat(
|
||||
[input_positions, position_padding])
|
||||
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad(
|
||||
num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
|
||||
common_attn_metadata)
|
||||
|
||||
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
|
||||
assert self.cos_cache is not None
|
||||
assert self.sin_cache is not None
|
||||
@@ -1267,8 +1396,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
"actual_seq_lengths": actual_seq_lengths,
|
||||
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
|
||||
}
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if forward_context.is_mtp_model:
|
||||
graph_params = get_mtp_graph_params()
|
||||
else:
|
||||
graph_params = get_graph_params()
|
||||
if forward_context.capturing:
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
|
||||
@@ -241,7 +241,10 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
speculative_config):
|
||||
graph_params = get_graph_params()
|
||||
if forward_context.is_mtp_model:
|
||||
graph_params = get_mtp_graph_params()
|
||||
else:
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
with torch.npu.stream(update_stream):
|
||||
@@ -257,7 +260,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
softmax_lse) = param
|
||||
seq_lens_list = forward_context.attn_metadata[
|
||||
key].decode.seq_lens_list
|
||||
if speculative_config and speculative_config.method == "deepseek_mtp":
|
||||
if speculative_config and speculative_config.method == "deepseek_mtp" \
|
||||
and not forward_context.is_mtp_model:
|
||||
actual_seq_lengths = forward_context.attn_metadata[
|
||||
key].decode.actual_seq_lengths_q
|
||||
spec_multiple = speculative_config.num_speculative_tokens + 1
|
||||
@@ -267,6 +271,13 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
spec_multiple * (i + 1)
|
||||
for i in range(runtime_shape // spec_multiple)
|
||||
]
|
||||
elif forward_context.is_mtp_model:
|
||||
actual_seq_lengths = forward_context.attn_metadata[
|
||||
key].decode.actual_seq_lengths_q
|
||||
block_table = forward_context.attn_metadata[
|
||||
key].decode.block_table
|
||||
seq_lens_list = seq_lens_list + [0] * (
|
||||
len(actual_seq_lengths) - len(seq_lens_list))
|
||||
else:
|
||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
|
||||
len(seq_lens_list))
|
||||
@@ -443,3 +454,32 @@ def update_graph_params_workspaces(num_tokens: int, workspace: int):
|
||||
|
||||
def get_graph_params():
|
||||
return _graph_params
|
||||
|
||||
|
||||
_mtp_graph_params: Optional[GraphParams] = None
|
||||
|
||||
|
||||
def set_mtp_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
global _mtp_graph_params
|
||||
if _mtp_graph_params is not None:
|
||||
raise ValueError("MTPGraph parameters have already been set!")
|
||||
_mtp_graph_params = GraphParams(
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: None
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
)
|
||||
|
||||
|
||||
def update_mtp_graph_params_workspaces(num_tokens: int, workspace: Any):
|
||||
global _mtp_graph_params
|
||||
if _mtp_graph_params is not None:
|
||||
_mtp_graph_params.workspaces[num_tokens] = workspace
|
||||
|
||||
|
||||
def get_mtp_graph_params():
|
||||
return _mtp_graph_params
|
||||
|
||||
@@ -10,6 +10,12 @@ from vllm.model_executor.models.deepseek_mtp import \
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
||||
|
||||
|
||||
class SharedHead(nn.Module):
|
||||
|
||||
@@ -51,4 +57,38 @@ def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
topk_indices_buffer)
|
||||
|
||||
|
||||
def predictor_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0
|
||||
@support_torch_compile
|
||||
class AscendDeepSeekMTP(DeepSeekMTP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
|
||||
DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init
|
||||
if vllm_version_is("0.11.0"):
|
||||
DeepSeekMultiTokenPredictorLayer.forward = predictor_forward
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import importlib
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import (CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
@@ -32,7 +32,11 @@ from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
set_mtp_graph_params,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||
prefill_context_parallel_enable,
|
||||
@@ -52,9 +56,14 @@ logger = init_logger(__name__)
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
_deepseek_mtp_path = "vllm.model_executor.models.deepseek_mtp"
|
||||
_deepseek_mtp_model = "DeepSeekMTP"
|
||||
if vllm_version_is("0.11.0"):
|
||||
_deepseek_mtp_path = "vllm_ascend.patch.worker.patch_deepseek_mtp"
|
||||
_deepseek_mtp_model = "AscendDeepSeekMTP"
|
||||
|
||||
_MTP_MODELS = {
|
||||
"DeepseekV3ForCausalLM":
|
||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
||||
"DeepseekV3ForCausalLM": (_deepseek_mtp_path, _deepseek_mtp_model),
|
||||
"Qwen3NextForCausalLM":
|
||||
("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP")
|
||||
}
|
||||
@@ -75,6 +84,9 @@ def _load_model(architecture):
|
||||
|
||||
class MtpProposer(Proposer):
|
||||
|
||||
# TODO: Find out why ModelRunner does not this explicit typing?
|
||||
model: Union[nn.Module, ACLGraphWrapper]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@@ -203,6 +215,15 @@ class MtpProposer(Proposer):
|
||||
process_weights_after_loading(self.model, draft_model_config,
|
||||
target_device)
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
):
|
||||
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
||||
set_mtp_graph_params(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes)
|
||||
self.model = ACLGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self,
|
||||
num_tokens: int,
|
||||
@@ -222,12 +243,55 @@ class MtpProposer(Proposer):
|
||||
moe_comm_type = self.runner._select_moe_comm_method(
|
||||
num_tokens, with_prefill)
|
||||
|
||||
attn_metadata = None
|
||||
if skip_attn:
|
||||
attn_metadata = None
|
||||
elif aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
if len(self.runner.attn_groups) > 0:
|
||||
num_computed_tokens_cpu = (
|
||||
self.runner.input_batch.
|
||||
num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.runner.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.runner.
|
||||
query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens_cpu=self.runner.seq_lens_cpu,
|
||||
seq_lens=self.runner.seq_lens_cpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=self.num_speculative_tokens + 1,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||
get_device_tensor()[:num_reqs],
|
||||
slot_mapping=self.runner.input_batch.block_table[0].
|
||||
slot_mapping,
|
||||
positions=self.runner.positions,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
cos=self.runner.cos,
|
||||
sin=self.runner.sin,
|
||||
)
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_mtp = builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.SpecDecoding,
|
||||
self.runner.get_model())
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_name:
|
||||
attn_metadata[layer_name] = attn_metadata_mtp
|
||||
else:
|
||||
attn_metadata = None
|
||||
else:
|
||||
attn_metadata = None
|
||||
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
previous_hidden_states = self.hidden_states[:num_tokens]
|
||||
for _ in range(self.num_speculative_tokens):
|
||||
for i in range(self.num_speculative_tokens):
|
||||
if i > 0:
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
@@ -239,10 +303,19 @@ class MtpProposer(Proposer):
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=0,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor):
|
||||
batch_descriptor=batch_descriptor,
|
||||
is_mtp_model=True):
|
||||
self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
hidden_states=previous_hidden_states)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||
not forward_context.capturing:
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
positions.shape[0],
|
||||
self.vllm_config.speculative_config)
|
||||
if with_prefill:
|
||||
break
|
||||
|
||||
@@ -324,6 +397,7 @@ class MtpProposer(Proposer):
|
||||
common_attn_metadata.query_start_loc = \
|
||||
query_start_loc_pcp_full[:num_reqs + 1]
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
|
||||
token_indices_to_sample = None
|
||||
common_attn_metadata, token_indices =\
|
||||
self._prepare_inputs(
|
||||
@@ -358,6 +432,8 @@ class MtpProposer(Proposer):
|
||||
long_seq_metadata=long_seq_metadata,
|
||||
num_prefill_reqs=num_prefill_reqs,
|
||||
num_decode_reqs=num_decode_reqs,
|
||||
scheduler_output=scheduler_output,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
)
|
||||
|
||||
return draft_token_ids
|
||||
@@ -460,6 +536,13 @@ class MtpProposer(Proposer):
|
||||
token_indices = torch.from_numpy(token_indices_np).to(
|
||||
device, non_blocking=True)
|
||||
|
||||
common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_(
|
||||
common_attn_metadata.slot_mapping[token_indices])
|
||||
common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1)
|
||||
|
||||
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
|
||||
# so we do not need to fixed them. But if they are used in the future,
|
||||
# we should fixed them.
|
||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=new_query_start_loc_cpu.to(device,
|
||||
non_blocking=True),
|
||||
@@ -472,7 +555,7 @@ class MtpProposer(Proposer):
|
||||
num_actual_tokens=total_num_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
positions=common_attn_metadata.positions[token_indices],
|
||||
attn_mask=self.runner.attn_mask,
|
||||
@@ -502,6 +585,8 @@ class MtpProposer(Proposer):
|
||||
long_seq_metadata=None,
|
||||
num_prefill_reqs=0,
|
||||
num_decode_reqs=0,
|
||||
scheduler_output: SchedulerOutput = None,
|
||||
num_scheduled_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
@@ -585,14 +670,11 @@ class MtpProposer(Proposer):
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_mtp = builder.build(0, common_attn_metadata,
|
||||
self.runner.get_model())
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_name:
|
||||
attn_metadata[layer_name] = attn_metadata_mtp
|
||||
|
||||
if self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
) and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_scheduled_tokens)
|
||||
elif self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
# Acl graph mode, add padding to the batch size
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
@@ -609,19 +691,39 @@ class MtpProposer(Proposer):
|
||||
|
||||
moe_comm_type = self.runner._select_moe_comm_method(
|
||||
num_input_tokens, with_prefill)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=False)
|
||||
|
||||
if scheduler_output:
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
uniform_decode = (max_query_len in list(
|
||||
range(1, self.num_speculative_tokens +
|
||||
2))) and (scheduler_output.total_num_scheduled_tokens
|
||||
== self.runner.input_batch.num_reqs *
|
||||
(self.num_speculative_tokens + 1))
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=uniform_decode)
|
||||
else:
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=False)
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
|
||||
if aclgraph_runtime_mode not in [
|
||||
CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE
|
||||
]:
|
||||
# Fallback to piecewise graph, when acl full graph is enabled
|
||||
logger.debug(
|
||||
"Currently the eagle proposer only supports cudagraph_mode "
|
||||
f"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} "
|
||||
"to CUDAGraphMode.PIECEWISE")
|
||||
aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
) and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
graph_pad_size = num_input_tokens
|
||||
else:
|
||||
# Currently, runner.graph_pad_size will always be -1.
|
||||
graph_pad_size = self.runner.graph_pad_size
|
||||
|
||||
# If use fullgraph and disable_padded_drafter_batch=True, We need to
|
||||
# update the graph_pad_size in common_attn_metadata, to tell the
|
||||
# builder padding some elements.
|
||||
common_attn_metadata.graph_pad_size = graph_pad_size
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_mtp = builder.build(0, common_attn_metadata,
|
||||
self.runner.get_model())
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_name:
|
||||
attn_metadata[layer_name] = attn_metadata_mtp
|
||||
|
||||
for step in range(self.num_speculative_tokens):
|
||||
with set_ascend_forward_context(
|
||||
@@ -635,7 +737,8 @@ class MtpProposer(Proposer):
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=num_tokens):
|
||||
num_actual_tokens=num_tokens,
|
||||
is_mtp_model=True):
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
model_kwargs = {}
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
@@ -644,6 +747,13 @@ class MtpProposer(Proposer):
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens])
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
num_input_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable():
|
||||
@@ -699,12 +809,21 @@ class MtpProposer(Proposer):
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
|
||||
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
||||
1:batch_size + 1].tolist()
|
||||
attn_metadata_i.decode.cos = builder.cos_cache[
|
||||
positions].unsqueeze(1).unsqueeze(2)
|
||||
attn_metadata_i.decode.sin = builder.sin_cache[
|
||||
positions].unsqueeze(1).unsqueeze(2)
|
||||
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
|
||||
if self.speculative_config.disable_padded_drafter_batch or \
|
||||
aclgraph_runtime_mode != CUDAGraphMode.FULL:
|
||||
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
||||
1:batch_size + 1].tolist()
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
attn_metadata_i.decode.actual_seq_lengths_q = \
|
||||
builder.pad_actual_seq_len_q_mtp_disable_pad(
|
||||
graph_pad_size - batch_size,
|
||||
batch_size,
|
||||
attn_metadata_i.decode.actual_seq_lengths_q)
|
||||
attn_metadata_i.decode.cos = builder.cos_cache[
|
||||
positions].unsqueeze(1).unsqueeze(2)
|
||||
attn_metadata_i.decode.sin = builder.sin_cache[
|
||||
positions].unsqueeze(1).unsqueeze(2)
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
@@ -735,6 +854,10 @@ class MtpProposer(Proposer):
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:hidden_states.shape[0]] = hidden_states
|
||||
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
self.positions[batch_size:num_input_tokens] = 0
|
||||
self.input_ids[batch_size:num_input_tokens] = 0
|
||||
self.hidden_states[batch_size:num_input_tokens].fill_(0)
|
||||
|
||||
if attn_metadata_i.prefill is not None:
|
||||
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
|
||||
@@ -751,12 +874,19 @@ class MtpProposer(Proposer):
|
||||
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
|
||||
)
|
||||
decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||
self.speculative_config.disable_padded_drafter_batch:
|
||||
attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [
|
||||
0
|
||||
] * (graph_pad_size - len(decode_seq_lens_list))
|
||||
attn_metadata_i.decode.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata_i.decode.max_seq_lens += 1
|
||||
attn_metadata_i.decode.max_seq_lens = min(
|
||||
attn_metadata_i.decode.max_seq_lens,
|
||||
self.runner.model_config.max_model_len)
|
||||
torch.npu.synchronize()
|
||||
|
||||
# mtp>1: [batch_size, k]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
@@ -915,6 +1045,9 @@ class MtpProposer(Proposer):
|
||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||
token_indices = self.arange[:total_num_tokens]
|
||||
|
||||
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
|
||||
# so we do not need to fixed them. But if they are used in the future,
|
||||
# we should fixed them.
|
||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
|
||||
@@ -3112,7 +3112,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.drafter.dummy_run(
|
||||
num_tokens=num_tokens,
|
||||
with_prefill=with_prefill,
|
||||
skip_attn=True,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
|
||||
Reference in New Issue
Block a user