Revert "[Refactor] Unify full-graph parameter update logic (#6041)" (#6227)

This reverts commit 8966a99710.

It breaks the test
`tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py::test_deepseek_mtp_correctness[True-FULL_DECODE_ONLY-2-wemaster/deepseek_mtp_main_random_bf16]`

- vLLM version: v0.14.0
- vLLM main:
d68209402d
This commit is contained in:
wangxiyuan
2026-01-25 15:25:38 +08:00
committed by GitHub
parent 7799c4ca3b
commit 95649344aa
10 changed files with 415 additions and 420 deletions

View File

@@ -24,15 +24,12 @@ from vllm.forward_context import BatchDescriptor, ForwardContext
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
AscendMetadataForDecode)
from vllm_ascend.attention.context_parallel.attention_cp import \
AscendAttentionCPImpl
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAMetadata)
from vllm_ascend.compilation.acl_graph import (
ACLGraphEntry, ACLGraphWrapper, get_draft_graph_params, get_graph_params,
set_draft_graph_params, set_graph_params,
update_draft_graph_params_workspaces)
set_draft_graph_params, set_graph_params, update_attn_dcp_pcp_params,
update_draft_graph_params_workspaces, update_mla_attn_dcp_pcp_params)
class TestACLGraphEntry(TestBase):
@@ -814,9 +811,8 @@ class TestPCPDCPGraphParams(TestBase):
out, lse))
with patch("torch_npu._C._npu_setStream", return_value=None):
AscendMlaCPImpl.update_graph_params(
self.update_stream, forward_context, 4
)
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context,
4)
_mock_graph_task_end.assert_called_once()
@@ -856,8 +852,6 @@ class TestPCPDCPGraphParams(TestBase):
out, lse, 2, 0, 0))
with patch("torch_npu._C._npu_setStream", return_value=None):
AscendAttentionCPImpl.update_graph_params(
self.update_stream, forward_context, 4, None
)
update_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
_mock_graph_task_end.assert_called_once()

View File

@@ -333,11 +333,11 @@ class TestEagleProposerDummyRun(TestBase):
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
self.assertTrue(self.proposer._runnable.call_count == 1)
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
mock_update_full_graph_params):
mock_update_attn_params):
last_use_cuda_graph = self.proposer.use_cuda_graph
mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
@@ -352,14 +352,14 @@ class TestEagleProposerDummyRun(TestBase):
in_graph_capturing=True,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
self.assertTrue(self.proposer._runnable.call_count == 1)
mock_update_full_graph_params.assert_not_called()
mock_update_attn_params.assert_not_called()
self.proposer.use_cuda_graph = last_use_cuda_graph
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
mock_update_full_graph_params):
mock_update_attn_params):
last_use_cuda_graph = self.proposer.use_cuda_graph
mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
@@ -374,7 +374,7 @@ class TestEagleProposerDummyRun(TestBase):
in_graph_capturing=False,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
self.assertTrue(self.proposer._runnable.call_count == 1)
self.assertTrue(mock_update_full_graph_params.call_count == 1)
self.assertTrue(mock_update_attn_params.call_count == 1)
self.proposer.use_cuda_graph = last_use_cuda_graph