[Feature] support aclgraph for model runner v2 (#7110)

### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model. 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-03-13 09:11:46 +08:00
committed by GitHub
parent 1f71da80eb
commit c980e68d40
52 changed files with 840 additions and 309 deletions

View File

@@ -184,7 +184,7 @@ def test_token_dispatcher_with_all_gather_quant(
): ):
context_mock = MagicMock() context_mock = MagicMock()
context_mock.fused_moe_state = 0 context_mock.fused_moe_state = 0
with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context", with patch("vllm_ascend.ascend_forward_context.get_forward_context",
return_value=context_mock): return_value=context_mock):
a = torch.randn((m, k), device=device, dtype=dtype) / 10 a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8) w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)

View File

@@ -85,3 +85,29 @@ def test_egale_spec_decoding(
}, },
) as runner: ) as runner:
runner.model.generate(prompts, sampling_params) runner.model.generate(prompts, sampling_params)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("compilation_config", [{"cudagraph_mode": "FULL_DECODE_ONLY"}, {}])
@patch.dict(os.environ, {"VLLM_USE_V2_MODEL_RUNNER": "1"})
def test_qwen3_dense_graph_mode(
model: str,
max_tokens: int,
enforce_eager: bool,
) -> None:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=enforce_eager,
) as runner:
runner.model.generate(prompts, sampling_params)

View File

@@ -74,7 +74,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
@patch("torch_npu._npu_reshape_and_cache") @patch("torch_npu._npu_reshape_and_cache")
@patch("torch_npu._npu_flash_attention") @patch("torch_npu._npu_flash_attention")
@patch("vllm_ascend.attention.attention_v1.get_forward_context") @patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_forward_prefill_310( def test_forward_prefill_310(
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
): ):
@@ -105,7 +105,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16)) @patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
@patch("torch_npu._npu_reshape_and_cache") @patch("torch_npu._npu_reshape_and_cache")
@patch("torch_npu._npu_paged_attention_splitfuse") @patch("torch_npu._npu_paged_attention_splitfuse")
@patch("vllm_ascend.attention.attention_v1.get_forward_context") @patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_forward_chunked_prefill_310( def test_forward_chunked_prefill_310(
self, self,
mock_get_forward_context, mock_get_forward_context,
@@ -140,7 +140,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16)) @patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
@patch("torch_npu._npu_reshape_and_cache") @patch("torch_npu._npu_reshape_and_cache")
@patch("torch_npu._npu_paged_attention_splitfuse") @patch("torch_npu._npu_paged_attention_splitfuse")
@patch("vllm_ascend.attention.attention_v1.get_forward_context") @patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_forward_prefill_cache_hit_310( def test_forward_prefill_cache_hit_310(
self, self,
mock_get_forward_context, mock_get_forward_context,
@@ -175,7 +175,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
@patch("vllm_ascend.attention.attention_v1.using_paged_attention") @patch("vllm_ascend.attention.attention_v1.using_paged_attention")
@patch("torch_npu._npu_paged_attention") @patch("torch_npu._npu_paged_attention")
@patch("torch_npu._npu_reshape_and_cache") @patch("torch_npu._npu_reshape_and_cache")
@patch("vllm_ascend.attention.attention_v1.get_forward_context") @patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_forward_paged_attention_310( def test_forward_paged_attention_310(
self, mock_get_forward_context, mock_npu_reshape_and_cache, mock_paged_attention, mock_using_paged_attention self, mock_get_forward_context, mock_npu_reshape_and_cache, mock_paged_attention, mock_using_paged_attention
): ):

View File

@@ -95,7 +95,7 @@ class TestAscendAttentionCPImpl(TestBase):
@patch('torch_npu.npu_attention_update') @patch('torch_npu.npu_attention_update')
@patch("torch_npu.npu_fused_infer_attention_score") @patch("torch_npu.npu_fused_infer_attention_score")
@patch( @patch(
'vllm_ascend.attention.context_parallel.attention_cp.get_forward_context' 'vllm_ascend.ascend_forward_context.get_forward_context'
) )
@patch_distributed_groups(dcp_size=2, pcp_size=2) @patch_distributed_groups(dcp_size=2, pcp_size=2)
def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp, def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp,

View File

@@ -212,7 +212,7 @@ class TestAscendAttentionBackendImpl(TestBase):
@patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score') @patch('torch_npu.npu_fused_infer_attention_score')
@patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
def test_forward_fused_infer_attention( def test_forward_fused_infer_attention(
self, mock_get_forward_context, self, mock_get_forward_context,
mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache): mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache):
@@ -248,7 +248,7 @@ class TestAscendAttentionBackendImpl(TestBase):
@patch('vllm_ascend.attention.attention_v1.using_paged_attention') @patch('vllm_ascend.attention.attention_v1.using_paged_attention')
@patch('torch_npu._npu_paged_attention') @patch('torch_npu._npu_paged_attention')
@patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
def test_forward_paged_attention(self, mock_get_forward_context, def test_forward_paged_attention(self, mock_get_forward_context,
mock_npu_reshape_and_cache, mock_npu_reshape_and_cache,
mock_paged_attention, mock_paged_attention,
@@ -279,7 +279,7 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_paged_attention.assert_called_once() mock_paged_attention.assert_called_once()
assert output.shape == (4, 8 * 64) assert output.shape == (4, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('torch_npu.npu_fused_infer_attention_score') @patch('torch_npu.npu_fused_infer_attention_score')
@patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_reshape_and_cache')
def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache, def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
@@ -311,7 +311,7 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_fused_infer_attention_score.assert_called_once() mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8, 64) assert output.shape == (10, 8, 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('torch_npu._npu_paged_attention') @patch('torch_npu._npu_paged_attention')
@patch('torch_npu.npu_fused_infer_attention_score') @patch('torch_npu.npu_fused_infer_attention_score')
@patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_reshape_and_cache')

View File

@@ -449,7 +449,7 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(result.shape[1], N) self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1) self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("torch_npu.npu_fused_infer_attention_score") @patch("torch_npu.npu_fused_infer_attention_score")
@patch('torch_npu.npu_attention_update') @patch('torch_npu.npu_attention_update')
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)

View File

@@ -929,7 +929,7 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(out.shape, prefix_out.shape) self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape) self.assertEqual(lse.shape, prefix_lse.shape)
@patch('vllm_ascend.attention.mla_v1.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj") @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
@patch("torch_npu.npu_fused_infer_attention_score") @patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_decode_without_graph(self, def test_forward_decode_without_graph(self,
@@ -1095,7 +1095,7 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim) self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank) self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
@patch('vllm_ascend.attention.mla_v1.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("torch_npu.npu_fused_infer_attention_score") @patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_decode(self, mock_npu_fused_infer_attention_score, def test_forward_decode(self, mock_npu_fused_infer_attention_score,
mock_get_forward_context): mock_get_forward_context):

View File

@@ -161,12 +161,13 @@ class TestACLGraphWrapper(TestBase):
vllm_config=self.mock_vllm_config, vllm_config=self.mock_vllm_config,
runtime_mode=CUDAGraphMode.NONE) runtime_mode=CUDAGraphMode.NONE)
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.envs')
def test_call_with_none_runtime_mode(self, mock_envs, def test_call_with_none_runtime_mode(self, mock_envs,
mock_current_platform, mock_current_platform,
mock_get_forward_context): mock_get_forward_context, mock_get_forward_context_2):
"""Test __call__ method when runtime mode is NONE""" """Test __call__ method when runtime mode is NONE"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
@@ -185,16 +186,19 @@ class TestACLGraphWrapper(TestBase):
self.mock_runnable.assert_called_once_with("arg1", "arg2") self.mock_runnable.assert_called_once_with("arg1", "arg2")
self.assertEqual(result, "test_output") self.assertEqual(result, "test_output")
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.envs')
def test_call_with_mismatched_runtime_mode(self, mock_envs, def test_call_with_mismatched_runtime_mode(self, mock_envs,
mock_current_platform, mock_current_platform,
mock_get_forward_context): mock_get_forward_context,
mock_get_forward_context_2):
"""Test __call__ method when runtime mode doesn't match wrapper mode""" """Test __call__ method when runtime mode doesn't match wrapper mode"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
wrapper = ACLGraphWrapper( wrapper = ACLGraphWrapper(
@@ -214,18 +218,20 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
) )
@patch('vllm_ascend.compilation.acl_graph.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter') @patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_capture_graph_first_time( def test_call_capture_graph_first_time(
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs, self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch): mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph for the first time""" """Test __call__ method captures graph for the first time"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Mock torch.npu.NPUGraph # Mock torch.npu.NPUGraph
@@ -284,6 +290,7 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
) )
@patch('vllm_ascend.compilation.acl_graph.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter') @patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@@ -291,12 +298,15 @@ class TestACLGraphWrapper(TestBase):
def test_call_replay_graph(self, mock_weak_ref_tensors, def test_call_replay_graph(self, mock_weak_ref_tensors,
mock_compilation_counter, mock_envs, mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context, mock_current_platform, mock_get_forward_context,
mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_validate_cudagraph_capturing_enabled,
mock_torch): mock_torch):
"""Test __call__ method replays graph when already captured""" """Test __call__ method replays graph when already captured"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
self.mock_forward_context.is_draft_model = False self.mock_forward_context.is_draft_model = False
@@ -358,17 +368,19 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
) )
@patch('vllm_ascend.compilation.acl_graph.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_with_debug_mode_input_address_check( def test_call_with_debug_mode_input_address_check(
self, mock_weak_ref_tensors, mock_envs, mock_current_platform, self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
mock_get_forward_context, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch): mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method with debug mode input address checking""" """Test __call__ method with debug mode input address checking"""
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
self.mock_forward_context.is_draft_model = False self.mock_forward_context.is_draft_model = False
@@ -413,17 +425,19 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
) )
@patch('vllm_ascend.compilation.acl_graph.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_with_debug_mode_input_address_mismatch( def test_call_with_debug_mode_input_address_mismatch(
self, mock_weak_ref_tensors, mock_envs, mock_current_platform, self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
mock_get_forward_context, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch): mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method with debug mode input address mismatch raises AssertionError""" """Test __call__ method with debug mode input address mismatch raises AssertionError"""
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Mock torch.npu.NPUGraph # Mock torch.npu.NPUGraph
@@ -471,6 +485,7 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
) )
@patch('vllm_ascend.compilation.acl_graph.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter') @patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@@ -478,12 +493,13 @@ class TestACLGraphWrapper(TestBase):
@patch('vllm_ascend.compilation.acl_graph.patch') @patch('vllm_ascend.compilation.acl_graph.patch')
def test_call_capture_graph_with_gc_disable( def test_call_capture_graph_with_gc_disable(
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter, self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
mock_envs, mock_current_platform, mock_get_forward_context, mock_envs, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch): mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph with gc_disable option enabled""" """Test __call__ method captures graph with gc_disable option enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable gc_disable option # Enable gc_disable option
@@ -545,18 +561,20 @@ class TestACLGraphWrapper(TestBase):
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
) )
@patch('vllm_ascend.compilation.acl_graph.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.compilation_counter') @patch('vllm_ascend.compilation.acl_graph.compilation_counter')
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
def test_call_capture_graph_with_weak_ref_output( def test_call_capture_graph_with_weak_ref_output(
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs, self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
mock_current_platform, mock_get_forward_context, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
mock_validate_cudagraph_capturing_enabled, mock_torch): mock_validate_cudagraph_capturing_enabled, mock_torch):
"""Test __call__ method captures graph with weak_ref_output option enabled""" """Test __call__ method captures graph with weak_ref_output option enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable weak_ref_output option # Enable weak_ref_output option
@@ -608,18 +626,20 @@ class TestACLGraphWrapper(TestBase):
# Should return the weak ref output when weak_ref_output option is enabled # Should return the weak ref output when weak_ref_output option is enabled
self.assertEqual(result, "weak_ref_output") self.assertEqual(result, "weak_ref_output")
@patch('vllm_ascend.compilation.acl_graph.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.get_forward_context')
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.current_platform')
@patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.envs')
@patch('vllm_ascend.compilation.acl_graph.logger') @patch('vllm_ascend.compilation.acl_graph.logger')
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs, def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
mock_current_platform, mock_current_platform,
mock_get_forward_context): mock_get_forward_context,mock_get_forward_context_2):
"""Test __call__ method captures graph with debug logging enabled""" """Test __call__ method captures graph with debug logging enabled"""
mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_envs.VLLM_LOGGING_LEVEL = "INFO"
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context mock_get_forward_context.return_value = self.mock_forward_context
mock_get_forward_context_2.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
# Enable debug logging # Enable debug logging
@@ -757,10 +777,11 @@ class TestPCPDCPGraphParams(TestBase):
self.graph_params.events[4].append(mock_event) self.graph_params.events[4].append(mock_event)
self.graph_params.handles[4].append(MagicMock()) self.graph_params.handles[4].append(MagicMock())
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('torch.npu.graph_task_update_end', ) @patch('torch.npu.graph_task_update_end', )
@patch('torch.npu.graph_task_update_begin', MagicMock()) @patch('torch.npu.graph_task_update_begin', MagicMock())
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock()) @patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end): def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context):
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
block_table = torch.zeros(2, 5, dtype=torch.long) block_table = torch.zeros(2, 5, dtype=torch.long)
seq_lens = torch.tensor([4, 4]) seq_lens = torch.tensor([4, 4])
@@ -790,6 +811,7 @@ class TestPCPDCPGraphParams(TestBase):
forward_context = MagicMock() forward_context = MagicMock()
forward_context.attn_metadata = {"attn_layer_0": metadata} forward_context.attn_metadata = {"attn_layer_0": metadata}
forward_context.is_draft_model = False forward_context.is_draft_model = False
mock_context.return_value = forward_context
num_heads = 256 num_heads = 256
scale = 0.1 scale = 0.1

View File

@@ -119,11 +119,9 @@ def mock_dist_env(mocker: MockerFixture):
return_value=(torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]), None, 0)), \ return_value=(torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]), None, 0)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context', patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \ return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context', patch('vllm_ascend.ascend_forward_context.get_forward_context',
return_value=mock_forward_context_obj), \ return_value=mock_forward_context_obj), \
patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3), \ patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3), \
patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher', patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
return_value=None), \ return_value=None), \
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher', patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
@@ -298,7 +296,7 @@ class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method', @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
return_value=MagicMock()) return_value=MagicMock())
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('vllm_ascend.utils.get_ascend_device_type', @patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType.A3) return_value=AscendDeviceType.A3)
@patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_grouped_matmul')
@@ -407,7 +405,7 @@ class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.fused_moe.moe_mlp.HAS_TRITON', False) @patch('vllm_ascend.ops.fused_moe.moe_mlp.HAS_TRITON', False)
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method', @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
return_value=MagicMock()) return_value=MagicMock())
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant') @patch('torch_npu.npu_dynamic_quant')
@@ -513,7 +511,7 @@ class TestUnifiedApplyMLP(TestBase):
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method", @patch("vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method",
return_value=MagicMock()) return_value=MagicMock())
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context") @patch("vllm_ascend.ascend_forward_context.get_forward_context")
@patch("torch_npu.npu_grouped_matmul") @patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu") @patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_grouped_matmul_swiglu_quant") @patch("torch_npu.npu_grouped_matmul_swiglu_quant")

View File

@@ -121,9 +121,10 @@ class TestAscendMultiHeadLatentAttention(TestBase):
@patch("vllm_ascend.ops.mla.get_ascend_config") @patch("vllm_ascend.ops.mla.get_ascend_config")
@patch("vllm_ascend.ops.mla.get_tensor_model_parallel_world_size") @patch("vllm_ascend.ops.mla.get_tensor_model_parallel_world_size")
@patch("vllm_ascend.ops.mla.get_forward_context") @patch("vllm_ascend.ops.mla.get_forward_context")
def test_forward(self, mock_get_forward_context, mock_tp_size, @patch('vllm_ascend.ascend_forward_context.get_forward_context')
def test_forward(self, mock_get_forward_context_2, mock_get_forward_context, mock_tp_size,
mock_ascend_config, mock_get_vllm_config, mock_ascend_config, mock_get_vllm_config,
mock_mla_forward): mock_mla_forward,):
mock_tp_size.return_value = 1 mock_tp_size.return_value = 1
mock_ascend_config.return_value.enable_shared_expert_dp = False mock_ascend_config.return_value.enable_shared_expert_dp = False
mock_vllm_config = MagicMock(spec=VllmConfig) mock_vllm_config = MagicMock(spec=VllmConfig)
@@ -159,6 +160,7 @@ class TestAscendMultiHeadLatentAttention(TestBase):
mock_forward_context = MagicMock(spec=ForwardContext) mock_forward_context = MagicMock(spec=ForwardContext)
mock_forward_context.flash_comm_v1_enabled = False mock_forward_context.flash_comm_v1_enabled = False
mock_get_forward_context.return_value = mock_forward_context mock_get_forward_context.return_value = mock_forward_context
mock_get_forward_context_2.return_value = mock_forward_context
mock_mla_forward.return_value = (3, self.hidden_size) mock_mla_forward.return_value = (3, self.hidden_size)

View File

@@ -28,7 +28,7 @@ class TestMoECommMethod(TestBase):
self.moe_config.dp_group = MagicMock() self.moe_config.dp_group = MagicMock()
self.moe_config.global_redundant_expert_num = 0 self.moe_config.global_redundant_expert_num = 0
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch( @patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather" "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
) )
@@ -73,7 +73,7 @@ class TestMoECommMethod(TestBase):
context_metadata=context_metadata) context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch( @patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2") "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
@@ -116,7 +116,7 @@ class TestMoECommMethod(TestBase):
context_metadata=context_metadata) context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch( @patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All" "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
) )
@@ -155,7 +155,7 @@ class TestMoECommMethod(TestBase):
mock_pf_instance.prepare.assert_called_once_with( mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, QuantType.NONE) hidden_states, router_logits, False, False, QuantType.NONE)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch( @patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather" "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
) )

View File

@@ -32,7 +32,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
@patch( @patch(
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank", "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
return_value=0) return_value=0)
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context") @patch('vllm_ascend.ascend_forward_context.get_forward_context')
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank, def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
mock_tp_size): mock_tp_size):
mock_context = MagicMock() mock_context = MagicMock()
@@ -65,7 +65,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
@patch( @patch(
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank", "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
return_value=0) return_value=0)
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context") @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("torch.distributed.all_gather") @patch("torch.distributed.all_gather")
def test_mc2_tp_split_allgather(self, mock_all_gather, def test_mc2_tp_split_allgather(self, mock_all_gather,
mock_get_forward_context, mock_tp_rank, mock_get_forward_context, mock_tp_rank,
@@ -169,7 +169,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
self.assertEqual(final_result.shape[0], 2) self.assertEqual(final_result.shape[0], 2)
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group") @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context") @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp", @patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
return_value=False) return_value=False)
def test_allgather_prepare_finalize(self, mock_enable_sp, def test_allgather_prepare_finalize(self, mock_enable_sp,

View File

@@ -386,10 +386,11 @@ class TestEagleProposerDummyRun(TestBase):
set_current_vllm_config(None) set_current_vllm_config(None)
# cpu does not support parallel-group, let alone `sp` # cpu does not support parallel-group, let alone `sp`
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context", @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
**{"return_value.flash_comm_v1_enabled": False}) **{"return_value.flash_comm_v1_enabled": False})
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_basic(self, mock_context, mock_get_context): def test_dummy_run_basic(self, mock_context, mock_get_context, mock_get_context_2):
num_tokens = 32 num_tokens = 32
with_prefill = False with_prefill = False
@@ -402,10 +403,11 @@ class TestEagleProposerDummyRun(TestBase):
self.assertTrue(self.proposer._runnable.call_count == 1) self.assertTrue(self.proposer._runnable.call_count == 1)
# cpu does not support parallel-group, let alone `sp` # cpu does not support parallel-group, let alone `sp`
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context", @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
**{"return_value.flash_comm_v1_enabled": False}) **{"return_value.flash_comm_v1_enabled": False})
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_with_prefill(self, mock_context, mock_get_context): def test_dummy_run_with_prefill(self, mock_context, mock_get_context, mock_get_context_2):
mock_context.return_value.__enter__.return_value = None mock_context.return_value.__enter__.return_value = None
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
@@ -413,11 +415,12 @@ class TestEagleProposerDummyRun(TestBase):
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4) self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
self.assertTrue(self.proposer._runnable.call_count == 1) self.assertTrue(self.proposer._runnable.call_count == 1)
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params") @patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context, def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
mock_update_full_graph_params): mock_update_full_graph_params, mock_get_context_2):
last_use_cuda_graph = self.proposer.use_cuda_graph last_use_cuda_graph = self.proposer.use_cuda_graph
mock_return_context = MagicMock() mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
@@ -425,6 +428,7 @@ class TestEagleProposerDummyRun(TestBase):
# cpu does not support parallel-group, let alone `sp` # cpu does not support parallel-group, let alone `sp`
mock_return_context.flash_comm_v1_enabled = False mock_return_context.flash_comm_v1_enabled = False
mock_get_context.return_value = mock_return_context mock_get_context.return_value = mock_return_context
mock_get_context_2.return_value = mock_return_context
self.proposer.use_cuda_graph = True self.proposer.use_cuda_graph = True
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
@@ -435,12 +439,13 @@ class TestEagleProposerDummyRun(TestBase):
self.assertTrue(self.proposer._runnable.call_count == 1) self.assertTrue(self.proposer._runnable.call_count == 1)
mock_update_full_graph_params.assert_not_called() mock_update_full_graph_params.assert_not_called()
self.proposer.use_cuda_graph = last_use_cuda_graph self.proposer.use_cuda_graph = last_use_cuda_graph
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params") @patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_in_graph_run(self, mock_context, mock_get_context, def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
mock_update_full_graph_params): mock_update_full_graph_params, mock_get_context_2):
last_use_cuda_graph = self.proposer.use_cuda_graph last_use_cuda_graph = self.proposer.use_cuda_graph
mock_return_context = MagicMock() mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
@@ -448,6 +453,7 @@ class TestEagleProposerDummyRun(TestBase):
# cpu does not support parallel-group, let alone `sp` # cpu does not support parallel-group, let alone `sp`
mock_return_context.flash_comm_v1_enabled = False mock_return_context.flash_comm_v1_enabled = False
mock_get_context.return_value = mock_return_context mock_get_context.return_value = mock_return_context
mock_get_context_2.return_value = mock_return_context
self.proposer.use_cuda_graph = True self.proposer.use_cuda_graph = True
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):

View File

@@ -18,12 +18,11 @@ from collections.abc import Callable
import torch import torch
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods
from vllm_ascend.quantization.methods.base import QuantType from vllm_ascend.quantization.methods.base import QuantType
@@ -93,7 +92,7 @@ class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod):
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
final_hidden_states = moe_comm_method.fused_experts( final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
@@ -222,9 +221,8 @@ class AscendFusedMoE310(FusedMoE):
) -> torch.Tensor: ) -> torch.Tensor:
assert self.quant_method is not None assert self.quant_method is not None
assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported." assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported."
forward_context = get_forward_context()
hidden_states, router_logits, _, context_metadata = forward_context.moe_comm_method.prepare( hidden_states, router_logits, _, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type
) )
@@ -246,7 +244,7 @@ class AscendFusedMoE310(FusedMoE):
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
) )
routed_out = forward_context.moe_comm_method.finalize( routed_out = _EXTRA_CTX.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out, hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results, reduce_results=self.reduce_results,
context_metadata=context_metadata, context_metadata=context_metadata,

View File

@@ -16,8 +16,8 @@
from __future__ import annotations from __future__ import annotations
import torch import torch
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult
from .moe_mlp import unified_apply_mlp from .moe_mlp import unified_apply_mlp
@@ -50,7 +50,7 @@ class AllGatherCommImpl310(AllGatherCommImpl):
) -> FusedExpertsResult: ) -> FusedExpertsResult:
# This method is overridden to use the 310p-specific unified_apply_mlp # This method is overridden to use the 310p-specific unified_apply_mlp
# which provides optimized MLP computation for the 310p platform # which provides optimized MLP computation for the 310p platform
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
assert moe_comm_method is not None, "Missing communication context" assert moe_comm_method is not None, "Missing communication context"
dispatch_results = self.token_dispatcher.token_dispatch( dispatch_results = self.token_dispatcher.token_dispatch(

View File

@@ -21,9 +21,9 @@ from typing import Any
import torch import torch
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed import get_ep_group from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend._310p.fused_moe.experts_selector import select_experts from vllm_ascend._310p.fused_moe.experts_selector import select_experts
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType
@@ -125,7 +125,7 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
topk_weights = topk_weights.to(self.in_dtype) topk_weights = topk_weights.to(self.in_dtype)
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
final_hidden_states = moe_comm_method.fused_experts( final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x, hidden_states=x,

View File

@@ -4,6 +4,7 @@ from enum import Enum
from typing import Any from typing import Any
import torch import torch
import vllm.envs as envs_vllm
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
@@ -270,3 +271,61 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo
else: else:
raise ValueError(f"Unsupported soc_version: {soc_version}") raise ValueError(f"Unsupported soc_version: {soc_version}")
return moe_comm_type return moe_comm_type
class _ExtraForwardContextProxy:
"""Unified forward-context access for v1/v2 model runners."""
extra_attrs = (
"capturing",
"moe_comm_type",
"moe_comm_method",
"mmrs_fusion",
"num_tokens",
"flash_comm_v1_enabled",
"flashcomm_v2_enabled",
"pad_size",
"padded_length",
"num_tokens_across_dp",
"mc2_mask",
"is_draft_model",
"prefetch_mlp_gate_up_proj",
"prefetch_mlp_down_proj",
"model_instance",
"layer_idx",
"max_tokens_across_dp",
"max_tokens_across_pcp",
"num_accept_tokens",
"in_profile_run",
"padded_num_tokens",
)
def check_extra_attr(self, name: str):
if name not in self.extra_attrs:
raise AttributeError(
f"{name} is not extra forward context attribute, "
"please get/set it from vllm's _forward_context directly."
)
@staticmethod
def _ctx():
return get_forward_context()
def __getattr__(self, name: str) -> Any:
self.check_extra_attr(name)
ctx = self._ctx()
if envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
return ctx.additional_kwargs[name]
return getattr(ctx, name)
def __setattr__(self, name: str, value: Any) -> None:
self.check_extra_attr(name)
ctx = self._ctx()
if envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
ctx.additional_kwargs[name] = value
else:
setattr(ctx, name, value)
# usage: from vllm_ascend.ascend_forward_context import _EXTRA_CTX
_EXTRA_CTX = _ExtraForwardContextProxy()

View File

@@ -23,7 +23,6 @@ import torch
import torch_npu import torch_npu
import vllm.envs as envs_vllm import vllm.envs as envs_vllm
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import ( # type: ignore from vllm.v1.attention.backend import ( # type: ignore
AttentionBackend, AttentionBackend,
@@ -40,6 +39,7 @@ from vllm.v1.attention.backends.registry import ( # type: ignore
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill
from vllm_ascend.attention.utils import ( from vllm_ascend.attention.utils import (
@@ -392,7 +392,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
): ):
if using_paged_attention(num_tokens, vllm_config): if using_paged_attention(num_tokens, vllm_config):
# Paged Attention update logic # Paged Attention update logic
if forward_context.is_draft_model: if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params() graph_params = get_draft_graph_params()
else: else:
graph_params = get_graph_params() graph_params = get_graph_params()
@@ -444,7 +444,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
event.record(update_stream) event.record(update_stream)
else: else:
# FIA update logic # FIA update logic
if forward_context.is_draft_model: if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params() graph_params = get_draft_graph_params()
attn_metadata = draft_attn_metadatas attn_metadata = draft_attn_metadatas
attn_keys = list(attn_metadata[0].keys()) attn_keys = list(attn_metadata[0].keys())
@@ -462,7 +462,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_layers = len(attn_keys) num_layers = len(attn_keys)
if num_layers == 0: if num_layers == 0:
return return
if forward_context.is_draft_model: if _EXTRA_CTX.is_draft_model:
attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers) attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers)
attn_count = 0 attn_count = 0
with torch.npu.stream(update_stream): with torch.npu.stream(update_stream):
@@ -488,7 +488,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
softmax_lse, softmax_lse,
) = param ) = param
if forward_context.is_draft_model: if _EXTRA_CTX.is_draft_model:
draft_step = attn_count // num_layers draft_step = attn_count // num_layers
seq_lens = attn_metadata[draft_step][key].seq_lens_list seq_lens = attn_metadata[draft_step][key].seq_lens_list
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
@@ -535,8 +535,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata) key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1] num_tokens = attn_metadata.actual_seq_lengths_q[-1]
forward_context = get_forward_context() if _EXTRA_CTX.is_draft_model:
if forward_context.is_draft_model:
graph_params = get_draft_graph_params() graph_params = get_draft_graph_params()
else: else:
graph_params = get_graph_params() graph_params = get_graph_params()
@@ -563,7 +562,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
sparse_mode=3, sparse_mode=3,
scale=self.scale, scale=self.scale,
) )
if forward_context.is_draft_model: if _EXTRA_CTX.is_draft_model:
update_draft_graph_params_workspaces(num_tokens, workspace) update_draft_graph_params_workspaces(num_tokens, workspace)
else: else:
update_graph_params_workspaces(num_tokens, workspace) update_graph_params_workspaces(num_tokens, workspace)
@@ -625,9 +624,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
output: torch.Tensor | None = None, output: torch.Tensor | None = None,
): ):
graph_params = get_graph_params() graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0] num_tokens = query.shape[0]
if forward_context.capturing: if _EXTRA_CTX.capturing:
# Get workspace from cache or calculate it if not present. # Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens) workspace = graph_params.workspaces.get(num_tokens)
if workspace is None: if workspace is None:
@@ -761,11 +759,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata: AscendMetadata, attn_metadata: AscendMetadata,
output: torch.Tensor, output: torch.Tensor,
): ):
forward_context: ForwardContext = get_forward_context()
# we inherit ForwardContext in model runner v2, when enable model # we inherit ForwardContext in model runner v2, when enable model
# runner v2, there is not capturing attribute in forward_context, # runner v2, there is not capturing attribute in forward_context,
# just use getattr to avoid attribute error. # just use getattr to avoid attribute error.
if getattr(forward_context, "capturing", False): if _EXTRA_CTX.capturing:
attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output) attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens] output[:num_tokens] = attn_output[:num_tokens]
return output return output
@@ -841,8 +838,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata: AscendMetadata, attn_metadata: AscendMetadata,
output: torch.Tensor | None = None, output: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() if _EXTRA_CTX.capturing:
if forward_context.capturing:
return self.full_graph_pa(query, attn_metadata, output) return self.full_graph_pa(query, attn_metadata, output)
torch_npu._npu_paged_attention( torch_npu._npu_paged_attention(
query=query, query=query,

View File

@@ -29,10 +29,10 @@ from vllm.distributed import (
get_decode_context_model_parallel_world_size, get_decode_context_model_parallel_world_size,
get_pcp_group, get_pcp_group,
) )
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.v1.attention.backend import AttentionCGSupport from vllm.v1.attention.backend import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.attention.attention_v1 import ( from vllm_ascend.attention.attention_v1 import (
AscendAttentionBackendImpl, AscendAttentionBackendImpl,
AscendAttentionMetadataBuilder, AscendAttentionMetadataBuilder,
@@ -559,9 +559,8 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
"actual_seq_lengths": torch.arange(attn_metadata.num_decodes_flatten) + 1, "actual_seq_lengths": torch.arange(attn_metadata.num_decodes_flatten) + 1,
} }
graph_params = get_graph_params() graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0] num_tokens = query.shape[0]
if forward_context.capturing: if _EXTRA_CTX.capturing:
stream = torch_npu.npu.current_stream() stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent() event = torch.npu.ExternalEvent()

View File

@@ -10,7 +10,6 @@ from vllm.distributed import (
get_decode_context_model_parallel_world_size, get_decode_context_model_parallel_world_size,
get_pcp_group, get_pcp_group,
) )
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionCGSupport from vllm.v1.attention.backend import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
@@ -30,6 +29,7 @@ from vllm_ascend.attention.mla_v1 import (
) )
# isort: on # isort: on
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.attention.context_parallel.common_cp import ( from vllm_ascend.attention.context_parallel.common_cp import (
AscendPCPMetadata, AscendPCPMetadata,
CPChunkedContextMetadata, CPChunkedContextMetadata,
@@ -294,7 +294,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
num_dcp_pcp_tokens=None, num_dcp_pcp_tokens=None,
draft_attn_metadatas=None, draft_attn_metadatas=None,
): ):
if forward_context.is_draft_model: if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params() graph_params = get_draft_graph_params()
else: else:
graph_params = get_graph_params() graph_params = get_graph_params()
@@ -659,12 +659,11 @@ class AscendMlaCPImpl(AscendMLAImpl):
"softmax_lse_flag": True, "softmax_lse_flag": True,
} }
forward_context: ForwardContext = get_forward_context() if _EXTRA_CTX.is_draft_model:
if forward_context.is_draft_model:
graph_params = get_draft_graph_params() graph_params = get_draft_graph_params()
else: else:
graph_params = get_graph_params() graph_params = get_graph_params()
if forward_context.capturing: if _EXTRA_CTX.capturing:
stream = torch_npu.npu.current_stream() stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent() event = torch.npu.ExternalEvent()
event.wait(stream) event.wait(stream)

View File

@@ -6,16 +6,20 @@ import torch
import torch_npu import torch_npu
import vllm.envs as envs_vllm import vllm.envs as envs_vllm
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.utils.math_utils import cdiv, round_down from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore from vllm.v1.attention.backend import (
AttentionBackend, # type: ignore
AttentionCGSupport,
MLAAttentionImpl,
)
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata, CPChunkedContextMetadata from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata, CPChunkedContextMetadata
@@ -44,12 +48,7 @@ from vllm_ascend.ops.layer_shard_linear import (
) )
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.utils import ( from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, get_weight_prefetch_method, maybe_trans_nz, weak_ref_tensors
ACL_FORMAT_FRACTAL_ND,
get_weight_prefetch_method,
maybe_trans_nz,
weak_ref_tensors,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -737,7 +736,7 @@ class AscendMLAImpl(MLAAttentionImpl):
num_dcp_pcp_tokens=None, num_dcp_pcp_tokens=None,
draft_attn_metadatas=None, draft_attn_metadatas=None,
): ):
if forward_context.is_draft_model: if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params() graph_params = get_draft_graph_params()
else: else:
graph_params = get_graph_params() graph_params = get_graph_params()
@@ -769,12 +768,12 @@ class AscendMLAImpl(MLAAttentionImpl):
softmax_lse, softmax_lse,
) = param ) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model: if speculative_config and speculative_config.method == "mtp" and not _EXTRA_CTX.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1 spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * (num_tokens // spec_multiple - len(seq_lens_list)) seq_lens_list = seq_lens_list + [0] * (num_tokens // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [spec_multiple * (i + 1) for i in range(num_tokens // spec_multiple)] actual_seq_lengths = [spec_multiple * (i + 1) for i in range(num_tokens // spec_multiple)]
elif forward_context.is_draft_model: elif _EXTRA_CTX.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[key].decode.block_table block_table = forward_context.attn_metadata[key].decode.block_table
# TODO: This is a hack and should be fixed in the future. # TODO: This is a hack and should be fixed in the future.
@@ -1243,12 +1242,11 @@ class AscendMLAImpl(MLAAttentionImpl):
"actual_seq_lengths": actual_seq_lengths, "actual_seq_lengths": actual_seq_lengths,
"actual_seq_lengths_kv": decode_meta.seq_lens_list, "actual_seq_lengths_kv": decode_meta.seq_lens_list,
} }
forward_context: ForwardContext = get_forward_context() if _EXTRA_CTX.is_draft_model:
if forward_context.is_draft_model:
graph_params = get_draft_graph_params() graph_params = get_draft_graph_params()
else: else:
graph_params = get_graph_params() graph_params = get_graph_params()
if forward_context.capturing: if _EXTRA_CTX.capturing:
stream = torch_npu.npu.current_stream() stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent() event = torch.npu.ExternalEvent()
@@ -1261,7 +1259,7 @@ class AscendMLAImpl(MLAAttentionImpl):
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, k_nope, **common_kwargs q_nope, k_nope, k_nope, **common_kwargs
) )
if forward_context.is_draft_model: if _EXTRA_CTX.is_draft_model:
update_draft_graph_params_workspaces(num_tokens, workspace) update_draft_graph_params_workspaces(num_tokens, workspace)
else: else:
update_graph_params_workspaces(num_tokens, workspace) update_graph_params_workspaces(num_tokens, workspace)
@@ -1493,7 +1491,6 @@ class AscendMLAImpl(MLAAttentionImpl):
reach_layer_for_shard_weight_series(layer) reach_layer_for_shard_weight_series(layer)
return output.fill_(0) return output.fill_(0)
forward_context = get_forward_context()
num_actual_tokens = self.get_num_actual_tokens(attn_metadata) num_actual_tokens = self.get_num_actual_tokens(attn_metadata)
assert ( assert (
attn_metadata.num_decodes is not None attn_metadata.num_decodes is not None
@@ -1505,7 +1502,7 @@ class AscendMLAImpl(MLAAttentionImpl):
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
output_padded = output output_padded = output
o_proj_input_shape = (forward_context.num_tokens, self.num_heads * self.v_head_dim) o_proj_input_shape = (_EXTRA_CTX.num_tokens, self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device) o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device)
# MLA Preprocess # MLA Preprocess

View File

@@ -7,16 +7,20 @@ import vllm.envs as envs_vllm
from torch import nn from torch import nn
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore from vllm.v1.attention.backend import (
AttentionBackend, # type: ignore
AttentionCGSupport,
MLAAttentionImpl,
)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend import envs from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata
@@ -967,10 +971,9 @@ class AscendSFAImpl(MLAAttentionImpl):
output: torch.Tensor | None = None, output: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
forward_context = get_forward_context()
if attn_metadata is None: if attn_metadata is None:
# Profiling run. # Profiling run.
if self.enable_dsa_cp_with_layer_shard and not forward_context.in_profile_run: if self.enable_dsa_cp_with_layer_shard and not _EXTRA_CTX.in_profile_run:
for layer in self.layer_sharding_kwargs or []: for layer in self.layer_sharding_kwargs or []:
if is_hidden_layer(layer): if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer) reach_layer_for_shard_weight_series(layer)

View File

@@ -19,6 +19,8 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from ..utils import weak_ref_tensors from ..utils import weak_ref_tensors
@@ -195,7 +197,7 @@ class ACLGraphWrapper:
if self.vllm_config.speculative_config if self.vllm_config.speculative_config
else False else False
) )
if self.runtime_mode != CUDAGraphMode.FULL or not forward_context.is_draft_model or not use_eagle: if self.runtime_mode != CUDAGraphMode.FULL or not _EXTRA_CTX.is_draft_model or not use_eagle:
torch.npu.current_stream().synchronize() torch.npu.current_stream().synchronize()
entry.aclgraph.replay() entry.aclgraph.replay()
return entry.output return entry.output

View File

@@ -3,6 +3,7 @@ import torch.distributed as dist
from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.parallel_state import get_fc3_quant_x_group from vllm_ascend.distributed.parallel_state import get_fc3_quant_x_group
@@ -16,7 +17,7 @@ def fc3_all_gather_and_maybe_unpad_impl(
x = get_fc3_quant_x_group().all_gather(x, 0) x = get_fc3_quant_x_group().all_gather(x, 0)
dp_metadata = forward_context.dp_metadata dp_metadata = forward_context.dp_metadata
if dp_metadata is None: if dp_metadata is None:
pad_size = forward_context.pad_size pad_size = _EXTRA_CTX.pad_size
if pad_size > 0: if pad_size > 0:
x = x[:-pad_size] x = x[:-pad_size]
else: else:
@@ -24,7 +25,7 @@ def fc3_all_gather_and_maybe_unpad_impl(
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype) result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:]) x = x.view(dp_size, _EXTRA_CTX.padded_length, *x.shape[1:])
offset = 0 offset = 0
for idx in range(dp_size): for idx in range(dp_size):
num_tokens_dp = num_tokens_across_dp_cpu[idx] num_tokens_dp = num_tokens_across_dp_cpu[idx]

View File

@@ -37,7 +37,7 @@ if not vllm_version_is("0.16.0"):
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import init_eplb_config from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
@@ -148,7 +148,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device) random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
final_hidden_states = moe_comm_method.fused_experts( final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
@@ -401,12 +401,13 @@ class AscendFusedMoE(FusedMoE):
# When static kernels are enabled, the forward pass runs twice (compilation + capture), # When static kernels are enabled, the forward pass runs twice (compilation + capture),
# causing moe_layer_index to overflow. Wrap the index to prevent out-of-bounds errors. # causing moe_layer_index to overflow. Wrap the index to prevent out-of-bounds errors.
if self.enable_npugraph_ex_static_kernel: if self.enable_npugraph_ex_static_kernel:
forward_context.moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers)) moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers))
forward_context.moe_layer_index = moe_layer_index
# Load balancing for token distribution among experts in dummy_run # Load balancing for token distribution among experts in dummy_run
# TODO: The community only considers load balancing when DP > 1. # TODO: The community only considers load balancing when DP > 1.
# This approach may overlook some extreme scenarios. # This approach may overlook some extreme scenarios.
enable_force_load_balance = forward_context.in_profile_run enable_force_load_balance = _EXTRA_CTX.in_profile_run
forward_context = get_forward_context() forward_context = get_forward_context()
if self.multistream_overlap_gate: if self.multistream_overlap_gate:
@@ -419,7 +420,7 @@ class AscendFusedMoE(FusedMoE):
assert fc3_context.shared_experts is not None assert fc3_context.shared_experts is not None
shared_out = fc3_context.shared_experts(hidden_states) shared_out = fc3_context.shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
moe_comm_type = forward_context.moe_comm_type moe_comm_type = _EXTRA_CTX.moe_comm_type
if ( if (
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
and not shared_expert_dp_enabled() and not shared_expert_dp_enabled()
@@ -442,16 +443,16 @@ class AscendFusedMoE(FusedMoE):
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
) )
if isinstance(forward_context.moe_comm_method, AllGatherCommImpl): if isinstance(_EXTRA_CTX.moe_comm_method, AllGatherCommImpl):
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True) topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True)
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True) topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True)
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids) set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare( hidden_states, router_logits, mc2_mask, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
replace_allreduce=forward_context.flash_comm_v1_enabled, replace_allreduce=_EXTRA_CTX.flash_comm_v1_enabled,
enable_shared_expert_dp=self.enable_shared_expert_dp, enable_shared_expert_dp=self.enable_shared_expert_dp,
quant_type=self.quant_type, quant_type=self.quant_type,
) )
@@ -509,7 +510,7 @@ class AscendFusedMoE(FusedMoE):
self.load_counter.add_(1) self.load_counter.add_(1)
else: else:
self.moe_load.add_(local_load) self.moe_load.add_(local_load)
routed_out = forward_context.moe_comm_method.finalize( routed_out = _EXTRA_CTX.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out, hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results, reduce_results=self.reduce_results,
context_metadata=context_metadata, context_metadata=context_metadata,
@@ -670,8 +671,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
# NOTE: This is exactly the opposite of # NOTE: This is exactly the opposite of
# `maybe_all_reduce_tensor_model_parallel` # `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context() moe_comm_type = _EXTRA_CTX.moe_comm_type
moe_comm_type = forward_context.moe_comm_type
if ( if (
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
and not shared_expert_dp_enabled() and not shared_expert_dp_enabled()

View File

@@ -19,11 +19,10 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm.model_executor.layers.fused_moe import FusedMoEConfig
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import ( from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalize, PrepareAndFinalize,
@@ -135,7 +134,7 @@ class MoECommMethod(ABC):
# Check constraints # Check constraints
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8] assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
assert moe_comm_method is not None, "Missing communication context" assert moe_comm_method is not None, "Missing communication context"
before_dispatch_evt = torch.npu.current_stream().record_event() before_dispatch_evt = torch.npu.current_stream().record_event()

View File

@@ -18,10 +18,9 @@
import torch import torch
import torch_npu import torch_npu
from torch.nn.functional import pad from torch.nn.functional import pad
from vllm.forward_context import get_forward_context
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.device.mxfp_compat import ( from vllm_ascend.device.mxfp_compat import (
ensure_mxfp8_moe_available, ensure_mxfp8_moe_available,
@@ -147,7 +146,7 @@ def quant_apply_mlp(
weight_prefetch_method = get_weight_prefetch_method() weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method: if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states) weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 is_mc2 = _EXTRA_CTX.moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and w1_offset is None and is_mc2: if w1_scale_bias is None and w1_offset is None and is_mc2:
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant: if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant:
# gmm1: gate_up_proj & act_fn: swiglu # gmm1: gate_up_proj & act_fn: swiglu

View File

@@ -26,10 +26,10 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
from vllm_ascend.quantization.methods.base import QuantType from vllm_ascend.quantization.methods.base import QuantType
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
@@ -242,8 +242,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
""" """
self.replace_allreduce = replace_allreduce self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
forward_context = get_forward_context() mc2_mask = _EXTRA_CTX.mc2_mask
mc2_mask = forward_context.mc2_mask
if self.tp_size > 1: if self.tp_size > 1:
# Also slice mc2_mask # Also slice mc2_mask
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0) split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
@@ -252,7 +251,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
padded_hidden_states_shape = hidden_states.shape padded_hidden_states_shape = hidden_states.shape
if not self.replace_allreduce: if not self.replace_allreduce:
self.num_tokens, _ = hidden_states.shape self.num_tokens, _ = hidden_states.shape
target_pad_length = forward_context.padded_num_tokens target_pad_length = _EXTRA_CTX.padded_num_tokens
pad_size = target_pad_length - self.num_tokens pad_size = target_pad_length - self.num_tokens
# Pad if necessary (unless shared expert DP is enabled) # Pad if necessary (unless shared expert DP is enabled)
@@ -367,8 +366,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
""" """
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1: if self.moe_config.dp_size > 1:
forward_context = get_forward_context() max_tokens_across_dp = _EXTRA_CTX.max_tokens_across_dp
max_tokens_across_dp = forward_context.max_tokens_across_dp
self.num_tokens = hidden_states.shape[0] self.num_tokens = hidden_states.shape[0]
pad_size = max_tokens_across_dp - self.num_tokens pad_size = max_tokens_across_dp - self.num_tokens
@@ -381,8 +379,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits = self.moe_config.dp_group.all_gather(router_logits, 0) router_logits = self.moe_config.dp_group.all_gather(router_logits, 0)
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
forward_context = get_forward_context() max_tokens_across_pcp = _EXTRA_CTX.max_tokens_across_pcp
max_tokens_across_pcp = forward_context.max_tokens_across_pcp
self.num_tokens_pcp = hidden_states.shape[0] self.num_tokens_pcp = hidden_states.shape[0]
pad_size = max_tokens_across_pcp - self.num_tokens_pcp pad_size = max_tokens_across_pcp - self.num_tokens_pcp

View File

@@ -57,9 +57,9 @@ from vllm.distributed import (
tensor_model_parallel_reduce_scatter, tensor_model_parallel_reduce_scatter,
) )
from vllm.distributed.parallel_state import get_tp_group from vllm.distributed.parallel_state import get_tp_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.parallel_state import ( from vllm_ascend.distributed.parallel_state import (
get_flashcomm2_odp_group, get_flashcomm2_odp_group,
get_flashcomm2_otp_group, get_flashcomm2_otp_group,
@@ -311,8 +311,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
input_parallel = splitted_input[tp_rank].contiguous() input_parallel = splitted_input[tp_rank].contiguous()
# padding for all-to-all # padding for all-to-all
forward_context = get_forward_context() num_padding_tokens = _EXTRA_CTX.pad_size
num_padding_tokens = forward_context.pad_size
if num_padding_tokens > 0: if num_padding_tokens > 0:
input_parallel = nn.functional.pad(input_parallel, (0, 0, 0, num_padding_tokens)) input_parallel = nn.functional.pad(input_parallel, (0, 0, 0, num_padding_tokens))
@@ -368,7 +367,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
else: else:
output = output_parallel output = output_parallel
if not forward_context.flash_comm_v1_enabled: if not _EXTRA_CTX.flash_comm_v1_enabled:
# flashcomm1 not enabled # flashcomm1 not enabled
output = get_tp_group().all_gather(output, 0) output = get_tp_group().all_gather(output, 0)
if num_padding_tokens > 0: if num_padding_tokens > 0:
@@ -514,9 +513,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
def matmul_and_reduce(self, input_parallel: torch.Tensor, bias_: Parameter | None) -> torch.Tensor: def matmul_and_reduce(self, input_parallel: torch.Tensor, bias_: Parameter | None) -> torch.Tensor:
assert self.quant_method is not None assert self.quant_method is not None
try: try:
forward_context = get_forward_context() flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled mmrs_fusion = _EXTRA_CTX.mmrs_fusion
mmrs_fusion = forward_context.mmrs_fusion
except AssertionError: except AssertionError:
flash_comm_v1_enabled = False flash_comm_v1_enabled = False
mmrs_fusion = False mmrs_fusion = False
@@ -527,7 +525,7 @@ class SequenceRowParallelOp(CustomRowParallelOp):
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_) output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
return tensor_model_parallel_all_reduce(output_parallel) return tensor_model_parallel_all_reduce(output_parallel)
pad_size = forward_context.pad_size pad_size = _EXTRA_CTX.pad_size
if pad_size > 0 and not (enable_dsa_cp() and "o_proj" in self.layer.prefix): if pad_size > 0 and not (enable_dsa_cp() and "o_proj" in self.layer.prefix):
x = F.pad(x, (0, 0, 0, pad_size)) x = F.pad(x, (0, 0, 0, pad_size))

View File

@@ -32,6 +32,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
class IndexerWrapper(nn.Module): class IndexerWrapper(nn.Module):
@@ -144,7 +145,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
kv_cache: torch.Tensor | None = None, kv_cache: torch.Tensor | None = None,
attn_metadata: AttentionMetadata | None = None, attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
need_gather_q_kv = get_forward_context().flash_comm_v1_enabled need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled
output_shape = hidden_states.shape output_shape = hidden_states.shape
# FIXME: This does not seem right, should make sure the buffer is fixed # FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)

View File

@@ -13,7 +13,7 @@ from vllm.distributed import (
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.rotary_embedding import rope_forward_oot from vllm_ascend.ops.rotary_embedding import rope_forward_oot
from vllm_ascend.ops.triton.muls_add import muls_add_triton from vllm_ascend.ops.triton.muls_add import muls_add_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
@@ -22,12 +22,12 @@ from vllm_ascend.utils import npu_stream_switch, prefetch_stream
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
try: try:
forward_context = get_forward_context() get_forward_context()
except AssertionError: except AssertionError:
return residual return residual
if x.size(0) != residual.size(0): if x.size(0) != residual.size(0):
pad_size = forward_context.pad_size pad_size = _EXTRA_CTX.pad_size
if pad_size > 0: if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size)) residual = F.pad(residual, (0, 0, 0, pad_size))
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
@@ -43,12 +43,12 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
except AssertionError: except AssertionError:
return x return x
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
if flash_comm_v1_enabled and label: if flash_comm_v1_enabled and label:
dp_metadata = forward_context.dp_metadata dp_metadata = forward_context.dp_metadata
if dp_metadata is None or not is_ep_comm: if dp_metadata is None or not is_ep_comm:
x = tensor_model_parallel_all_gather(x, 0) x = tensor_model_parallel_all_gather(x, 0)
pad_size = forward_context.pad_size pad_size = _EXTRA_CTX.pad_size
if pad_size > 0: if pad_size > 0:
x = x[:-pad_size] x = x[:-pad_size]
else: else:
@@ -57,7 +57,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype) result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:]) x = x.view(dp_size, _EXTRA_CTX.padded_length, *x.shape[1:])
offset = 0 offset = 0
for idx in range(dp_size): for idx in range(dp_size):
num_tokens_dp = num_tokens_across_dp_cpu[idx] num_tokens_dp = num_tokens_across_dp_cpu[idx]
@@ -79,7 +79,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
dp_metadata = forward_context.dp_metadata dp_metadata = forward_context.dp_metadata
if dp_metadata is None or not is_ep_comm: if dp_metadata is None or not is_ep_comm:
pad_size = forward_context.pad_size pad_size = _EXTRA_CTX.pad_size
if pad_size > 0: if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size)) x = F.pad(x, (0, 0, 0, pad_size))
return tensor_model_parallel_reduce_scatter(x, 0) return tensor_model_parallel_reduce_scatter(x, 0)
@@ -87,7 +87,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
# padding # padding
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu
padded_x = torch.empty((dp_size, forward_context.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype) padded_x = torch.empty((dp_size, _EXTRA_CTX.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype)
offset = 0 offset = 0
for idx in range(dp_size): for idx in range(dp_size):
num_tokens_dp = num_tokens_across_dp_cpu[idx] num_tokens_dp = num_tokens_across_dp_cpu[idx]
@@ -98,7 +98,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor: def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor:
if get_forward_context().flash_comm_v1_enabled and label: if _EXTRA_CTX.flash_comm_v1_enabled and label:
return torch.empty( return torch.empty(
(x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype (x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
) )
@@ -107,7 +107,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor: def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
if get_forward_context().flash_comm_v1_enabled: if _EXTRA_CTX.flash_comm_v1_enabled:
return torch.empty( return torch.empty(
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype (x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
) )
@@ -138,11 +138,10 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor: def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context() moe_comm_type = _EXTRA_CTX.moe_comm_type
moe_comm_type = forward_context.moe_comm_type
if ( if (
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
or forward_context.flash_comm_v1_enabled or _EXTRA_CTX.flash_comm_v1_enabled
): ):
return final_hidden_states return final_hidden_states
else: else:
@@ -163,7 +162,7 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str)
forward_context = get_forward_context() forward_context = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
num_tokens = input_parallel.size(0) num_tokens = input_parallel.size(0)
if forward_context.flash_comm_v1_enabled: if _EXTRA_CTX.flash_comm_v1_enabled:
num_tokens = num_tokens // self.tp_size num_tokens = num_tokens // self.tp_size
output = torch.empty( output = torch.empty(
size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype

View File

@@ -21,7 +21,6 @@ import os
import torch import torch
import torch_npu import torch_npu
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding,
MRotaryEmbedding, MRotaryEmbedding,
@@ -31,6 +30,7 @@ from vllm.model_executor.layers.rotary_embedding import (
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
@@ -240,8 +240,8 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style = self.is_neox_style is_neox_style = self.is_neox_style
if is_neox_style_override is not None: if is_neox_style_override is not None:
is_neox_style = is_neox_style_override is_neox_style = is_neox_style_override
is_draft_model = get_forward_context().is_draft_model is_draft_model = _EXTRA_CTX.is_draft_model
flash_comm_v1_enabled = get_forward_context().flash_comm_v1_enabled flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
if is_draft_model and self.use_mtp and flash_comm_v1_enabled: if is_draft_model and self.use_mtp and flash_comm_v1_enabled:
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions.contiguous(), True) positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions.contiguous(), True)
return torch.ops.vllm.npu_rotary_embedding( return torch.ops.vllm.npu_rotary_embedding(

View File

@@ -6,6 +6,7 @@ from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm_ascend.ascend_config import WeightPrefetchConfig from vllm_ascend.ascend_config import WeightPrefetchConfig
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.linear import AscendQKVParallelLinear, AscendRowParallelLinear from vllm_ascend.ops.linear import AscendQKVParallelLinear, AscendRowParallelLinear
from vllm_ascend.utils import is_moe_model from vllm_ascend.utils import is_moe_model
@@ -95,11 +96,11 @@ class WeightPrefetchMethod:
if not self.moe.is_active_this_forward: if not self.moe.is_active_this_forward:
return return
forward_context = get_forward_context() forward_context = get_forward_context()
if not forward_context or forward_context.model_instance is None: if not forward_context or _EXTRA_CTX.model_instance is None:
return return
# layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm. # layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm.
weight = forward_context.model_instance.model.layers[forward_context.layer_idx - 1].mlp.experts.w13_weight weight = _EXTRA_CTX.model_instance.model.layers[_EXTRA_CTX.layer_idx - 1].mlp.experts.w13_weight # type: ignore # type: ignore
weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0) weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0)
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=None, max_weight_size=int(weight_size)) torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=None, max_weight_size=int(weight_size))
@@ -122,9 +123,7 @@ class WeightPrefetchMethod:
except AssertionError: except AssertionError:
return return
self.mlp.is_active_this_forward = ( self.mlp.is_active_this_forward = (
forward_context.layer_idx is not None _EXTRA_CTX.layer_idx is not None and _EXTRA_CTX.num_tokens is not None and _EXTRA_CTX.num_tokens < 500
and forward_context.num_tokens is not None
and forward_context.num_tokens < 500
) )
if not self.mlp.is_active_this_forward: if not self.mlp.is_active_this_forward:
return return
@@ -144,9 +143,9 @@ class WeightPrefetchMethod:
# start point of gate_up_proj weight prefetch # start point of gate_up_proj weight prefetch
if curr_layer_prefix.split(".")[-2] == "self_attn": if curr_layer_prefix.split(".")[-2] == "self_attn":
model_instance = forward_context.model_instance model_instance = _EXTRA_CTX.model_instance
layer_idx = int(curr_layer_prefix.split(".")[2]) layer_idx = int(curr_layer_prefix.split(".")[2])
weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight # type: ignore
if self.mlp_pre_version_compatibale_config: if self.mlp_pre_version_compatibale_config:
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0) weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0)
else: else:
@@ -156,12 +155,12 @@ class WeightPrefetchMethod:
if weight_size > MAX_PREFETCH_WEIGHT_SIZE: if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
weight_size = MAX_PREFETCH_WEIGHT_SIZE weight_size = MAX_PREFETCH_WEIGHT_SIZE
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size)) torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
forward_context.prefetch_mlp_gate_up_proj = True _EXTRA_CTX.prefetch_mlp_gate_up_proj = True
def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext): def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext):
layer_idx = forward_context.layer_idx layer_idx = _EXTRA_CTX.layer_idx
model_instance = forward_context.model_instance model_instance = _EXTRA_CTX.model_instance
weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight # type: ignore
if self.mlp_pre_version_compatibale_config: if self.mlp_pre_version_compatibale_config:
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0) weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0)
else: else:
@@ -171,22 +170,22 @@ class WeightPrefetchMethod:
if weight_size > MAX_PREFETCH_WEIGHT_SIZE: if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
weight_size = MAX_PREFETCH_WEIGHT_SIZE weight_size = MAX_PREFETCH_WEIGHT_SIZE
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size)) torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
forward_context.prefetch_mlp_down_proj = True _EXTRA_CTX.prefetch_mlp_down_proj = True
forward_context.layer_idx += 1 _EXTRA_CTX.layer_idx = layer_idx + 1 # type: ignore
def maybe_prefetch_mlp_weight_postprocess(self, stop_flag: torch.Tensor): def maybe_prefetch_mlp_weight_postprocess(self, stop_flag: torch.Tensor):
if not self.mlp.is_active_this_forward: if not self.mlp.is_active_this_forward:
return return
try: try:
forward_context = get_forward_context() get_forward_context()
except AssertionError: except AssertionError:
return return
if forward_context.prefetch_mlp_gate_up_proj or forward_context.prefetch_mlp_down_proj: if _EXTRA_CTX.prefetch_mlp_gate_up_proj or _EXTRA_CTX.prefetch_mlp_down_proj:
torch.ops.vllm.prefetch_postprocess(stop_flag) torch.ops.vllm.prefetch_postprocess(stop_flag)
forward_context.prefetch_mlp_gate_up_proj = False _EXTRA_CTX.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False _EXTRA_CTX.prefetch_mlp_down_proj = False
def maybe_prefetch_mla_or_sla_weight_in_current_stream( def maybe_prefetch_mla_or_sla_weight_in_current_stream(
self, self,

View File

@@ -153,7 +153,7 @@ def propose(
# FIXME(woosuk): This is UNSAFE!! # FIXME(woosuk): This is UNSAFE!!
attn_metadata = build_attn_metadata( attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders, attn_groups=self.attn_groups,
num_reqs=num_reqs, num_reqs=num_reqs,
num_tokens=num_reqs, num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc, query_start_loc_gpu=query_start_loc,

View File

@@ -589,11 +589,12 @@ class NPUPlatform(Platform):
if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER: if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
return {} return {}
# is_draft_model will be removed later, so we set it to False temporarily.
is_draft_model = False
moe_comm_type = select_moe_comm_method( moe_comm_type = select_moe_comm_method(
num_tokens, num_tokens,
vllm_config, vllm_config,
# is_draft_model will be removed later, so we set it to False temporarily. is_draft_model=is_draft_model,
is_draft_model=False,
) )
moe_comm_method = get_moe_comm_method(moe_comm_type) moe_comm_method = get_moe_comm_method(moe_comm_type)
@@ -620,7 +621,7 @@ class NPUPlatform(Platform):
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 # TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
pad_size = None pad_size = 0
padded_length = None padded_length = None
if flash_comm_v1_enabled or flashcomm_v2_enabled: if flash_comm_v1_enabled or flashcomm_v2_enabled:
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
@@ -657,6 +658,7 @@ class NPUPlatform(Platform):
"padded_length": padded_length, "padded_length": padded_length,
"max_tokens_across_dp": max_tokens_across_dp, "max_tokens_across_dp": max_tokens_across_dp,
"mc2_mask": mc2_mask, "mc2_mask": mc2_mask,
"is_draft_model": is_draft_model,
} }
@staticmethod @staticmethod

View File

@@ -21,9 +21,9 @@ from typing import Any
import torch import torch
import torch_npu import torch_npu
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from .base import AscendMoEScheme from .base import AscendMoEScheme
@@ -215,7 +215,7 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
return moe_comm_method.fused_experts( return moe_comm_method.fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight_packed, w1=layer.w13_weight_packed,

View File

@@ -23,9 +23,9 @@ import torch
import torch_npu import torch_npu
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed import get_ep_group from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz
@@ -375,7 +375,7 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
return moe_comm_method.fused_experts( return moe_comm_method.fused_experts(
hidden_states=x, hidden_states=x,
w1=[layer.w13_weight], w1=[layer.w13_weight],

View File

@@ -22,11 +22,10 @@ import torch
import torch_npu import torch_npu
from vllm.config import CompilationMode, get_current_vllm_config from vllm.config import CompilationMode, get_current_vllm_config
from vllm.distributed import get_ep_group from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.flash_common3_context import get_flash_common3_context from vllm_ascend.flash_common3_context import get_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
@@ -234,10 +233,9 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
assert topk_weights is not None assert topk_weights is not None
topk_weights = topk_weights.to(self.in_dtype) topk_weights = topk_weights.to(self.in_dtype)
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
fused_scale_flag = ( fused_scale_flag = (
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 _EXTRA_CTX.moe_comm_type == MoECommType.FUSED_MC2 and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
) )
if self.dynamic_eplb: if self.dynamic_eplb:
w1 = layer.w13_weight_list w1 = layer.w13_weight_list

View File

@@ -22,9 +22,9 @@ import torch
import torch_npu import torch_npu
from vllm.config import CompilationMode, get_current_vllm_config from vllm.config import CompilationMode, get_current_vllm_config
from vllm.distributed import get_ep_group from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.device.mxfp_compat import ( from vllm_ascend.device.mxfp_compat import (
FLOAT8_E8M0FNU_DTYPE, FLOAT8_E8M0FNU_DTYPE,
ensure_mxfp8_linear_available, ensure_mxfp8_linear_available,
@@ -187,7 +187,7 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = _EXTRA_CTX.moe_comm_method
return moe_comm_method.fused_experts( return moe_comm_method.fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,

View File

@@ -34,7 +34,7 @@ from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.ascend_forward_context import _EXTRA_CTX, set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
@@ -398,7 +398,7 @@ class AscendEagleProposer(EagleProposer):
num_tokens=num_tokens, num_tokens=num_tokens,
) )
forward_context = get_forward_context() forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing: if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not _EXTRA_CTX.capturing:
self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata) self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata)
def _propose( def _propose(
@@ -784,8 +784,8 @@ class AscendEagleProposer(EagleProposer):
input_batch_size = num_input_tokens if (self.method == "mtp" or self.use_cuda_graph) else batch_size input_batch_size = num_input_tokens if (self.method == "mtp" or self.use_cuda_graph) else batch_size
forward_context = get_forward_context() forward_context = get_forward_context()
forward_context.num_tokens = input_batch_size _EXTRA_CTX.num_tokens = input_batch_size
forward_context.num_accept_tokens = batch_size _EXTRA_CTX.num_accept_tokens = batch_size
for draft_step in range(self.num_speculative_tokens - 1): for draft_step in range(self.num_speculative_tokens - 1):
# Reset MOE layer index for each draft step iteration # Reset MOE layer index for each draft step iteration
@@ -1361,15 +1361,14 @@ class AscendEagleProposer(EagleProposer):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
forward_context = get_forward_context()
if self.method == "mtp": if self.method == "mtp":
if forward_context.flash_comm_v1_enabled: if _EXTRA_CTX.flash_comm_v1_enabled:
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states) hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states)
positions = positions.unsqueeze(-1) positions = positions.unsqueeze(-1)
positions = torch.ops.vllm.maybe_pad_and_reduce(positions) positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
positions = positions.squeeze(-1) positions = positions.squeeze(-1)
else: else:
if forward_context.flash_comm_v1_enabled: if _EXTRA_CTX.flash_comm_v1_enabled:
hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states) hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states)
return hidden_states, positions return hidden_states, positions
@@ -1388,8 +1387,7 @@ class AscendEagleProposer(EagleProposer):
if hidden_states is not None: if hidden_states is not None:
hidden_states = last_hidden_states hidden_states = last_hidden_states
else: else:
forward_context = get_forward_context() if _EXTRA_CTX.flash_comm_v1_enabled:
if forward_context.flash_comm_v1_enabled:
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True last_hidden_states.contiguous(), True
) )

View File

@@ -5,4 +5,5 @@ This directory contains the new model runner which is under active development.
please see [Model Runner V2](https://github.com/vllm-project/vllm-ascend/issues/5208) please see [Model Runner V2](https://github.com/vllm-project/vllm-ascend/issues/5208)
to get specific plans. to get specific plans.
supported vllm version: main@1339784 supported vllm version: main@4034c3d32e30d01639459edd3ab486f56993876d
related PR: <https://github.com/vllm-project/vllm-ascend/pull/7110>

View File

@@ -19,16 +19,25 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from typing import Any
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import vllm
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import logger
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.cudagraph_utils import prepare_inputs_to_capture as prepare_inputs_to_capture_gpu
from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.compilation.acl_graph import set_graph_params, update_full_graph_params
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
@@ -38,44 +47,134 @@ class AclGraphManager(CudaGraphManager):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
use_mrope: bool, use_aux_hidden_state_outputs: bool,
device: torch.device, device: torch.device,
model_runner: Any, # NPUModelRunner type, in case circular import, so we pass it as Any
): ):
with torch_cuda_wrapper(): # set model runner attribute, so we can access attributes model runner
super().__init__(vllm_config, use_mrope, device) # when call `run_fullgraph` method in CudaGraphManager,
# then we don't need to # copy `execute_model` method in `NPUModelRunner` class.
self.model_runner = model_runner
super().__init__(
vllm_config,
use_aux_hidden_state_outputs,
device,
)
# vllm-ascend need to update graph params of attention backend.
# so we need to set graph params before capture full graph.
if super().needs_capture():
set_graph_params(self.cudagraph_sizes)
def _capture_full_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
"""Override _capture_full_graph because we need to set capturing=True in forward context."""
# set capturing=True in before model forward.
model = ModelWithContext(model)
return super()._capture_full_graph(
num_tokens,
num_reqs,
model,
input_ids,
positions,
inputs_embeds,
num_tokens_across_dp,
attn_metadata,
slot_mappings,
has_lora,
)
def capture_graph( def capture_graph(
self, self,
num_tokens: int, num_tokens: int,
capture_cg_mode: CUDAGraphMode,
model: nn.Module, model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
block_tables: BlockTables, block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder], attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
has_lora: bool = False,
uniform_decode: bool = False,
) -> None: ) -> None:
with torch_cuda_wrapper(), prepare_capture_inputs_wrapper(): with torch_cuda_wrapper(), prepare_capture_inputs_wrapper():
super().capture_graph( super().capture_graph(
num_tokens, num_tokens,
capture_cg_mode,
model, model,
model_state,
input_buffers, input_buffers,
block_tables, block_tables,
attn_metadata_builders, attn_groups,
kv_cache_config, kv_cache_config,
has_lora,
uniform_decode,
) )
def run_fullgraph(self, num_tokens: int) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""Override run_fullgraph to update full graph params in run_fullgraph."""
logger.info_once(f"run_fullgraph with num_tokens={num_tokens}")
ret = super().run_fullgraph(num_tokens)
assert self.model_runner.cudagraph_and_dp_padding is not None
positions = self.model_runner.input_buffers.positions[:num_tokens]
_num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
self.model_runner.cudagraph_and_dp_padding
)
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
with set_forward_context(
self.model_runner.input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=None, # Full graph model don't need batch_descriptor
slot_mapping=self.model_runner.input_batch.slot_mappings,
):
forward_context = get_forward_context()
update_full_graph_params(
# FIXME(Ronald1995): support hybrid attn backend
list(self.model_runner.attn_backends.values())[0],
self.model_runner.update_stream,
forward_context,
num_tokens,
self.vllm_config,
self.model_runner.speculative_config,
positions.shape[0],
)
return ret
def is_uniform_decode(
self,
num_reqs: int,
num_tokens: int,
max_query_len: int,
):
return (max_query_len == self.uniform_decode_query_len) and (num_tokens == max_query_len * num_reqs)
@contextmanager @contextmanager
def prepare_capture_inputs_wrapper(): def prepare_capture_inputs_wrapper():
"""Context manager to override input preparation for NPU graph capture.""" """Context manager to override input preparation for NPU graph capture."""
# TODO(Ronald1995): make prepare_inputs_to_capture as static method # TODO(Ronald1995): make prepare_inputs_to_capture as static method
# in CudaGraphManager. # in CudaGraphManager.
global prepare_inputs_to_capture_gpu ori = vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture
try: try:
ori_func = prepare_inputs_to_capture_gpu vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture = prepare_inputs_to_capture
prepare_inputs_to_capture_gpu = prepare_inputs_to_capture
yield yield
finally: finally:
prepare_inputs_to_capture_gpu = ori_func vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture = ori
def prepare_inputs_to_capture( def prepare_inputs_to_capture(
@@ -83,9 +182,66 @@ def prepare_inputs_to_capture(
num_tokens: int, num_tokens: int,
input_buffers: InputBuffers, input_buffers: InputBuffers,
block_tables: BlockTables, block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder], attn_groups: list[list[AttentionGroup]],
max_model_len: int, max_model_len: int,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> dict[str, Any]: uniform_decode_query_len: int = 0,
# TODO(Ronald1995): Implement NPU specific input preparation. ) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
return {} if uniform_decode_query_len > 0:
num_tokens_per_req = uniform_decode_query_len
else:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
query_start_loc_np[-1] = num_tokens
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len.
input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0
input_buffers.seq_lens_cpu[:num_reqs] = num_tokens
input_buffers.seq_lens_cpu[num_reqs:] = 0
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
slot_mappings_by_layer = build_slot_mappings_by_layer(slot_mappings, kv_cache_config)
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=num_tokens_per_req,
seq_lens=input_buffers.seq_lens,
max_seq_len=max_model_len,
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
seq_lens_np=input_buffers.seq_lens_np,
)
return attn_metadata, slot_mappings_by_layer
class ModelWithContext(nn.Module):
"""Define a wrapper model to inject forward context.
so we can inherit vllm's CudaGraphManager._capture_full_graph.
"""
def __init__(self, original_model):
super().__init__()
self.original_model = original_model
def forward(self, *args, **kwargs):
# In warmup phase, capturing=False by default.
# when capturing, we need to set capturing=True in forward context.
_EXTRA_CTX.capturing = True
return self.original_model(*args, **kwargs)

View File

@@ -23,8 +23,8 @@ from typing import Any
import numpy as np import numpy as np
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig
from vllm.v1.worker.utils import AttentionGroup
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -43,7 +43,7 @@ def get_attn_mask_builder(device: torch.device):
def build_attn_metadata( def build_attn_metadata(
*, *,
attn_metadata_builders: list[AttentionMetadataBuilder], attn_groups: list[list[AttentionGroup]],
num_reqs: int, num_reqs: int,
num_tokens: int, num_tokens: int,
query_start_loc_gpu: torch.Tensor, query_start_loc_gpu: torch.Tensor,
@@ -54,6 +54,7 @@ def build_attn_metadata(
block_tables: Sequence[torch.Tensor], block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor, slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
dcp_local_seq_lens: torch.Tensor | None = None,
# extra attributes for ascend npus. # extra attributes for ascend npus.
seq_lens_np: np.ndarray | None = None, seq_lens_np: np.ndarray | None = None,
num_computed_tokens_cpu: torch.Tensor | None = None, num_computed_tokens_cpu: torch.Tensor | None = None,
@@ -72,9 +73,6 @@ def build_attn_metadata(
if seq_lens_np is None: if seq_lens_np is None:
seq_lens_np = np.full(num_reqs, max_seq_len, dtype=np.int32) seq_lens_np = np.full(num_reqs, max_seq_len, dtype=np.int32)
seq_lens_cpu = torch.from_numpy(seq_lens_np)[:num_reqs] seq_lens_cpu = torch.from_numpy(seq_lens_np)[:num_reqs]
# torch_npu._reshape_and_cache operator requires slot_mappings to
# be torch.int32.
slot_mappings = slot_mappings.to(torch.int32)
attn_metadata: dict[str, Any] = {} attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups kv_cache_groups = kv_cache_config.kv_cache_groups
@@ -100,13 +98,14 @@ def build_attn_metadata(
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
) )
attn_metadata_builder = attn_metadata_builders[i] for attn_group in attn_groups[i]:
metadata = attn_metadata_builder.build( attn_metadata_builder = attn_group.get_metadata_builder(0)
common_prefix_len=0, metadata = attn_metadata_builder.build(
common_attn_metadata=common_attn_metadata, # type: ignore common_prefix_len=0,
) common_attn_metadata=common_attn_metadata,
for layer_name in kv_cache_spec.layer_names: )
attn_metadata[layer_name] = metadata for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata return attn_metadata

View File

@@ -0,0 +1,58 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/block_table.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.v1.worker.gpu.block_table import BlockTables
class AscendBlockTables(BlockTables):
"""Block table for Ascend NPUs."""
def __init__(
self,
block_sizes: list[int],
max_num_reqs: int,
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
cp_size: int = 1,
cp_rank: int = 0,
cp_interleave: int = 1,
):
super().__init__(
block_sizes,
max_num_reqs,
max_num_batched_tokens,
max_model_len,
device,
cp_size,
cp_rank,
cp_interleave,
)
# because we will override these attribute, delete these attribute to
# make sure it's collected by python gc immediately.
del self.slot_mappings
# vllm-ascend' reshape_and_cache function requires slot_mappings to be int32.
# so we need to redefine slot_mappings to be int32.
self.slot_mappings: torch.Tensor = torch.zeros(
self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int32,
device=self.device,
)

View File

@@ -22,6 +22,8 @@ import numpy as np
import torch import torch
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm_ascend.attention.attention_v1 import AscendAttentionState
class AscendInputBuffers(InputBuffers): class AscendInputBuffers(InputBuffers):
"""Input buffers for Ascend NPUs.""" """Input buffers for Ascend NPUs."""
@@ -37,6 +39,16 @@ class AscendInputBuffers(InputBuffers):
max_num_tokens, max_num_tokens,
device, device,
) )
del self.query_start_loc
# NOTE: For FULL mode we change +1 to +2 to reserve extra space for padding.
# See _pad_query_start_loc_for_fia.
self.query_start_loc: torch.Tensor = torch.zeros(
max_num_reqs + 2,
dtype=torch.int32,
device=device,
)
# Create seq_lens_cpu and seq_lens_np. # Create seq_lens_cpu and seq_lens_np.
# npu's attention backend still needs seq_lens on CPU side. # npu's attention backend still needs seq_lens on CPU side.
self.seq_lens_cpu: torch.Tensor = torch.zeros( self.seq_lens_cpu: torch.Tensor = torch.zeros(
@@ -56,6 +68,8 @@ class AscendInputBatch(InputBatch):
# Create seq_lens_np. # Create seq_lens_np.
# npu's attention backend still needs seq_lens on CPU side. # npu's attention backend still needs seq_lens on CPU side.
seq_lens_np: np.ndarray seq_lens_np: np.ndarray
# attn_state is used to build attention metadata.
attn_state: AscendAttentionState | None = None
@classmethod @classmethod
def make_dummy( def make_dummy(
@@ -79,4 +93,11 @@ class AscendInputBatch(InputBatch):
input_buffers.seq_lens_np[num_reqs:] = 0 input_buffers.seq_lens_np[num_reqs:] = 0
seq_lens_np = input_buffers.seq_lens_np[:num_reqs] seq_lens_np = input_buffers.seq_lens_np[:num_reqs]
input_batch.seq_lens_np = seq_lens_np input_batch.seq_lens_np = seq_lens_np
# A dummy run for dp or memory profiling.
# When dummy run for dp, num_tokens is set to 1,
# so attn_state is set to DecodeOnly.
# when dummy run for memory profiling,
# attention metadata isn't needed,
# we can also set attn_state to AscendAttentionState.DecodeOnly.
input_batch.attn_state = AscendAttentionState.DecodeOnly
return cls(**asdict(input_batch), seq_lens_np=seq_lens_np) return cls(**asdict(input_batch), seq_lens_np=seq_lens_np)

View File

@@ -17,12 +17,16 @@
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# #
import functools
import numpy as np import numpy as np
import torch import torch
import vllm
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.config.compilation import CUDAGraphMode
from vllm.sequence import IntermediateTensors
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.input_batch import ( from vllm.v1.worker.gpu.input_batch import (
combine_sampled_and_draft_tokens, combine_sampled_and_draft_tokens,
@@ -32,23 +36,23 @@ from vllm.v1.worker.gpu.input_batch import (
) )
from vllm.v1.worker.gpu.model_runner import GPUModelRunner from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import set_weight_prefetch_method
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata, build_attn_state from vllm_ascend.worker.v2.attn_utils import build_attn_state
from vllm_ascend.worker.v2.input_batch import AscendInputBatch, AscendInputBuffers from vllm_ascend.worker.v2.input_batch import AscendInputBatch, AscendInputBuffers
from vllm_ascend.worker.v2.sample.sampler import AscendSampler from vllm_ascend.worker.v2.sample.sampler import AscendSampler
from vllm_ascend.worker.v2.spec_decode import init_speculator from vllm_ascend.worker.v2.spec_decode import init_speculator
from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator
from vllm_ascend.worker.v2.states import AscendRequestState from vllm_ascend.worker.v2.states import AscendRequestState
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper from vllm_ascend.worker.v2.utils import block_table_wrapper, model_states_wrapper, torch_cuda_wrapper
logger = init_logger(__name__)
class NPUModelRunner(GPUModelRunner): class NPUModelRunner(GPUModelRunner):
"""Model runner for Ascend NPUs.""" """Model runner for Ascend NPUs."""
def __init__(self, vllm_config: VllmConfig, device: torch.device): def __init__(self, vllm_config: VllmConfig, device: torch.device):
with torch_cuda_wrapper(): with torch_cuda_wrapper(), block_table_wrapper(), model_states_wrapper():
super().__init__(vllm_config, device) super().__init__(vllm_config, device)
# because we will override these attribute, delete these attribute to # because we will override these attribute, delete these attribute to
@@ -62,8 +66,9 @@ class NPUModelRunner(GPUModelRunner):
# NPU specific initializations can be added below. # NPU specific initializations can be added below.
self.cudagraph_manager: AclGraphManager = AclGraphManager( self.cudagraph_manager: AclGraphManager = AclGraphManager(
self.vllm_config, self.vllm_config,
self.uses_mrope, self.use_aux_hidden_state_outputs,
self.device, self.device,
self,
) )
# we define AscendEagleSpeculator in vllm_ascend.worker.v2.spec_decode.eagle # we define AscendEagleSpeculator in vllm_ascend.worker.v2.spec_decode.eagle
@@ -96,6 +101,7 @@ class NPUModelRunner(GPUModelRunner):
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
device=self.device, device=self.device,
req_states=self.req_states,
logprobs_mode=self.model_config.logprobs_mode, logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1, num_speculative_tokens=self.num_speculative_steps + 1,
) )
@@ -113,6 +119,59 @@ class NPUModelRunner(GPUModelRunner):
pin_memory=True, pin_memory=True,
) )
# Ascend-specific configurations
self.ascend_config = get_ascend_config()
# set this just the same as model runner v1, or it will raise error.
set_weight_prefetch_method(self.ascend_config.weight_prefetch_config)
# we need to update full graph params in run_fullgraph,
# so create a stream to update full graph params.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.update_stream: torch.npu.Stream = torch.npu.Stream()
# we need to use return value of `get_cudagraph_and_dp_padding`
# to set forward_context in `run_fullgraph`.
# so we can inherit `execute_model` method.
self.cudagraph_and_dp_padding: tuple[int, torch.Tensor | None, int] | None = None
# we need to use input_batch to set forward_context in run_fullgraph.
# so we can inherit `execute_model` method.
self.input_batch: AscendInputBatch | None = None
@torch.inference_mode()
def execute_model(
self,
scheduler_output: SchedulerOutput,
intermediate_tensors: IntermediateTensors | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
) -> ModelRunnerOutput | IntermediateTensors | None:
"""Override GPUModelRunner.execute_model for Ascend NPUs by there reasons:
1. when run fullgraph, we need to use ret value of `get_cudagraph_and_dp_padding`
to set forward_context in `run_fullgraph`.
"""
# use closure to store return value of get_cudagraph_and_dp_padding in model runner.
def wrapper(func):
@functools.wraps(func)
def inner(*args, **kwargs):
self.cudagraph_and_dp_padding = func(*args, **kwargs)
return self.cudagraph_and_dp_padding
return inner
if self.cudagraph_and_dp_padding is None:
vllm.v1.worker.gpu.model_runner.get_cudagraph_and_dp_padding = wrapper(
vllm.v1.worker.gpu.model_runner.get_cudagraph_and_dp_padding
)
return super().execute_model(
scheduler_output,
intermediate_tensors,
dummy_run,
skip_attn_for_dummy_run,
)
def prepare_inputs( def prepare_inputs(
self, self,
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
@@ -185,33 +244,40 @@ class NPUModelRunner(GPUModelRunner):
idx_mapping, total_num_logits, cu_num_logits, max_expand_len idx_mapping, total_num_logits, cu_num_logits, max_expand_len
) )
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
# Get query_start_loc. # Get query_start_loc.
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32) # NOTE: For FULL mode we change +1 to +2 to reserve extra space for padding.
# See _pad_query_start_loc_for_fia.
query_start_loc_np = np.empty(self.max_num_reqs + 2, dtype=np.int32)
query_start_loc_np[0] = 0 query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1]) np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
# Pad for full CUDA graph mode. # Pad for full CUDA graph mode.
# Some attention backends like FA3 require query_start_loc to be non-decreasing. # Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np[num_reqs + 1 :] = num_tokens query_start_loc_np[num_reqs + 1 :] = num_tokens
# This is only required for vllm-ascend.
query_start_loc_np, num_reqs_padded = self._pad_query_start_loc_for_fia(
num_tokens_padded=num_tokens_after_padding,
num_tokens=num_tokens,
num_reqs=num_reqs,
query_start_loc_np=query_start_loc_np,
max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
)
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc) async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
query_start_loc_np = query_start_loc_np[: num_reqs + 1] query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1]
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
max_query_len = num_scheduled_tokens.max().item()
# Get prefill tokens. # Get prefill tokens if any.
prepare_prefill_inputs( if self.req_states.any_prefills(idx_mapping_np):
self.input_buffers.input_ids, prepare_prefill_inputs(
self.req_states.next_prefill_tokens, self.input_buffers.input_ids,
idx_mapping, self.req_states.next_prefill_tokens,
query_start_loc, idx_mapping,
self.req_states.prefill_token_ids.gpu, query_start_loc,
self.req_states.prefill_len.gpu, self.req_states.all_token_ids.gpu,
self.req_states.num_computed_tokens.gpu, self.req_states.prefill_len.gpu,
) self.req_states.num_computed_tokens.gpu,
)
# Prepare positions and seq_lens. # Prepare positions and seq_lens.
prepare_pos_seq_lens( prepare_pos_seq_lens(
@@ -223,14 +289,8 @@ class NPUModelRunner(GPUModelRunner):
) )
seq_lens = self.input_buffers.seq_lens[:num_reqs] seq_lens = self.input_buffers.seq_lens[:num_reqs]
# Prepare M-RoPE positions. # Pad for full CUDA graph mode.
if self.uses_mrope: self.input_buffers.seq_lens_np[num_reqs_padded:] = 0
self.mrope_states.prepare_mrope_positions(
idx_mapping,
query_start_loc,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens.gpu,
)
# Some input token ids are directly read from the last sampled tokens # Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from. # and draft tokens. Also, get the logits indices to sample tokens from.
@@ -246,43 +306,12 @@ class NPUModelRunner(GPUModelRunner):
total_num_logits, total_num_logits,
) )
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, self.input_buffers.positions[:num_tokens]
)
# Layer name -> slot mapping.
slot_mappings_by_layer = build_slot_mappings_by_layer(slot_mappings, self.kv_cache_config)
# Layer name -> attention metadata.
# TODO(Ronald1995): try to add a new method `build_attn_metadata` in
# vllm gpu_model_runner_v2, maybe we don't overwrite `prepare_inputs`
# method like this.
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=self.input_buffers.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
# extra attributes for ascend npus.
seq_lens_np=self.input_buffers.seq_lens_np,
num_computed_tokens_cpu=self.req_states.num_computed_tokens_cpu[idx_mapping_cpu],
attn_state=attn_state,
)
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding] input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding] positions = self.input_buffers.positions[:num_tokens_after_padding]
mrope_positions = None
if self.uses_mrope: self.input_batch = AscendInputBatch(
mrope_positions = self.mrope_states.mrope_positions
mrope_positions = mrope_positions[:, :num_tokens_after_padding]
return AscendInputBatch(
req_ids=req_ids, req_ids=req_ids,
num_reqs=num_reqs, num_reqs=num_reqs_padded,
idx_mapping=idx_mapping, idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np, idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping, expanded_idx_mapping=expanded_idx_mapping,
@@ -294,18 +323,18 @@ class NPUModelRunner(GPUModelRunner):
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np, query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens, seq_lens=seq_lens,
dcp_local_seq_lens=None, # TODO(Ronald1995): support cp.
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
mrope_positions=mrope_positions,
inputs_embeds=None,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
logits_indices=logits_indices, logits_indices=logits_indices,
cu_num_logits=cu_num_logits, cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np, cu_num_logits_np=cu_num_logits_np,
has_structured_output_reqs=scheduler_output.has_structured_output_requests, has_structured_output_reqs=scheduler_output.has_structured_output_requests,
# extra attributes for ascend npus.
seq_lens_np=self.input_buffers.seq_lens_np, seq_lens_np=self.input_buffers.seq_lens_np,
attn_state=attn_state,
) )
return self.input_batch
def postprocess( def postprocess(
self, self,
@@ -352,7 +381,7 @@ class NPUModelRunner(GPUModelRunner):
self.req_states.num_computed_tokens_cpu[req_index] = self.num_computed_tokens_cpu[req_index] self.req_states.num_computed_tokens_cpu[req_index] = self.num_computed_tokens_cpu[req_index]
# update seq_lens_cpu # update seq_lens_cpu
for i, req_id in enumerate(req_ids): for i, req_id in enumerate(req_ids): # type: ignore
req_index = self.req_states.req_id_to_index[req_id] req_index = self.req_states.req_id_to_index[req_id]
num_computed_tokens = self.req_states.num_computed_tokens_cpu[req_index] num_computed_tokens = self.req_states.num_computed_tokens_cpu[req_index]
self.input_buffers.seq_lens_cpu[i] = num_computed_tokens + num_scheduled_tokens[req_id] self.input_buffers.seq_lens_cpu[i] = num_computed_tokens + num_scheduled_tokens[req_id]
@@ -361,3 +390,44 @@ class NPUModelRunner(GPUModelRunner):
# TODO(Ronald1995): just define the method in case calling error in # TODO(Ronald1995): just define the method in case calling error in
# worker, implement it in the future. # worker, implement it in the future.
pass pass
def _pad_query_start_loc_for_fia(
self,
num_tokens_padded: int,
num_tokens: int,
num_reqs: int,
query_start_loc_np: np.ndarray,
max_query_len: int,
) -> tuple[np.ndarray, int]:
"""
This function is only designed to satisfied the constraint that when the layout is TND,
the first dimension of `hidden_states` must equal the last element of `actual_seq_lengths_q`.
"""
assert self.cudagraph_and_dp_padding is not None
_num_tokens_after_padding, _num_tokens_across_dp, synced_cudagraph_mode = self.cudagraph_and_dp_padding
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
if cudagraph_runtime_mode != CUDAGraphMode.FULL:
return query_start_loc_np, num_reqs
uniform_decode_query_len = self.cudagraph_manager.uniform_decode_query_len
is_uniform_decode = self.cudagraph_manager.is_uniform_decode(
num_reqs=num_reqs,
num_tokens=num_tokens,
max_query_len=max_query_len,
)
if is_uniform_decode:
# Uniform-batch case: num_reqs must be no greater than num_reqs_padded
num_reqs_padded = num_tokens_padded // uniform_decode_query_len
last_loc = query_start_loc_np[num_reqs]
query_start_loc_np[num_reqs + 1 : num_reqs_padded + 1] = (
np.arange(1, num_reqs_padded + 1 - num_reqs) * uniform_decode_query_len + last_loc
)
else:
# Mixed-batch case: num_reqs must equal num_reqs_padded
num_reqs_padded = min(num_tokens_padded, self.max_num_reqs)
# Insert a dummy request instead of setting query_start_loc[num_reqs] = num_tokens_padded directly
query_start_loc_np[num_reqs_padded + 1] = num_tokens_padded
num_reqs_padded = num_reqs_padded + 1
return query_start_loc_np, num_reqs_padded

View File

@@ -0,0 +1,34 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_states/__init__.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
def init_asecnd_model_state(
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
):
from vllm_ascend.worker.v2.model_states.default import AscendModelState
return AscendModelState(vllm_config, model, encoder_cache, device)

View File

@@ -0,0 +1,62 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_states/default.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Any
import torch
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
from vllm.v1.worker.utils import AttentionGroup
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
from vllm_ascend.worker.v2.input_batch import AscendInputBatch
class AscendModelState(DefaultModelState):
"""Model state for Ascend NPUs."""
def prepare_attn(
self,
input_batch: AscendInputBatch,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
"""Override prepare_attn method because `build_attn_metadata` is different from vllm."""
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
# extra attributes for ascend npus.
seq_lens_np=input_batch.seq_lens_np,
attn_state=input_batch.attn_state,
)
return attn_metadata

View File

@@ -16,9 +16,6 @@
# #
import numpy as np import numpy as np
import torch import torch
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample
@@ -53,21 +50,23 @@ class AscendSampler(Sampler):
self.num_speculative_tokens, self.num_speculative_tokens,
) )
# Apply bad words masking in place.
self.bad_words_state.apply_bad_words(
logits,
idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
)
# Apply temperature in place. # Apply temperature in place.
apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu) self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
# Apply min_p in place if any request has a non-zero min_p. # Apply min_p in place.
do_min_p = self.sampling_states.do_min_p(idx_mapping_np) self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
if do_min_p:
apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu)
# Apply top_k and/or top_p. This might return a new tensor. # Apply top_k and/or top_p. This might or might not return a new tensor.
do_top_k = self.sampling_states.do_top_k(idx_mapping_np) logits = self.sampling_states.apply_top_k_top_p(logits, idx_mapping, idx_mapping_np)
top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None
do_top_p = self.sampling_states.do_top_p(idx_mapping_np)
top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None
if do_top_k or do_top_p:
logits = apply_top_k_top_p(logits, top_k, top_p)
# Sample the next token. # Sample the next token.
sampled = gumbel_sample( sampled = gumbel_sample(

View File

@@ -23,7 +23,7 @@ import torch
import vllm import vllm
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata from vllm_ascend.worker.v2.attn_utils import build_attn_metadata

View File

@@ -56,13 +56,13 @@ class AscendRequestState(RequestState):
self, self,
req_id, req_id,
prompt_len, prompt_len,
prefill_token_ids, all_token_ids,
num_computed_tokens, num_computed_tokens,
): ):
super().add_request( super().add_request(
req_id, req_id,
prompt_len, prompt_len,
prefill_token_ids, all_token_ids,
num_computed_tokens, num_computed_tokens,
) )
req_idx = self.req_id_to_index[req_id] req_idx = self.req_id_to_index[req_id]

View File

@@ -1,6 +1,11 @@
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
import vllm
from vllm.logger import logger
from vllm_ascend.worker.v2.block_table import AscendBlockTables
from vllm_ascend.worker.v2.model_states import init_asecnd_model_state
@contextmanager @contextmanager
@@ -15,6 +20,34 @@ def torch_cuda_wrapper():
torch.cuda.CUDAGraph = torch.npu.NPUGraph torch.cuda.CUDAGraph = torch.npu.NPUGraph
torch.cuda.graph = torch.npu.graph torch.cuda.graph = torch.npu.graph
torch.cuda.synchronize = torch.npu.synchronize torch.cuda.synchronize = torch.npu.synchronize
torch.cuda.set_stream = torch.npu.set_stream
torch.cuda.current_device = torch.npu.current_device
torch.cuda.mem_get_info = torch.npu.mem_get_info
logger.info_once("Wrapping torch.cuda with torch.npu.")
yield
finally:
pass
@contextmanager
def block_table_wrapper():
try:
# vllm-ascend need to initialize slot mapping as torch.int32 dtype,
# but vllm default is torch.int64 dtype.
vllm.v1.worker.gpu.model_runner.BlockTables = AscendBlockTables
logger.info_once("Wrapping BlockTables with AscendBlockTables.")
yield
finally:
pass
@contextmanager
def model_states_wrapper():
try:
# prepare_attn in AscendModelState is different from vllm,
# we need to override init_model_state.
vllm.v1.worker.gpu.model_runner.init_model_state = init_asecnd_model_state
logger.info_once("Wrapping init_model_state with init_asecnd_model_state.")
yield yield
finally: finally:
pass pass