This reverts commit8966a99710. It breaks the test `tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py::test_deepseek_mtp_correctness[True-FULL_DECODE_ONLY-2-wemaster/deepseek_mtp_main_random_bf16]` - vLLM version: v0.14.0 - vLLM main:d68209402d
This commit is contained in:
@@ -333,11 +333,11 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
|
||||
self.assertTrue(self.proposer._runnable.call_count == 1)
|
||||
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params")
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||
def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
|
||||
mock_update_full_graph_params):
|
||||
mock_update_attn_params):
|
||||
last_use_cuda_graph = self.proposer.use_cuda_graph
|
||||
mock_return_context = MagicMock()
|
||||
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
@@ -352,14 +352,14 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
in_graph_capturing=True,
|
||||
aclgraph_runtime_mode=CUDAGraphMode.FULL)
|
||||
self.assertTrue(self.proposer._runnable.call_count == 1)
|
||||
mock_update_full_graph_params.assert_not_called()
|
||||
mock_update_attn_params.assert_not_called()
|
||||
self.proposer.use_cuda_graph = last_use_cuda_graph
|
||||
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params")
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||
def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
|
||||
mock_update_full_graph_params):
|
||||
mock_update_attn_params):
|
||||
last_use_cuda_graph = self.proposer.use_cuda_graph
|
||||
mock_return_context = MagicMock()
|
||||
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
@@ -374,7 +374,7 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
in_graph_capturing=False,
|
||||
aclgraph_runtime_mode=CUDAGraphMode.FULL)
|
||||
self.assertTrue(self.proposer._runnable.call_count == 1)
|
||||
self.assertTrue(mock_update_full_graph_params.call_count == 1)
|
||||
self.assertTrue(mock_update_attn_params.call_count == 1)
|
||||
self.proposer.use_cuda_graph = last_use_cuda_graph
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user