[Feature] Support to use fullgraph with eagle (#5118)

### What this PR does / why we need it?
    
We support to use full graph with eagle. 

Change list:
1. Distinguish between processing graph_params and draft_graph_params in
attention_v1.
    2. Adapt the full-graph mode in eagle_proposer, include:
        1). If use full graph, make Fullgraph Wrapper when load model.
2). Build a new meatadata, set running mode in FULL and mark attention
update in dummy_run when in Fullgraph mode.
3). Fixed and fill any attn_metadata, such as
attn_metadata.slot_mapping.
        4). Add a descriptor.
        5). Set running mode and triggered update metadata.
3. Trans is_mtp_model to is_draft_model, and add the update of
workspace.

NOTE:
When set async_scheduling=True, the draft model will enforce execution
in eager mode.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com>
This commit is contained in:
anon189Ty
2025-12-29 09:54:51 +08:00
committed by GitHub
parent f81cf694b2
commit 3e67e8276c
11 changed files with 348 additions and 103 deletions

View File

@@ -206,6 +206,51 @@ def test_eagle_correctness(
del llm del llm
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
def test_eaqgle_fullgraph_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
use_eagle3: bool,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle3 speculative decoding
in full-graph mode.
'''
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
with VllmRunner(model_name, max_model_len=1024) as ref_llm:
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)
with VllmRunner(model_name,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 4,
},
compilation_config={
"level": 3,
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_num_of_warmups": 1,
"cudagraph_capture_sizes": [5, 10, 15, 20],
},
max_model_len=1024) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
def test_suffix_correctness( def test_suffix_correctness(
test_prompts: list[list[dict[str, Any]]], test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams, sampling_config: SamplingParams,

View File

@@ -27,9 +27,9 @@ from vllm_ascend.attention.attention_v1 import (AscendMetadata,
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAMetadata) AscendMLAMetadata)
from vllm_ascend.compilation.acl_graph import ( from vllm_ascend.compilation.acl_graph import (
ACLGraphEntry, ACLGraphWrapper, get_graph_params, get_mtp_graph_params, ACLGraphEntry, ACLGraphWrapper, get_draft_graph_params, get_graph_params,
set_graph_params, set_mtp_graph_params, update_attn_dcp_pcp_params, set_draft_graph_params, set_graph_params, update_attn_dcp_pcp_params,
update_mla_attn_dcp_pcp_params, update_mtp_graph_params_workspaces) update_draft_graph_params_workspaces, update_mla_attn_dcp_pcp_params)
class TestACLGraphEntry(TestBase): class TestACLGraphEntry(TestBase):
@@ -713,25 +713,26 @@ class TestACLGraphWrapper(TestBase):
self.assertEqual(unwrapped, self.mock_runnable) self.assertEqual(unwrapped, self.mock_runnable)
class TestMTPGraphParams(TestBase): class TestDraftGraphParams(TestBase):
def test_set_mtp_graph_params(self): def test_set_draft_graph_params(self):
with patch('vllm_ascend.compilation.acl_graph._mtp_graph_params', with patch('vllm_ascend.compilation.acl_graph._draft_graph_params',
new=None): new=None):
set_mtp_graph_params([4]) set_draft_graph_params([4])
from vllm_ascend.compilation.acl_graph import _mtp_graph_params from vllm_ascend.compilation.acl_graph import _draft_graph_params
self.assertIsNotNone(_mtp_graph_params) self.assertIsNotNone(_draft_graph_params)
@patch('vllm_ascend.compilation.acl_graph._mtp_graph_params') @patch('vllm_ascend.compilation.acl_graph._draft_graph_params')
def test_update_mtp_graph_params_workspaces(self, mtp_graph_params_mock): def test_update_draft_graph_params_workspaces(self,
mtp_graph_params_mock.workspaces = {4: 5} draft_graph_params_mock):
update_mtp_graph_params_workspaces(4, 6) draft_graph_params_mock.workspaces = {4: 5}
self.assertEqual(mtp_graph_params_mock.workspaces[4], 6) update_draft_graph_params_workspaces(4, 6)
self.assertEqual(draft_graph_params_mock.workspaces[4], 6)
@patch('vllm_ascend.compilation.acl_graph._mtp_graph_params') @patch('vllm_ascend.compilation.acl_graph._draft_graph_params')
def test_get_mtp_graph_params(self, mtp_graph_params_mock): def test_get_draft_graph_params(self, draft_graph_params_mock):
graph_params = get_mtp_graph_params() graph_params = get_draft_graph_params()
self.assertIs(mtp_graph_params_mock, graph_params) self.assertIs(draft_graph_params_mock, graph_params)
class TestPCPDCPGraphParams(TestBase): class TestPCPDCPGraphParams(TestBase):
@@ -783,7 +784,7 @@ class TestPCPDCPGraphParams(TestBase):
decode=decode) decode=decode)
forward_context = MagicMock() forward_context = MagicMock()
forward_context.attn_metadata = {"attn_layer_0": metadata} forward_context.attn_metadata = {"attn_layer_0": metadata}
forward_context.is_mtp_model = False forward_context.is_draft_model = False
num_heads = 256 num_heads = 256
scale = 0.1 scale = 0.1
@@ -836,7 +837,7 @@ class TestPCPDCPGraphParams(TestBase):
decode_meta=decode) decode_meta=decode)
forward_context = MagicMock() forward_context = MagicMock()
forward_context.attn_metadata = {"attn_layer_0": metadata} forward_context.attn_metadata = {"attn_layer_0": metadata}
forward_context.is_mtp_model = False forward_context.is_draft_model = False
self.graph_params.attn_params[4] = [] self.graph_params.attn_params[4] = []
self.graph_params.attn_params[4].append( self.graph_params.attn_params[4].append(

View File

@@ -2,7 +2,7 @@ from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
import torch import torch
from vllm.config import CacheConfig, CompilationMode, VllmConfig from vllm.config import CacheConfig, CompilationMode, CUDAGraphMode, VllmConfig
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
@@ -33,11 +33,13 @@ class TestEagleProposerInitialization(TestBase):
def tearDown(self): def tearDown(self):
self.mock_cpugpubuffer.stop() self.mock_cpugpubuffer.stop()
def test_initialization_eagle(self): def test_initialization_eagle_graph(self):
self.vllm_config.speculative_config.method = "eagle" self.vllm_config.speculative_config.method = "eagle"
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096 self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
self.vllm_config.model_config.enforce_eager = False self.vllm_config.model_config.enforce_eager = False
self.vllm_config.speculative_config.enforce_eager = False
self.vllm_config.scheduler_config.async_scheduling = False
proposer = EagleProposer(vllm_config=self.vllm_config, proposer = EagleProposer(vllm_config=self.vllm_config,
device=self.device, device=self.device,
@@ -53,7 +55,7 @@ class TestEagleProposerInitialization(TestBase):
self.assertEqual(proposer.hidden_states.shape, (1024, 4096)) self.assertEqual(proposer.hidden_states.shape, (1024, 4096))
self.assertEqual(proposer.arange.shape, (1024, )) self.assertEqual(proposer.arange.shape, (1024, ))
def test_initialization_eagle3(self): def test_initialization_eagle3_enforce_eager(self):
self.vllm_config.speculative_config.method = "eagle3" self.vllm_config.speculative_config.method = "eagle3"
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048 self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048
self.vllm_config.compilation_config.mode = CompilationMode.NONE self.vllm_config.compilation_config.mode = CompilationMode.NONE
@@ -68,6 +70,23 @@ class TestEagleProposerInitialization(TestBase):
self.assertFalse(proposer.use_cuda_graph) self.assertFalse(proposer.use_cuda_graph)
self.assertEqual(proposer.hidden_states.shape, (1024, 2048)) self.assertEqual(proposer.hidden_states.shape, (1024, 2048))
def test_initialization_eagle3_full_graph_async(self):
self.vllm_config.speculative_config.method = "eagle3"
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048
self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
self.vllm_config.model_config.enforce_eager = False
self.vllm_config.speculative_config.enforce_eager = False
self.vllm_config.scheduler_config.async_scheduling = True
proposer = EagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
self.assertEqual(proposer.name, SpecDcodeType.EAGLE3)
self.assertEqual(proposer.hidden_size, 2048)
self.assertFalse(proposer.use_cuda_graph)
self.assertEqual(proposer.hidden_states.shape, (1024, 2048))
class TestEagleProposerLoadModel(TestBase): class TestEagleProposerLoadModel(TestBase):
@@ -176,6 +195,7 @@ class TestEagleProposerDummyRun(TestBase):
def setUp(self): def setUp(self):
self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config = MagicMock(spec=VllmConfig)
self.vllm_config.speculative_config = MagicMock() self.vllm_config.speculative_config = MagicMock()
self.vllm_config.speculative_config.num_speculative_tokens = 4
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.runner = MagicMock() self.runner = MagicMock()
@@ -192,25 +212,64 @@ class TestEagleProposerDummyRun(TestBase):
device=self.device, device=self.device,
runner=self.runner) runner=self.runner)
self.proposer.model = MagicMock() self.proposer.model = MagicMock()
self.proposer.update_stream = MagicMock()
def tearDown(self): def tearDown(self):
self.mock_cpugpubuffer.stop() self.mock_cpugpubuffer.stop()
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_basic(self, mock_context): def test_dummy_run_basic(self, mock_context, mock_get_context):
num_tokens = 32 num_tokens = 32
with_prefill = False with_prefill = False
self.proposer.dummy_run(num_tokens=num_tokens, self.proposer.dummy_run(num_tokens=num_tokens,
with_prefill=with_prefill) with_prefill=with_prefill)
mock_context.assert_called_once() self.assertTrue(self.proposer.model.call_count == 4)
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_with_prefill(self, mock_context): def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
mock_context.return_value.__enter__.return_value = None mock_context.return_value.__enter__.return_value = None
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4) self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
self.proposer.model.assert_called_once() self.assertTrue(self.proposer.model.call_count == 4)
@patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
mock_update_attn_params):
last_use_cuda_graph = self.proposer.use_cuda_graph
mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
mock_return_context.capturing = True
mock_get_context.return_value = mock_return_context
self.proposer.use_cuda_graph = True
self.proposer.dummy_run(num_tokens=64,
in_graph_capturing=True,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
self.assertTrue(self.proposer.model.call_count == 4)
mock_update_attn_params.assert_not_called()
self.proposer.use_cuda_graph = last_use_cuda_graph
@patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
mock_update_attn_params):
last_use_cuda_graph = self.proposer.use_cuda_graph
mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
mock_return_context.capturing = False
mock_get_context.return_value = mock_return_context
self.proposer.use_cuda_graph = True
self.proposer.dummy_run(num_tokens=64,
in_graph_capturing=False,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
self.assertTrue(self.proposer.model.call_count == 4)
self.assertTrue(mock_update_attn_params.call_count == 4)
self.proposer.use_cuda_graph = last_use_cuda_graph
class TestEagleProposerGenerateTokenIds(TestBase): class TestEagleProposerGenerateTokenIds(TestBase):

View File

@@ -36,7 +36,7 @@ def set_ascend_forward_context(
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None, batch_descriptor: Optional[BatchDescriptor] = None,
model_instance: torch.nn.Module = None, model_instance: torch.nn.Module = None,
is_mtp_model=False): is_draft_model=False):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
We add some additional param into forward_context. We add some additional param into forward_context.
@@ -55,7 +55,7 @@ def set_ascend_forward_context(
from vllm_ascend.ops.fused_moe.moe_comm_method import \ from vllm_ascend.ops.fused_moe.moe_comm_method import \
get_moe_comm_method get_moe_comm_method
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, moe_comm_type = select_moe_comm_method(num_tokens, vllm_config,
is_mtp_model) is_draft_model)
forward_context.moe_comm_type = moe_comm_type forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
@@ -110,7 +110,7 @@ def set_ascend_forward_context(
forward_context.prefetch_mlp_down_proj = False forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
forward_context.model_instance = model_instance forward_context.model_instance = model_instance
forward_context.is_mtp_model = is_mtp_model forward_context.is_draft_model = is_draft_model
if num_tokens is None and attn_metadata is not None: if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens num_tokens = attn_metadata.num_actual_tokens
@@ -195,7 +195,7 @@ def get_mc2_mask():
def select_moe_comm_method(num_tokens: int, def select_moe_comm_method(num_tokens: int,
vllm_config: VllmConfig, vllm_config: VllmConfig,
is_mtp_model=False) -> Optional[MoECommType]: is_draft_model=False) -> Optional[MoECommType]:
"""Select the MoE communication method according to parallel settings, """Select the MoE communication method according to parallel settings,
device generation, token count, and quantization. device generation, token count, and quantization.
@@ -210,7 +210,7 @@ def select_moe_comm_method(num_tokens: int,
Args: Args:
num_tokens (int): The number of tokens in the current batch. num_tokens (int): The number of tokens in the current batch.
vllm_config (VllmConfig): Runtime configuration for the model. vllm_config (VllmConfig): Runtime configuration for the model.
is_mtp_model (bool): Whether the model runs in MTP mode (disables fused MC2). is_draft_model (bool): Whether the model runs in MTP mode (disables fused MC2).
Raises: Raises:
ValueError: If the soc version is unsupported. ValueError: If the soc version is unsupported.
@@ -249,13 +249,13 @@ def select_moe_comm_method(num_tokens: int,
fused_decode_enable = fused_mc2_enable fused_decode_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_decode_enable = fused_mc2_enable and get_ep_group( fused_decode_enable = fused_mc2_enable and get_ep_group(
).world_size <= 16 and (not is_mtp_model) ).world_size <= 16 and (not is_draft_model)
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2 moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
else: else:
fused_prefill_enable = fused_mc2_enable fused_prefill_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_prefill_enable = fused_mc2_enable and get_ep_group( fused_prefill_enable = fused_mc2_enable and get_ep_group(
).world_size <= 16 and (not is_mtp_model) ).world_size <= 16 and (not is_draft_model)
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_prefill_enable = False fused_prefill_enable = False
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL

View File

@@ -38,8 +38,9 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendMetadataForPrefill, enable_cp, AscendMetadataForPrefill, enable_cp,
split_decodes_and_prefills, split_decodes_and_prefills,
using_paged_attention) using_paged_attention)
from vllm_ascend.compilation.acl_graph import (get_graph_params, from vllm_ascend.compilation.acl_graph import (
update_graph_params_workspaces) get_draft_graph_params, get_graph_params,
update_draft_graph_params_workspaces, update_graph_params_workspaces)
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
weak_ref_tensors) weak_ref_tensors)
@@ -262,7 +263,9 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
): ):
if attn_state == AscendAttentionState.DecodeOnly:
if attn_state in (AscendAttentionState.DecodeOnly,
AscendAttentionState.ChunkedPrefill):
attn_metadata = self.build( attn_metadata = self.build(
common_prefix_len=0, common_prefix_len=0,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
@@ -319,7 +322,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
= self._get_fia_params(key, value, attn_metadata) = self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1] num_tokens = attn_metadata.actual_seq_lengths_q[-1]
graph_params = get_graph_params() forward_context = get_forward_context()
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
# Prepare tensors for attention output # Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level # TODO: Refactor this to step-level instead of layer-level
@@ -343,7 +350,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
sparse_mode=3, sparse_mode=3,
scale=self.scale, scale=self.scale,
) )
update_graph_params_workspaces(num_tokens, workspace) if forward_context.is_draft_model:
update_draft_graph_params_workspaces(num_tokens, workspace)
else:
update_graph_params_workspaces(num_tokens, workspace)
# Handle graph capturing mode # Handle graph capturing mode
stream = torch_npu.npu.current_stream() stream = torch_npu.npu.current_stream()

View File

@@ -26,8 +26,8 @@ from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata) from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata)
from vllm_ascend.attention.common_cp import (AscendPCPMetadata, from vllm_ascend.attention.common_cp import (AscendPCPMetadata,
CPChunkedContextMetadata) CPChunkedContextMetadata)
from vllm_ascend.compilation.acl_graph import (get_graph_params, from vllm_ascend.compilation.acl_graph import (get_draft_graph_params,
get_mtp_graph_params, get_graph_params,
update_graph_params_workspaces) update_graph_params_workspaces)
from vllm_ascend.utils import weak_ref_tensors from vllm_ascend.utils import weak_ref_tensors
@@ -555,8 +555,8 @@ class AscendMlaCPImpl(AscendMLAImpl):
"calc_type": "calc_type_ring", "calc_type": "calc_type_ring",
} }
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
if forward_context.is_mtp_model: if forward_context.is_draft_model:
graph_params = get_mtp_graph_params() graph_params = get_draft_graph_params()
else: else:
graph_params = get_graph_params() graph_params = get_graph_params()
if forward_context.capturing: if forward_context.capturing:

View File

@@ -27,9 +27,9 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills, split_decodes_and_prefills,
trans_rope_weight, transdata, trans_rope_weight, transdata,
wait_for_kv_layer_from_connector) wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params, from vllm_ascend.compilation.acl_graph import (
get_mtp_graph_params, get_draft_graph_params, get_graph_params,
update_graph_params_workspaces) update_draft_graph_params_workspaces, update_graph_params_workspaces)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import ( from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series, is_hidden_layer, post_process_after_loading_for_shared_weight_series,
@@ -1184,8 +1184,8 @@ class AscendMLAImpl(MLAAttentionImpl):
"actual_seq_lengths_kv": decode_meta.seq_lens_list, "actual_seq_lengths_kv": decode_meta.seq_lens_list,
} }
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
if forward_context.is_mtp_model: if forward_context.is_draft_model:
graph_params = get_mtp_graph_params() graph_params = get_draft_graph_params()
else: else:
graph_params = get_graph_params() graph_params = get_graph_params()
if forward_context.capturing: if forward_context.capturing:
@@ -1200,7 +1200,10 @@ class AscendMLAImpl(MLAAttentionImpl):
if workspace is None: if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, k_nope, **common_kwargs) q_nope, k_nope, k_nope, **common_kwargs)
update_graph_params_workspaces(num_tokens, workspace) if forward_context.is_draft_model:
update_draft_graph_params_workspaces(num_tokens, workspace)
else:
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty( attn_output = torch.empty(
(q_nope.shape[1], q_nope.shape[0], *q_nope.shape[2:]), (q_nope.shape[1], q_nope.shape[0], *q_nope.shape[2:]),

View File

@@ -254,7 +254,10 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
def _update_attn_fia_params(update_stream, forward_context, runtime_shape): def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params() if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
# For Qwen3-next, since the kv_cache_config has already categorized # For Qwen3-next, since the kv_cache_config has already categorized
# linear_attn and self_attn, the attn_metadata is first arranged with # linear_attn and self_attn, the attn_metadata is first arranged with
# self_attn followed by linear_attn. Therefore, using zip directly # self_attn followed by linear_attn. Therefore, using zip directly
@@ -306,8 +309,8 @@ def update_attn_params(update_stream, forward_context, runtime_shape,
def update_mla_attn_params(update_stream, forward_context, runtime_shape, def update_mla_attn_params(update_stream, forward_context, runtime_shape,
speculative_config): speculative_config):
if forward_context.is_mtp_model: if forward_context.is_draft_model:
graph_params = get_mtp_graph_params() graph_params = get_draft_graph_params()
else: else:
graph_params = get_graph_params() graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args # FIXME: Behold! We are using a temporary hack here to update the args
@@ -326,7 +329,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
seq_lens_list = forward_context.attn_metadata[ seq_lens_list = forward_context.attn_metadata[
key].decode.seq_lens_list key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" \ if speculative_config and speculative_config.method == "mtp" \
and not forward_context.is_mtp_model: and not forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[ actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1 spec_multiple = speculative_config.num_speculative_tokens + 1
@@ -336,7 +339,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
spec_multiple * (i + 1) spec_multiple * (i + 1)
for i in range(runtime_shape // spec_multiple) for i in range(runtime_shape // spec_multiple)
] ]
elif forward_context.is_mtp_model: elif forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[ actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[ block_table = forward_context.attn_metadata[
@@ -440,8 +443,8 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
def update_mla_attn_dcp_pcp_params(update_stream, forward_context, def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
runtime_shape): runtime_shape):
if forward_context.is_mtp_model: if forward_context.is_draft_model:
graph_params = get_mtp_graph_params() graph_params = get_draft_graph_params()
else: else:
graph_params = get_graph_params() graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args # FIXME: Behold! We are using a temporary hack here to update the args
@@ -527,14 +530,14 @@ def get_graph_params():
return _graph_params return _graph_params
_mtp_graph_params: Optional[GraphParams] = None _draft_graph_params: Optional[GraphParams] = None
def set_mtp_graph_params(aclgraph_capture_sizes: list[int]): def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
global _mtp_graph_params global _draft_graph_params
if _mtp_graph_params is not None: if _draft_graph_params is not None:
raise ValueError("MTPGraph parameters have already been set!") raise ValueError("DraftGraph parameters have already been set!")
_mtp_graph_params = GraphParams( _draft_graph_params = GraphParams(
{size: [] {size: []
for size in aclgraph_capture_sizes}, for size in aclgraph_capture_sizes},
{size: None {size: None
@@ -546,11 +549,11 @@ def set_mtp_graph_params(aclgraph_capture_sizes: list[int]):
) )
def update_mtp_graph_params_workspaces(num_tokens: int, workspace: Any): def update_draft_graph_params_workspaces(num_tokens: int, workspace: Any):
global _mtp_graph_params global _draft_graph_params
if _mtp_graph_params is not None: if _draft_graph_params is not None:
_mtp_graph_params.workspaces[num_tokens] = workspace _draft_graph_params.workspaces[num_tokens] = workspace
def get_mtp_graph_params(): def get_draft_graph_params():
return _mtp_graph_params return _draft_graph_params

View File

@@ -8,6 +8,7 @@ from vllm.attention.layer import Attention
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
@@ -25,6 +26,8 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
update_attn_params)
from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.ops.rotary_embedding import update_cos_sin
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
@@ -48,6 +51,8 @@ class EagleProposer(Proposer):
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.draft_model_config = self.speculative_config.draft_model_config self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method self.method = self.speculative_config.method
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
# We need to get the hidden size from the draft model config because # We need to get the hidden size from the draft model config because
@@ -56,9 +61,17 @@ class EagleProposer(Proposer):
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size( self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size(
) )
self.use_cuda_graph = (self.vllm_config.compilation_config.mode # there is synchronization between mtp steps when enabling aclgraph,
== CompilationMode.VLLM_COMPILE and # disable aclgraph when use async scheduling to avoid the
not self.vllm_config.model_config.enforce_eager) # synchronization overhead.
# NOTE: we need to set aclgraph_runtime_mode to None in both dummy_run
# and _propose.
self.use_cuda_graph = (
self.vllm_config.compilation_config.mode
== CompilationMode.VLLM_COMPILE
and not self.vllm_config.model_config.enforce_eager
and not self.use_async_scheduling
and not self.vllm_config.speculative_config.enforce_eager)
self.cudagraph_batch_sizes = list( self.cudagraph_batch_sizes = list(
sorted( sorted(
@@ -74,8 +87,7 @@ class EagleProposer(Proposer):
device=device, device=device,
with_numpy=True, with_numpy=True,
) )
self.decode_threshold = 1 + \ self.decode_threshold = 1 + self.num_speculative_tokens
self.vllm_config.speculative_config.num_speculative_tokens
# persistent buffers for cuda graph # persistent buffers for cuda graph
self.input_ids = torch.zeros( self.input_ids = torch.zeros(
@@ -160,6 +172,19 @@ class EagleProposer(Proposer):
else: else:
self.model.lm_head = model.lm_head self.model.lm_head = model.lm_head
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
) and self.use_cuda_graph:
self.update_stream = torch.npu.Stream()
self.model = ACLGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
def get_model(self) -> nn.Module:
# get raw model out of the aclgraph wrapper.
if isinstance(self.model, ACLGraphWrapper):
return self.model.unwrap()
return self.model
@torch.inference_mode() @torch.inference_mode()
def dummy_run(self, def dummy_run(self,
num_tokens: int, num_tokens: int,
@@ -174,16 +199,73 @@ class EagleProposer(Proposer):
# update global cos, sin # update global cos, sin
update_cos_sin(self.positions[:num_tokens]) update_cos_sin(self.positions[:num_tokens])
with set_ascend_forward_context(None, attn_metadata = None
self.vllm_config, if not self.use_cuda_graph:
num_tokens=num_tokens): aclgraph_runtime_mode = CUDAGraphMode.NONE
self.model( if aclgraph_runtime_mode == CUDAGraphMode.FULL and len(
input_ids=self.input_ids[:num_tokens], self.runner.attn_groups) > 0:
positions=self.positions[:num_tokens], num_computed_tokens_cpu = (
hidden_states=self.hidden_states[:num_tokens], self.runner.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs])
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc.gpu[:num_reqs + 1],
query_start_loc_cpu=self.runner.query_start_loc.cpu[:num_reqs +
1],
seq_lens_cpu=self.runner.seq_lens.cpu,
seq_lens=self.runner.seq_lens.gpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
num_input_tokens=num_tokens,
max_query_len=self.num_speculative_tokens + 1,
num_computed_tokens_cpu=num_computed_tokens_cpu,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor()[:num_reqs],
slot_mapping=self.runner.input_batch.block_table[0].
slot_mapping.gpu,
positions=self.runner.positions.gpu,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
max_seq_len=0,
) )
dummy_compute_logits(self.hidden_states) dummy_compute_logits(self.hidden_states)
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_eagle = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.ChunkedPrefill)
attn_metadata = {}
for layer_name in [self.attn_layer_name]:
attn_metadata[layer_name] = attn_metadata_eagle
for i in range(self.num_speculative_tokens):
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
aclgraph_runtime_mode = CUDAGraphMode.NONE
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_actual_tokens=0,
in_profile_run=is_profile,
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
forward_context = get_forward_context()
self.model(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
if (forward_context.cudagraph_runtime_mode
== CUDAGraphMode.FULL
and not forward_context.capturing):
update_attn_params(
self.update_stream,
forward_context,
num_tokens,
self.vllm_config,
)
def generate_token_ids(self, def generate_token_ids(self,
sampled_token_ids: torch.Tensor | list[list[int]], sampled_token_ids: torch.Tensor | list[list[int]],
sampling_metadata: SamplingMetadata = None, sampling_metadata: SamplingMetadata = None,
@@ -343,7 +425,7 @@ class EagleProposer(Proposer):
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if self.name == SpecDcodeType.EAGLE3: if self.name == SpecDcodeType.EAGLE3:
assert isinstance(self.model, Eagle3LlamaForCausalLM) assert isinstance(self.get_model(), Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states( target_hidden_states = self.model.combine_hidden_states(
target_hidden_states) target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size assert target_hidden_states.shape[-1] == self.hidden_size
@@ -361,6 +443,14 @@ class EagleProposer(Proposer):
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
if self.use_cuda_graph:
aclgraph_runtime_mode, batch_descriptor = \
self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=True, has_lora=has_lora)
else:
aclgraph_runtime_mode = CUDAGraphMode.NONE
batch_descriptor = None
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
@@ -372,27 +462,40 @@ class EagleProposer(Proposer):
# update global cos, sin # update global cos, sin
update_cos_sin(self.positions[:num_input_tokens]) update_cos_sin(self.positions[:num_input_tokens])
with set_ascend_forward_context(attn_metadata, with set_ascend_forward_context(
self.vllm_config, {self.attn_layer_name: attn_metadata},
num_tokens=num_input_tokens): self.vllm_config,
num_tokens=num_input_tokens,
num_actual_tokens=num_tokens,
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens], input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens],
) )
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
# TODO: support mla in future.
update_attn_params(
self.update_stream,
forward_context,
num_input_tokens,
self.vllm_config,
)
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
if self.vllm_config.speculative_config.num_speculative_tokens == 1: if self.num_speculative_tokens == 1:
# [batch_size, 1] # [batch_size, 1]
return draft_token_ids.view(-1, 1) return draft_token_ids.view(-1, 1)
# Generate the remaining draft tokens. # Generate the remaining draft tokens.
draft_token_ids_tensor = torch.zeros( draft_token_ids_tensor = torch.zeros(
(self.vllm_config.speculative_config.num_speculative_tokens, (self.num_speculative_tokens, *draft_token_ids.shape),
*draft_token_ids.shape),
dtype=draft_token_ids.dtype, dtype=draft_token_ids.dtype,
device=self.device) device=self.device)
draft_token_ids_tensor[0] = draft_token_ids draft_token_ids_tensor[0] = draft_token_ids
@@ -417,9 +520,13 @@ class EagleProposer(Proposer):
1:].tolist() 1:].tolist()
attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist() attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist()
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
for now_speculative in range( if self.use_cuda_graph:
self.vllm_config.speculative_config.num_speculative_tokens - aclgraph_runtime_mode, batch_descriptor = \
1): self.runner.cudagraph_dispatcher.dispatch(num_tokens=input_batch_size, uniform_decode=True, has_lora=has_lora)
else:
aclgraph_runtime_mode = CUDAGraphMode.NONE
batch_descriptor = None
for now_speculative in range(self.num_speculative_tokens - 1):
# Update the inputs. # Update the inputs.
# cast to int32 is crucial when eagle model is compiled. # cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default. # tensor.argmax() returns int64 by default.
@@ -467,6 +574,8 @@ class EagleProposer(Proposer):
# NOTE: ASCEND slot_mapping must on cpu # NOTE: ASCEND slot_mapping must on cpu
attn_metadata.slot_mapping[:slot_mapping_tmp.shape[0]].copy_( attn_metadata.slot_mapping[:slot_mapping_tmp.shape[0]].copy_(
slot_mapping_tmp.to(torch.int32)) slot_mapping_tmp.to(torch.int32))
attn_metadata.slot_mapping[slot_mapping_tmp.shape[0]:].fill_(
PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions self.positions[:batch_size] = clamped_positions
@@ -474,20 +583,33 @@ class EagleProposer(Proposer):
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask() attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
attn_metadata.attn_mask = attn_mask attn_metadata.attn_mask = attn_mask
# Run the model.
# update global cos, sin # update global cos, sin
update_cos_sin(self.positions[:input_batch_size]) update_cos_sin(self.positions[:input_batch_size])
with set_ascend_forward_context(attn_metadata, # Run the model.
self.vllm_config, with set_ascend_forward_context(
num_tokens=input_batch_size): {self.attn_layer_name: attn_metadata},
self.vllm_config,
num_tokens=input_batch_size,
num_actual_tokens=batch_size,
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size], input_ids=self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size], positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size], hidden_states=self.hidden_states[:input_batch_size],
) )
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
update_attn_params(
self.update_stream,
forward_context,
input_batch_size,
self.vllm_config,
)
hidden_states = hidden_states[:batch_size] hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size]) logits = self.model.compute_logits(last_hidden_states[:batch_size])
@@ -719,7 +841,7 @@ class EagleProposer(Proposer):
common_attn_metadata.slot_mapping[token_indices]) common_attn_metadata.slot_mapping[token_indices])
common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1) common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1)
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward # NOTE: Currently positions and seq_lens are not used in attn forward
# so we do not need to fixed them. But if they are used in the future, # so we do not need to fixed them. But if they are used in the future,
# we should fixed them. # we should fixed them.
spec_common_attn_metadata = AscendCommonAttentionMetadata( spec_common_attn_metadata = AscendCommonAttentionMetadata(
@@ -779,7 +901,7 @@ class EagleProposer(Proposer):
total_num_tokens = query_start_loc_cpu[-1].item() total_num_tokens = query_start_loc_cpu[-1].item()
token_indices = self.arange[:total_num_tokens] token_indices = self.arange[:total_num_tokens]
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward # NOTE: Currently positions and seq_lens are not used in attn forward
# so we do not need to fixed them. But if they are used in the future, # so we do not need to fixed them. But if they are used in the future,
# we should fixed them. # we should fixed them.
spec_common_attn_metadata = AscendCommonAttentionMetadata( spec_common_attn_metadata = AscendCommonAttentionMetadata(
@@ -803,7 +925,8 @@ class EagleProposer(Proposer):
seq_lens=common_attn_metadata.seq_lens, seq_lens=common_attn_metadata.seq_lens,
max_seq_len=0) max_seq_len=0)
token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] - query_start_loc = common_attn_metadata.query_start_loc[
1 - num_rejected_tokens_gpu) 1:1 + num_rejected_tokens_gpu.shape[0]]
token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu
return spec_common_attn_metadata, token_indices, token_indices_to_sample return spec_common_attn_metadata, token_indices, token_indices_to_sample

View File

@@ -303,7 +303,7 @@ class MtpProposer(Proposer):
num_actual_tokens=0, num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
is_mtp_model=True, is_draft_model=True,
in_profile_run=is_profile): in_profile_run=is_profile):
if self.enable_shared_expert_dp: if self.enable_shared_expert_dp:
positions = positions.unsqueeze(-1) positions = positions.unsqueeze(-1)
@@ -782,7 +782,7 @@ class MtpProposer(Proposer):
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
is_mtp_model=True): is_draft_model=True):
with ProfileExecuteDuration().capture_async('mtp_forward'): with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {} model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata model_kwargs["attn_metadata"] = attn_metadata

View File

@@ -89,8 +89,8 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_draft_graph_params,
set_graph_params, set_graph_params,
set_mtp_graph_params,
update_attn_dcp_pcp_params, update_attn_dcp_pcp_params,
update_attn_params, update_attn_params,
update_mla_attn_dcp_pcp_params, update_mla_attn_dcp_pcp_params,
@@ -1104,7 +1104,8 @@ class NPUModelRunner(GPUModelRunner):
self.spec_decode_common_attn_metadata is None: self.spec_decode_common_attn_metadata is None:
self.spec_decode_common_attn_metadata = common_attn_metadata self.spec_decode_common_attn_metadata = common_attn_metadata
if self.speculative_config.method in ("eagle", "eagle3") and \ if self.speculative_config.method in ("eagle", "eagle3") and \
self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(): (self.vllm_config.speculative_config.enforce_eager \
or self.use_async_scheduling):
self.spec_decode_common_attn_metadata = \ self.spec_decode_common_attn_metadata = \
self.spec_decode_common_attn_metadata.unpadded( self.spec_decode_common_attn_metadata.unpadded(
total_num_scheduled_tokens, base_num_reqs) total_num_scheduled_tokens, base_num_reqs)
@@ -2916,7 +2917,7 @@ class NPUModelRunner(GPUModelRunner):
# we set the graph params right before initializing the keys. # we set the graph params right before initializing the keys.
set_graph_params(self.cudagraph_batch_sizes) set_graph_params(self.cudagraph_batch_sizes)
if self.speculative_config: if self.speculative_config:
set_mtp_graph_params(self.cudagraph_batch_sizes) set_draft_graph_params(self.cudagraph_batch_sizes)
self.cudagraph_dispatcher.initialize_cudagraph_keys( self.cudagraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode, self.compilation_config.cudagraph_mode,