[Feature] support aclgraph for model runner v2 (#7110)

### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model. 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-03-13 09:11:46 +08:00
committed by GitHub
parent 1f71da80eb
commit c980e68d40
52 changed files with 840 additions and 309 deletions

View File

@@ -184,7 +184,7 @@ def test_token_dispatcher_with_all_gather_quant(
):
context_mock = MagicMock()
context_mock.fused_moe_state = 0
with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context",
with patch("vllm_ascend.ascend_forward_context.get_forward_context",
return_value=context_mock):
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)

View File

@@ -85,3 +85,29 @@ def test_egale_spec_decoding(
},
) as runner:
runner.model.generate(prompts, sampling_params)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("compilation_config", [{"cudagraph_mode": "FULL_DECODE_ONLY"}, {}])
@patch.dict(os.environ, {"VLLM_USE_V2_MODEL_RUNNER": "1"})
def test_qwen3_dense_graph_mode(
model: str,
max_tokens: int,
enforce_eager: bool,
) -> None:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=enforce_eager,
) as runner:
runner.model.generate(prompts, sampling_params)

View File

@@ -74,7 +74,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
@patch("torch_npu._npu_reshape_and_cache")
@patch("torch_npu._npu_flash_attention")
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_forward_prefill_310(
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
):
@@ -105,7 +105,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
@patch("torch_npu._npu_reshape_and_cache")
@patch("torch_npu._npu_paged_attention_splitfuse")
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_forward_chunked_prefill_310(
self,
mock_get_forward_context,
@@ -140,7 +140,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
@patch("torch_npu._npu_reshape_and_cache")
@patch("torch_npu._npu_paged_attention_splitfuse")
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_forward_prefill_cache_hit_310(
self,
mock_get_forward_context,
@@ -175,7 +175,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
@patch("vllm_ascend.attention.attention_v1.using_paged_attention")
@patch("torch_npu._npu_paged_attention")
@patch("torch_npu._npu_reshape_and_cache")
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_forward_paged_attention_310(
self, mock_get_forward_context, mock_npu_reshape_and_cache, mock_paged_attention, mock_using_paged_attention
):

View File

@@ -95,7 +95,7 @@ class TestAscendAttentionCPImpl(TestBase):
@patch('torch_npu.npu_attention_update')
@patch("torch_npu.npu_fused_infer_attention_score")
@patch(
'vllm_ascend.attention.context_parallel.attention_cp.get_forward_context'
'vllm_ascend.ascend_forward_context.get_forward_context'
)
@patch_distributed_groups(dcp_size=2, pcp_size=2)
def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp,

View File

@@ -212,7 +212,7 @@ class TestAscendAttentionBackendImpl(TestBase):
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
def test_forward_fused_infer_attention(
self, mock_get_forward_context,
mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache):
@@ -248,7 +248,7 @@ class TestAscendAttentionBackendImpl(TestBase):
@patch('vllm_ascend.attention.attention_v1.using_paged_attention')
@patch('torch_npu._npu_paged_attention')
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
def test_forward_paged_attention(self, mock_get_forward_context,
mock_npu_reshape_and_cache,
mock_paged_attention,
@@ -279,7 +279,7 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_paged_attention.assert_called_once()
assert output.shape == (4, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
@@ -311,7 +311,7 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8, 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('torch_npu._npu_paged_attention')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('torch_npu._npu_reshape_and_cache')

View File

@@ -449,7 +449,7 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("torch_npu.npu_fused_infer_attention_score")
@patch('torch_npu.npu_attention_update')
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)

View File

@@ -929,7 +929,7 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
@patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_decode_without_graph(self,
@@ -1095,7 +1095,7 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
mock_get_forward_context):

View File

@@ -161,12 +161,13 @@ class TestACLGraphWrapper(TestBase):
vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.NONE)
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_call_with_none_runtime_mode(self, mock_envs,
mock_current_platform,
mock_get_forward_context):
mock_get_forward_context, mock_get_forward_context_2):
"""Test __call__ method when runtime mode is NONE"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
@@ -185,16 +186,19 @@ class TestACLGraphWrapper(TestBase):
self.mock_runnable.assert_called_once_with("arg1", "arg2")
self.assertEqual(result, "test_output")
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
def test_call_with_mismatched_runtime_mode(self, mock_envs,
mock_current_platform,
mock_get_forward_context):
mock_get_forward_context,
mock_get_forward_context_2):
"""Test __call__ method when runtime mode doesn't match wrapper mode"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
wrapper = ACLGraphWrapper(
@@ -214,18 +218,20 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_capture_graph_first_time(
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context,
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph for the first time"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Mock torch.npu.NPUGraph
@@ -284,6 +290,7 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@@ -291,12 +298,15 @@ class TestACLGraphWrapper(TestBase):
def test_call_replay_graph(self, mock_weak_ref_tensors,
mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context,
mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled,
mock_torch):
"""Test __call__ method replays graph when already captured"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
self.mock_forward_context.is_draft_model = False
@@ -358,17 +368,19 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_with_debug_mode_input_address_check(
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
mock_get_forward_context,
mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method with debug mode input address checking"""
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
self.mock_forward_context.is_draft_model = False
@@ -413,17 +425,19 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_with_debug_mode_input_address_mismatch(
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
mock_get_forward_context,
mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method with debug mode input address mismatch raises AssertionError"""
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Mock torch.npu.NPUGraph
@@ -471,6 +485,7 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@@ -478,12 +493,13 @@ class TestACLGraphWrapper(TestBase):
@patch('vllm_ascend.compilation.acl_graph.patch')
def test_call_capture_graph_with_gc_disable(
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
mock_envs, mock_current_platform, mock_get_forward_context,
mock_envs, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph with gc_disable option enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable gc_disable option
@@ -545,18 +561,20 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
)
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_capture_graph_with_weak_ref_output(
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context,
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph with weak_ref_output option enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable weak_ref_output option
@@ -608,18 +626,20 @@ class TestACLGraphWrapper(TestBase):
# Should return the weak ref output when weak_ref_output option is enabled
self.assertEqual(result, "weak_ref_output")
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.logger')
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
mock_current_platform,
mock_get_forward_context):
mock_get_forward_context,mock_get_forward_context_2):
"""Test __call__ method captures graph with debug logging enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable debug logging
@@ -757,10 +777,11 @@ class TestPCPDCPGraphParams(TestBase):
self.graph_params.events[4].append(mock_event)
self.graph_params.handles[4].append(MagicMock())
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('torch.npu.graph_task_update_end', )
@patch('torch.npu.graph_task_update_begin', MagicMock())
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context):
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
block_table = torch.zeros(2, 5, dtype=torch.long)
seq_lens = torch.tensor([4, 4])
@@ -790,6 +811,7 @@ class TestPCPDCPGraphParams(TestBase):
forward_context = MagicMock()
forward_context.attn_metadata = {"attn_layer_0": metadata}
forward_context.is_draft_model = False
mock_context.return_value = forward_context
num_heads = 256
scale = 0.1

View File

@@ -119,11 +119,9 @@ def mock_dist_env(mocker: MockerFixture):
return_value=(torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]), None, 0)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context',
patch('vllm_ascend.ascend_forward_context.get_forward_context',
return_value=mock_forward_context_obj), \
patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3), \
patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
@@ -298,7 +296,7 @@ class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
return_value=MagicMock())
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType.A3)
@patch('torch_npu.npu_grouped_matmul')
@@ -407,7 +405,7 @@ class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.fused_moe.moe_mlp.HAS_TRITON', False)
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
return_value=MagicMock())
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
@@ -513,7 +511,7 @@ class TestUnifiedApplyMLP(TestBase):
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method",
return_value=MagicMock())
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")

View File

@@ -121,9 +121,10 @@ class TestAscendMultiHeadLatentAttention(TestBase):
@patch("vllm_ascend.ops.mla.get_ascend_config")
@patch("vllm_ascend.ops.mla.get_tensor_model_parallel_world_size")
@patch("vllm_ascend.ops.mla.get_forward_context")
def test_forward(self, mock_get_forward_context, mock_tp_size,
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
def test_forward(self, mock_get_forward_context_2, mock_get_forward_context, mock_tp_size,
mock_ascend_config, mock_get_vllm_config,
mock_mla_forward):
mock_mla_forward,):
mock_tp_size.return_value = 1
mock_ascend_config.return_value.enable_shared_expert_dp = False
mock_vllm_config = MagicMock(spec=VllmConfig)
@@ -159,6 +160,7 @@ class TestAscendMultiHeadLatentAttention(TestBase):
mock_forward_context = MagicMock(spec=ForwardContext)
mock_forward_context.flash_comm_v1_enabled = False
mock_get_forward_context.return_value = mock_forward_context
mock_get_forward_context_2.return_value = mock_forward_context
mock_mla_forward.return_value = (3, self.hidden_size)

View File

@@ -28,7 +28,7 @@ class TestMoECommMethod(TestBase):
self.moe_config.dp_group = MagicMock()
self.moe_config.global_redundant_expert_num = 0
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
)
@@ -73,7 +73,7 @@ class TestMoECommMethod(TestBase):
context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
@@ -116,7 +116,7 @@ class TestMoECommMethod(TestBase):
context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
)
@@ -155,7 +155,7 @@ class TestMoECommMethod(TestBase):
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, QuantType.NONE)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
)

View File

@@ -32,7 +32,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
@patch(
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
return_value=0)
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
mock_tp_size):
mock_context = MagicMock()
@@ -65,7 +65,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
@patch(
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
return_value=0)
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("torch.distributed.all_gather")
def test_mc2_tp_split_allgather(self, mock_all_gather,
mock_get_forward_context, mock_tp_rank,
@@ -169,7 +169,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
self.assertEqual(final_result.shape[0], 2)
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
return_value=False)
def test_allgather_prepare_finalize(self, mock_enable_sp,

View File

@@ -386,10 +386,11 @@ class TestEagleProposerDummyRun(TestBase):
set_current_vllm_config(None)
# cpu does not support parallel-group, let alone `sp`
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
**{"return_value.flash_comm_v1_enabled": False})
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_basic(self, mock_context, mock_get_context):
def test_dummy_run_basic(self, mock_context, mock_get_context, mock_get_context_2):
num_tokens = 32
with_prefill = False
@@ -402,10 +403,11 @@ class TestEagleProposerDummyRun(TestBase):
self.assertTrue(self.proposer._runnable.call_count == 1)
# cpu does not support parallel-group, let alone `sp`
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
**{"return_value.flash_comm_v1_enabled": False})
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
def test_dummy_run_with_prefill(self, mock_context, mock_get_context, mock_get_context_2):
mock_context.return_value.__enter__.return_value = None
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
with set_current_vllm_config(self.vllm_config):
@@ -413,11 +415,12 @@ class TestEagleProposerDummyRun(TestBase):
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
self.assertTrue(self.proposer._runnable.call_count == 1)
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
mock_update_full_graph_params):
mock_update_full_graph_params, mock_get_context_2):
last_use_cuda_graph = self.proposer.use_cuda_graph
mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
@@ -425,6 +428,7 @@ class TestEagleProposerDummyRun(TestBase):
# cpu does not support parallel-group, let alone `sp`
mock_return_context.flash_comm_v1_enabled = False
mock_get_context.return_value = mock_return_context
mock_get_context_2.return_value = mock_return_context
self.proposer.use_cuda_graph = True
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
with set_current_vllm_config(self.vllm_config):
@@ -435,12 +439,13 @@ class TestEagleProposerDummyRun(TestBase):
self.assertTrue(self.proposer._runnable.call_count == 1)
mock_update_full_graph_params.assert_not_called()
self.proposer.use_cuda_graph = last_use_cuda_graph
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
mock_update_full_graph_params):
mock_update_full_graph_params, mock_get_context_2):
last_use_cuda_graph = self.proposer.use_cuda_graph
mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
@@ -448,6 +453,7 @@ class TestEagleProposerDummyRun(TestBase):
# cpu does not support parallel-group, let alone `sp`
mock_return_context.flash_comm_v1_enabled = False
mock_get_context.return_value = mock_return_context
mock_get_context_2.return_value = mock_return_context
self.proposer.use_cuda_graph = True
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
with set_current_vllm_config(self.vllm_config):