diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py index 369be649..5da3f021 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py @@ -184,7 +184,7 @@ def test_token_dispatcher_with_all_gather_quant( ): context_mock = MagicMock() context_mock.fused_moe_state = 0 - with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context", + with patch("vllm_ascend.ascend_forward_context.get_forward_context", return_value=context_mock): a = torch.randn((m, k), device=device, dtype=dtype) / 10 w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8) diff --git a/tests/e2e/singlecard/model_runner_v2/test_basic.py b/tests/e2e/singlecard/model_runner_v2/test_basic.py index dc019a8b..5c039dd6 100644 --- a/tests/e2e/singlecard/model_runner_v2/test_basic.py +++ b/tests/e2e/singlecard/model_runner_v2/test_basic.py @@ -85,3 +85,29 @@ def test_egale_spec_decoding( }, ) as runner: runner.model.generate(prompts, sampling_params) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("compilation_config", [{"cudagraph_mode": "FULL_DECODE_ONLY"}, {}]) +@patch.dict(os.environ, {"VLLM_USE_V2_MODEL_RUNNER": "1"}) +def test_qwen3_dense_graph_mode( + model: str, + max_tokens: int, + enforce_eager: bool, +) -> None: + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0) + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=enforce_eager, + ) as runner: + runner.model.generate(prompts, sampling_params) diff --git a/tests/ut/_310p/attention/test_attention_v1_310.py b/tests/ut/_310p/attention/test_attention_v1_310.py index 3370baee..dd0f2dcf 100644 --- a/tests/ut/_310p/attention/test_attention_v1_310.py +++ b/tests/ut/_310p/attention/test_attention_v1_310.py @@ -74,7 +74,7 @@ class TestAscendAttentionBackendImpl310(TestBase): @patch("torch_npu._npu_reshape_and_cache") @patch("torch_npu._npu_flash_attention") - @patch("vllm_ascend.attention.attention_v1.get_forward_context") + @patch("vllm_ascend.ascend_forward_context.get_forward_context") def test_forward_prefill_310( self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache ): @@ -105,7 +105,7 @@ class TestAscendAttentionBackendImpl310(TestBase): @patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16)) @patch("torch_npu._npu_reshape_and_cache") @patch("torch_npu._npu_paged_attention_splitfuse") - @patch("vllm_ascend.attention.attention_v1.get_forward_context") + @patch("vllm_ascend.ascend_forward_context.get_forward_context") def test_forward_chunked_prefill_310( self, mock_get_forward_context, @@ -140,7 +140,7 @@ class TestAscendAttentionBackendImpl310(TestBase): @patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16)) @patch("torch_npu._npu_reshape_and_cache") @patch("torch_npu._npu_paged_attention_splitfuse") - @patch("vllm_ascend.attention.attention_v1.get_forward_context") + @patch("vllm_ascend.ascend_forward_context.get_forward_context") def test_forward_prefill_cache_hit_310( self, mock_get_forward_context, @@ -175,7 +175,7 @@ class TestAscendAttentionBackendImpl310(TestBase): @patch("vllm_ascend.attention.attention_v1.using_paged_attention") @patch("torch_npu._npu_paged_attention") @patch("torch_npu._npu_reshape_and_cache") - @patch("vllm_ascend.attention.attention_v1.get_forward_context") + @patch("vllm_ascend.ascend_forward_context.get_forward_context") def test_forward_paged_attention_310( self, mock_get_forward_context, mock_npu_reshape_and_cache, mock_paged_attention, mock_using_paged_attention ): diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py index c3866b13..ddbe5c7d 100644 --- a/tests/ut/attention/test_attention_cp.py +++ b/tests/ut/attention/test_attention_cp.py @@ -95,7 +95,7 @@ class TestAscendAttentionCPImpl(TestBase): @patch('torch_npu.npu_attention_update') @patch("torch_npu.npu_fused_infer_attention_score") @patch( - 'vllm_ascend.attention.context_parallel.attention_cp.get_forward_context' + 'vllm_ascend.ascend_forward_context.get_forward_context' ) @patch_distributed_groups(dcp_size=2, pcp_size=2) def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp, diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index bd0f5988..d32714a6 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -212,7 +212,7 @@ class TestAscendAttentionBackendImpl(TestBase): @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') - @patch('vllm_ascend.attention.attention_v1.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') def test_forward_fused_infer_attention( self, mock_get_forward_context, mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache): @@ -248,7 +248,7 @@ class TestAscendAttentionBackendImpl(TestBase): @patch('vllm_ascend.attention.attention_v1.using_paged_attention') @patch('torch_npu._npu_paged_attention') @patch('torch_npu._npu_reshape_and_cache') - @patch('vllm_ascend.attention.attention_v1.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') def test_forward_paged_attention(self, mock_get_forward_context, mock_npu_reshape_and_cache, mock_paged_attention, @@ -279,7 +279,7 @@ class TestAscendAttentionBackendImpl(TestBase): mock_paged_attention.assert_called_once() assert output.shape == (4, 8 * 64) - @patch('vllm_ascend.attention.attention_v1.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('torch_npu.npu_fused_infer_attention_score') @patch('torch_npu._npu_reshape_and_cache') def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache, @@ -311,7 +311,7 @@ class TestAscendAttentionBackendImpl(TestBase): mock_fused_infer_attention_score.assert_called_once() assert output.shape == (10, 8, 64) - @patch('vllm_ascend.attention.attention_v1.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('torch_npu._npu_paged_attention') @patch('torch_npu.npu_fused_infer_attention_score') @patch('torch_npu._npu_reshape_and_cache') diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index 39696a83..1adf6419 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -449,7 +449,7 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(result.shape[1], N) self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1) - @patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("torch_npu.npu_fused_infer_attention_score") @patch('torch_npu.npu_attention_update') @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 4a7ee9e1..8fe78566 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -929,7 +929,7 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(out.shape, prefix_out.shape) self.assertEqual(lse.shape, prefix_lse.shape) - @patch('vllm_ascend.attention.mla_v1.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj") @patch("torch_npu.npu_fused_infer_attention_score") def test_forward_decode_without_graph(self, @@ -1095,7 +1095,7 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim) self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank) - @patch('vllm_ascend.attention.mla_v1.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("torch_npu.npu_fused_infer_attention_score") def test_forward_decode(self, mock_npu_fused_infer_attention_score, mock_get_forward_context): diff --git a/tests/ut/compilation/test_acl_graph.py b/tests/ut/compilation/test_acl_graph.py index 77d779ca..afee5ef5 100644 --- a/tests/ut/compilation/test_acl_graph.py +++ b/tests/ut/compilation/test_acl_graph.py @@ -161,12 +161,13 @@ class TestACLGraphWrapper(TestBase): vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.NONE) + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.envs') def test_call_with_none_runtime_mode(self, mock_envs, mock_current_platform, - mock_get_forward_context): + mock_get_forward_context, mock_get_forward_context_2): """Test __call__ method when runtime mode is NONE""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool @@ -185,16 +186,19 @@ class TestACLGraphWrapper(TestBase): self.mock_runnable.assert_called_once_with("arg1", "arg2") self.assertEqual(result, "test_output") + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.envs') def test_call_with_mismatched_runtime_mode(self, mock_envs, mock_current_platform, - mock_get_forward_context): + mock_get_forward_context, + mock_get_forward_context_2): """Test __call__ method when runtime mode doesn't match wrapper mode""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_get_forward_context.return_value = self.mock_forward_context + mock_get_forward_context_2.return_value = self.mock_forward_context self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL wrapper = ACLGraphWrapper( @@ -214,18 +218,20 @@ class TestACLGraphWrapper(TestBase): 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' ) @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.compilation_counter') @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') def test_call_capture_graph_first_time( self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs, - mock_current_platform, mock_get_forward_context, + mock_current_platform, mock_get_forward_context,mock_get_forward_context_2, mock_validate_cudagraph_capturing_enabled, mock_torch): """Test __call__ method captures graph for the first time""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_get_forward_context.return_value = self.mock_forward_context + mock_get_forward_context_2.return_value = self.mock_forward_context self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL # Mock torch.npu.NPUGraph @@ -284,6 +290,7 @@ class TestACLGraphWrapper(TestBase): 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' ) @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.compilation_counter') @@ -291,12 +298,15 @@ class TestACLGraphWrapper(TestBase): def test_call_replay_graph(self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs, mock_current_platform, mock_get_forward_context, + mock_get_forward_context_2, mock_validate_cudagraph_capturing_enabled, mock_torch): """Test __call__ method replays graph when already captured""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_get_forward_context.return_value = self.mock_forward_context + mock_get_forward_context_2.return_value = self.mock_forward_context + self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL self.mock_forward_context.is_draft_model = False @@ -358,17 +368,19 @@ class TestACLGraphWrapper(TestBase): 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' ) @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') def test_call_with_debug_mode_input_address_check( self, mock_weak_ref_tensors, mock_envs, mock_current_platform, - mock_get_forward_context, + mock_get_forward_context,mock_get_forward_context_2, mock_validate_cudagraph_capturing_enabled, mock_torch): """Test __call__ method with debug mode input address checking""" mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_get_forward_context.return_value = self.mock_forward_context + mock_get_forward_context_2.return_value = self.mock_forward_context self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL self.mock_forward_context.is_draft_model = False @@ -413,17 +425,19 @@ class TestACLGraphWrapper(TestBase): 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' ) @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') def test_call_with_debug_mode_input_address_mismatch( self, mock_weak_ref_tensors, mock_envs, mock_current_platform, - mock_get_forward_context, + mock_get_forward_context,mock_get_forward_context_2, mock_validate_cudagraph_capturing_enabled, mock_torch): """Test __call__ method with debug mode input address mismatch raises AssertionError""" mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_get_forward_context.return_value = self.mock_forward_context + mock_get_forward_context_2.return_value = self.mock_forward_context self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL # Mock torch.npu.NPUGraph @@ -471,6 +485,7 @@ class TestACLGraphWrapper(TestBase): 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' ) @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.compilation_counter') @@ -478,12 +493,13 @@ class TestACLGraphWrapper(TestBase): @patch('vllm_ascend.compilation.acl_graph.patch') def test_call_capture_graph_with_gc_disable( self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter, - mock_envs, mock_current_platform, mock_get_forward_context, + mock_envs, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2, mock_validate_cudagraph_capturing_enabled, mock_torch): """Test __call__ method captures graph with gc_disable option enabled""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_get_forward_context.return_value = self.mock_forward_context + mock_get_forward_context_2.return_value = self.mock_forward_context self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL # Enable gc_disable option @@ -545,18 +561,20 @@ class TestACLGraphWrapper(TestBase): 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' ) @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.compilation_counter') @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') def test_call_capture_graph_with_weak_ref_output( self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs, - mock_current_platform, mock_get_forward_context, + mock_current_platform, mock_get_forward_context,mock_get_forward_context_2, mock_validate_cudagraph_capturing_enabled, mock_torch): """Test __call__ method captures graph with weak_ref_output option enabled""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_get_forward_context.return_value = self.mock_forward_context + mock_get_forward_context_2.return_value = self.mock_forward_context self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL # Enable weak_ref_output option @@ -608,18 +626,20 @@ class TestACLGraphWrapper(TestBase): # Should return the weak ref output when weak_ref_output option is enabled self.assertEqual(result, "weak_ref_output") - + @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.compilation.acl_graph.current_platform') @patch('vllm_ascend.compilation.acl_graph.envs') @patch('vllm_ascend.compilation.acl_graph.logger') def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs, mock_current_platform, - mock_get_forward_context): + mock_get_forward_context,mock_get_forward_context_2): """Test __call__ method captures graph with debug logging enabled""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_get_forward_context.return_value = self.mock_forward_context + mock_get_forward_context_2.return_value = self.mock_forward_context self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL # Enable debug logging @@ -757,10 +777,11 @@ class TestPCPDCPGraphParams(TestBase): self.graph_params.events[4].append(mock_event) self.graph_params.handles[4].append(MagicMock()) + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('torch.npu.graph_task_update_end', ) @patch('torch.npu.graph_task_update_begin', MagicMock()) @patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock()) - def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end): + def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context): input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) block_table = torch.zeros(2, 5, dtype=torch.long) seq_lens = torch.tensor([4, 4]) @@ -790,6 +811,7 @@ class TestPCPDCPGraphParams(TestBase): forward_context = MagicMock() forward_context.attn_metadata = {"attn_layer_0": metadata} forward_context.is_draft_model = False + mock_context.return_value = forward_context num_heads = 256 scale = 0.1 diff --git a/tests/ut/ops/test_fused_moe.py b/tests/ut/ops/test_fused_moe.py index b82ca200..f0bd0f3e 100644 --- a/tests/ut/ops/test_fused_moe.py +++ b/tests/ut/ops/test_fused_moe.py @@ -119,11 +119,9 @@ def mock_dist_env(mocker: MockerFixture): return_value=(torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]), None, 0)), \ patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context', return_value=mock_forward_context_obj), \ - patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context', + patch('vllm_ascend.ascend_forward_context.get_forward_context', return_value=mock_forward_context_obj), \ patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3), \ - patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context', - return_value=mock_forward_context_obj), \ patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher', return_value=None), \ patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher', @@ -298,7 +296,7 @@ class TestUnifiedApplyMLP(TestBase): @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method', return_value=MagicMock()) - @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.utils.get_ascend_device_type', return_value=AscendDeviceType.A3) @patch('torch_npu.npu_grouped_matmul') @@ -407,7 +405,7 @@ class TestUnifiedApplyMLP(TestBase): @patch('vllm_ascend.ops.fused_moe.moe_mlp.HAS_TRITON', False) @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method', return_value=MagicMock()) - @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context') + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') @@ -513,7 +511,7 @@ class TestUnifiedApplyMLP(TestBase): @patch("vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method", return_value=MagicMock()) - @patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context") + @patch("vllm_ascend.ascend_forward_context.get_forward_context") @patch("torch_npu.npu_grouped_matmul") @patch("torch_npu.npu_swiglu") @patch("torch_npu.npu_grouped_matmul_swiglu_quant") diff --git a/tests/ut/ops/test_mla.py b/tests/ut/ops/test_mla.py index 8080179f..870daed4 100644 --- a/tests/ut/ops/test_mla.py +++ b/tests/ut/ops/test_mla.py @@ -121,9 +121,10 @@ class TestAscendMultiHeadLatentAttention(TestBase): @patch("vllm_ascend.ops.mla.get_ascend_config") @patch("vllm_ascend.ops.mla.get_tensor_model_parallel_world_size") @patch("vllm_ascend.ops.mla.get_forward_context") - def test_forward(self, mock_get_forward_context, mock_tp_size, + @patch('vllm_ascend.ascend_forward_context.get_forward_context') + def test_forward(self, mock_get_forward_context_2, mock_get_forward_context, mock_tp_size, mock_ascend_config, mock_get_vllm_config, - mock_mla_forward): + mock_mla_forward,): mock_tp_size.return_value = 1 mock_ascend_config.return_value.enable_shared_expert_dp = False mock_vllm_config = MagicMock(spec=VllmConfig) @@ -159,6 +160,7 @@ class TestAscendMultiHeadLatentAttention(TestBase): mock_forward_context = MagicMock(spec=ForwardContext) mock_forward_context.flash_comm_v1_enabled = False mock_get_forward_context.return_value = mock_forward_context + mock_get_forward_context_2.return_value = mock_forward_context mock_mla_forward.return_value = (3, self.hidden_size) diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index 40d768af..ed805dd7 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -28,7 +28,7 @@ class TestMoECommMethod(TestBase): self.moe_config.dp_group = MagicMock() self.moe_config.global_redundant_expert_num = 0 - @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch( "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather" ) @@ -73,7 +73,7 @@ class TestMoECommMethod(TestBase): context_metadata=context_metadata) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) - @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch( "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2") @@ -116,7 +116,7 @@ class TestMoECommMethod(TestBase): context_metadata=context_metadata) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) - @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch( "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All" ) @@ -155,7 +155,7 @@ class TestMoECommMethod(TestBase): mock_pf_instance.prepare.assert_called_once_with( hidden_states, router_logits, False, False, QuantType.NONE) - @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch( "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather" ) diff --git a/tests/ut/ops/test_prepare_finalize.py b/tests/ut/ops/test_prepare_finalize.py index bb867155..f25b5ab2 100644 --- a/tests/ut/ops/test_prepare_finalize.py +++ b/tests/ut/ops/test_prepare_finalize.py @@ -32,7 +32,7 @@ class TestPrepareAndFinalize(unittest.TestCase): @patch( "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank", return_value=0) - @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context") + @patch('vllm_ascend.ascend_forward_context.get_forward_context') def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank, mock_tp_size): mock_context = MagicMock() @@ -65,7 +65,7 @@ class TestPrepareAndFinalize(unittest.TestCase): @patch( "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank", return_value=0) - @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context") + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("torch.distributed.all_gather") def test_mc2_tp_split_allgather(self, mock_all_gather, mock_get_forward_context, mock_tp_rank, @@ -169,7 +169,7 @@ class TestPrepareAndFinalize(unittest.TestCase): self.assertEqual(final_result.shape[0], 2) @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group") - @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context") + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp", return_value=False) def test_allgather_prepare_finalize(self, mock_enable_sp, diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 45233672..66dfc8a1 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -386,10 +386,11 @@ class TestEagleProposerDummyRun(TestBase): set_current_vllm_config(None) # cpu does not support parallel-group, let alone `sp` + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context", **{"return_value.flash_comm_v1_enabled": False}) @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") - def test_dummy_run_basic(self, mock_context, mock_get_context): + def test_dummy_run_basic(self, mock_context, mock_get_context, mock_get_context_2): num_tokens = 32 with_prefill = False @@ -402,10 +403,11 @@ class TestEagleProposerDummyRun(TestBase): self.assertTrue(self.proposer._runnable.call_count == 1) # cpu does not support parallel-group, let alone `sp` + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context", **{"return_value.flash_comm_v1_enabled": False}) @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") - def test_dummy_run_with_prefill(self, mock_context, mock_get_context): + def test_dummy_run_with_prefill(self, mock_context, mock_get_context, mock_get_context_2): mock_context.return_value.__enter__.return_value = None # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` with set_current_vllm_config(self.vllm_config): @@ -413,11 +415,12 @@ class TestEagleProposerDummyRun(TestBase): self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4) self.assertTrue(self.proposer._runnable.call_count == 1) + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params") @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context, - mock_update_full_graph_params): + mock_update_full_graph_params, mock_get_context_2): last_use_cuda_graph = self.proposer.use_cuda_graph mock_return_context = MagicMock() mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL @@ -425,6 +428,7 @@ class TestEagleProposerDummyRun(TestBase): # cpu does not support parallel-group, let alone `sp` mock_return_context.flash_comm_v1_enabled = False mock_get_context.return_value = mock_return_context + mock_get_context_2.return_value = mock_return_context self.proposer.use_cuda_graph = True # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` with set_current_vllm_config(self.vllm_config): @@ -435,12 +439,13 @@ class TestEagleProposerDummyRun(TestBase): self.assertTrue(self.proposer._runnable.call_count == 1) mock_update_full_graph_params.assert_not_called() self.proposer.use_cuda_graph = last_use_cuda_graph - + + @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params") @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_in_graph_run(self, mock_context, mock_get_context, - mock_update_full_graph_params): + mock_update_full_graph_params, mock_get_context_2): last_use_cuda_graph = self.proposer.use_cuda_graph mock_return_context = MagicMock() mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL @@ -448,6 +453,7 @@ class TestEagleProposerDummyRun(TestBase): # cpu does not support parallel-group, let alone `sp` mock_return_context.flash_comm_v1_enabled = False mock_get_context.return_value = mock_return_context + mock_get_context_2.return_value = mock_return_context self.proposer.use_cuda_graph = True # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` with set_current_vllm_config(self.vllm_config): diff --git a/vllm_ascend/_310p/fused_moe/fused_moe.py b/vllm_ascend/_310p/fused_moe/fused_moe.py index 4a411dc4..9e23cc9c 100644 --- a/vllm_ascend/_310p/fused_moe/fused_moe.py +++ b/vllm_ascend/_310p/fused_moe/fused_moe.py @@ -18,12 +18,11 @@ from collections.abc import Callable import torch from vllm.distributed import get_dp_group, get_ep_group, get_tp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE -from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods from vllm_ascend.quantization.methods.base import QuantType @@ -93,7 +92,7 @@ class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod): topk_weights = topk_weights.to(x.dtype) - moe_comm_method = get_forward_context().moe_comm_method + moe_comm_method = _EXTRA_CTX.moe_comm_method final_hidden_states = moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -222,9 +221,8 @@ class AscendFusedMoE310(FusedMoE): ) -> torch.Tensor: assert self.quant_method is not None assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported." - forward_context = get_forward_context() - hidden_states, router_logits, _, context_metadata = forward_context.moe_comm_method.prepare( + hidden_states, router_logits, _, context_metadata = _EXTRA_CTX.moe_comm_method.prepare( hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type ) @@ -246,7 +244,7 @@ class AscendFusedMoE310(FusedMoE): apply_router_weight_on_input=self.apply_router_weight_on_input, ) - routed_out = forward_context.moe_comm_method.finalize( + routed_out = _EXTRA_CTX.moe_comm_method.finalize( hidden_states=fused_experts_results.routed_out, reduce_results=self.reduce_results, context_metadata=context_metadata, diff --git a/vllm_ascend/_310p/fused_moe/moe_comm_method.py b/vllm_ascend/_310p/fused_moe/moe_comm_method.py index bcc9851b..589566fc 100644 --- a/vllm_ascend/_310p/fused_moe/moe_comm_method.py +++ b/vllm_ascend/_310p/fused_moe/moe_comm_method.py @@ -16,8 +16,8 @@ from __future__ import annotations import torch -from vllm.forward_context import get_forward_context +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult from .moe_mlp import unified_apply_mlp @@ -50,7 +50,7 @@ class AllGatherCommImpl310(AllGatherCommImpl): ) -> FusedExpertsResult: # This method is overridden to use the 310p-specific unified_apply_mlp # which provides optimized MLP computation for the 310p platform - moe_comm_method = get_forward_context().moe_comm_method + moe_comm_method = _EXTRA_CTX.moe_comm_method assert moe_comm_method is not None, "Missing communication context" dispatch_results = self.token_dispatcher.token_dispatch( diff --git a/vllm_ascend/_310p/quantization/methods/w8a8_dynamic.py b/vllm_ascend/_310p/quantization/methods/w8a8_dynamic.py index 8a9b6597..6a1a1303 100644 --- a/vllm_ascend/_310p/quantization/methods/w8a8_dynamic.py +++ b/vllm_ascend/_310p/quantization/methods/w8a8_dynamic.py @@ -21,9 +21,9 @@ from typing import Any import torch from vllm.config import get_current_vllm_config from vllm.distributed import get_ep_group -from vllm.forward_context import get_forward_context from vllm_ascend._310p.fused_moe.experts_selector import select_experts +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType @@ -125,7 +125,7 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme): topk_weights = topk_weights.to(self.in_dtype) - moe_comm_method = get_forward_context().moe_comm_method + moe_comm_method = _EXTRA_CTX.moe_comm_method final_hidden_states = moe_comm_method.fused_experts( hidden_states=x, diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 116f562a..afca4d97 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -4,6 +4,7 @@ from enum import Enum from typing import Any import torch +import vllm.envs as envs_vllm from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context @@ -270,3 +271,61 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo else: raise ValueError(f"Unsupported soc_version: {soc_version}") return moe_comm_type + + +class _ExtraForwardContextProxy: + """Unified forward-context access for v1/v2 model runners.""" + + extra_attrs = ( + "capturing", + "moe_comm_type", + "moe_comm_method", + "mmrs_fusion", + "num_tokens", + "flash_comm_v1_enabled", + "flashcomm_v2_enabled", + "pad_size", + "padded_length", + "num_tokens_across_dp", + "mc2_mask", + "is_draft_model", + "prefetch_mlp_gate_up_proj", + "prefetch_mlp_down_proj", + "model_instance", + "layer_idx", + "max_tokens_across_dp", + "max_tokens_across_pcp", + "num_accept_tokens", + "in_profile_run", + "padded_num_tokens", + ) + + def check_extra_attr(self, name: str): + if name not in self.extra_attrs: + raise AttributeError( + f"{name} is not extra forward context attribute, " + "please get/set it from vllm's _forward_context directly." + ) + + @staticmethod + def _ctx(): + return get_forward_context() + + def __getattr__(self, name: str) -> Any: + self.check_extra_attr(name) + ctx = self._ctx() + if envs_vllm.VLLM_USE_V2_MODEL_RUNNER: + return ctx.additional_kwargs[name] + return getattr(ctx, name) + + def __setattr__(self, name: str, value: Any) -> None: + self.check_extra_attr(name) + ctx = self._ctx() + if envs_vllm.VLLM_USE_V2_MODEL_RUNNER: + ctx.additional_kwargs[name] = value + else: + setattr(ctx, name, value) + + +# usage: from vllm_ascend.ascend_forward_context import _EXTRA_CTX +_EXTRA_CTX = _ExtraForwardContextProxy() diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index a1c79d94..ce9d3e1b 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -23,7 +23,6 @@ import torch import torch_npu import vllm.envs as envs_vllm from vllm.config import VllmConfig, get_current_vllm_config -from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils.math_utils import cdiv from vllm.v1.attention.backend import ( # type: ignore AttentionBackend, @@ -40,6 +39,7 @@ from vllm.v1.attention.backends.registry import ( # type: ignore from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill from vllm_ascend.attention.utils import ( @@ -392,7 +392,7 @@ class AscendAttentionBackendImpl(AttentionImpl): ): if using_paged_attention(num_tokens, vllm_config): # Paged Attention update logic - if forward_context.is_draft_model: + if _EXTRA_CTX.is_draft_model: graph_params = get_draft_graph_params() else: graph_params = get_graph_params() @@ -444,7 +444,7 @@ class AscendAttentionBackendImpl(AttentionImpl): event.record(update_stream) else: # FIA update logic - if forward_context.is_draft_model: + if _EXTRA_CTX.is_draft_model: graph_params = get_draft_graph_params() attn_metadata = draft_attn_metadatas attn_keys = list(attn_metadata[0].keys()) @@ -462,7 +462,7 @@ class AscendAttentionBackendImpl(AttentionImpl): num_layers = len(attn_keys) if num_layers == 0: return - if forward_context.is_draft_model: + if _EXTRA_CTX.is_draft_model: attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers) attn_count = 0 with torch.npu.stream(update_stream): @@ -488,7 +488,7 @@ class AscendAttentionBackendImpl(AttentionImpl): softmax_lse, ) = param - if forward_context.is_draft_model: + if _EXTRA_CTX.is_draft_model: draft_step = attn_count // num_layers seq_lens = attn_metadata[draft_step][key].seq_lens_list actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q @@ -535,8 +535,7 @@ class AscendAttentionBackendImpl(AttentionImpl): key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata) num_tokens = attn_metadata.actual_seq_lengths_q[-1] - forward_context = get_forward_context() - if forward_context.is_draft_model: + if _EXTRA_CTX.is_draft_model: graph_params = get_draft_graph_params() else: graph_params = get_graph_params() @@ -563,7 +562,7 @@ class AscendAttentionBackendImpl(AttentionImpl): sparse_mode=3, scale=self.scale, ) - if forward_context.is_draft_model: + if _EXTRA_CTX.is_draft_model: update_draft_graph_params_workspaces(num_tokens, workspace) else: update_graph_params_workspaces(num_tokens, workspace) @@ -625,9 +624,8 @@ class AscendAttentionBackendImpl(AttentionImpl): output: torch.Tensor | None = None, ): graph_params = get_graph_params() - forward_context: ForwardContext = get_forward_context() num_tokens = query.shape[0] - if forward_context.capturing: + if _EXTRA_CTX.capturing: # Get workspace from cache or calculate it if not present. workspace = graph_params.workspaces.get(num_tokens) if workspace is None: @@ -761,11 +759,10 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata: AscendMetadata, output: torch.Tensor, ): - forward_context: ForwardContext = get_forward_context() # we inherit ForwardContext in model runner v2, when enable model # runner v2, there is not capturing attribute in forward_context, # just use getattr to avoid attribute error. - if getattr(forward_context, "capturing", False): + if _EXTRA_CTX.capturing: attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output) output[:num_tokens] = attn_output[:num_tokens] return output @@ -841,8 +838,7 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata: AscendMetadata, output: torch.Tensor | None = None, ) -> torch.Tensor: - forward_context: ForwardContext = get_forward_context() - if forward_context.capturing: + if _EXTRA_CTX.capturing: return self.full_graph_pa(query, attn_metadata, output) torch_npu._npu_paged_attention( query=query, diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 655fec1d..af23ae90 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -29,10 +29,10 @@ from vllm.distributed import ( get_decode_context_model_parallel_world_size, get_pcp_group, ) -from vllm.forward_context import ForwardContext, get_forward_context from vllm.v1.attention.backend import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.attention.attention_v1 import ( AscendAttentionBackendImpl, AscendAttentionMetadataBuilder, @@ -559,9 +559,8 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): "actual_seq_lengths": torch.arange(attn_metadata.num_decodes_flatten) + 1, } graph_params = get_graph_params() - forward_context: ForwardContext = get_forward_context() num_tokens = query.shape[0] - if forward_context.capturing: + if _EXTRA_CTX.capturing: stream = torch_npu.npu.current_stream() event = torch.npu.ExternalEvent() diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index d30ce725..2ea6c529 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -10,7 +10,6 @@ from vllm.distributed import ( get_decode_context_model_parallel_world_size, get_pcp_group, ) -from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils.math_utils import cdiv from vllm.v1.attention.backend import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec @@ -30,6 +29,7 @@ from vllm_ascend.attention.mla_v1 import ( ) # isort: on +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.attention.context_parallel.common_cp import ( AscendPCPMetadata, CPChunkedContextMetadata, @@ -294,7 +294,7 @@ class AscendMlaCPImpl(AscendMLAImpl): num_dcp_pcp_tokens=None, draft_attn_metadatas=None, ): - if forward_context.is_draft_model: + if _EXTRA_CTX.is_draft_model: graph_params = get_draft_graph_params() else: graph_params = get_graph_params() @@ -659,12 +659,11 @@ class AscendMlaCPImpl(AscendMLAImpl): "softmax_lse_flag": True, } - forward_context: ForwardContext = get_forward_context() - if forward_context.is_draft_model: + if _EXTRA_CTX.is_draft_model: graph_params = get_draft_graph_params() else: graph_params = get_graph_params() - if forward_context.capturing: + if _EXTRA_CTX.capturing: stream = torch_npu.npu.current_stream() event = torch.npu.ExternalEvent() event.wait(stream) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 34d8e250..c68d4ff0 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -6,16 +6,20 @@ import torch import torch_npu import vllm.envs as envs_vllm from vllm.config import VllmConfig, get_current_vllm_config -from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.utils.math_utils import cdiv, round_down -from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore +from vllm.v1.attention.backend import ( + AttentionBackend, # type: ignore + AttentionCGSupport, + MLAAttentionImpl, +) from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata, CPChunkedContextMetadata @@ -44,12 +48,7 @@ from vllm_ascend.ops.layer_shard_linear import ( ) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.quantization.methods import AscendW8A8LinearMethod -from vllm_ascend.utils import ( - ACL_FORMAT_FRACTAL_ND, - get_weight_prefetch_method, - maybe_trans_nz, - weak_ref_tensors, -) +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, get_weight_prefetch_method, maybe_trans_nz, weak_ref_tensors from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: @@ -737,7 +736,7 @@ class AscendMLAImpl(MLAAttentionImpl): num_dcp_pcp_tokens=None, draft_attn_metadatas=None, ): - if forward_context.is_draft_model: + if _EXTRA_CTX.is_draft_model: graph_params = get_draft_graph_params() else: graph_params = get_graph_params() @@ -769,12 +768,12 @@ class AscendMLAImpl(MLAAttentionImpl): softmax_lse, ) = param seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list - if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model: + if speculative_config and speculative_config.method == "mtp" and not _EXTRA_CTX.is_draft_model: actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q spec_multiple = speculative_config.num_speculative_tokens + 1 seq_lens_list = seq_lens_list + [0] * (num_tokens // spec_multiple - len(seq_lens_list)) actual_seq_lengths = [spec_multiple * (i + 1) for i in range(num_tokens // spec_multiple)] - elif forward_context.is_draft_model: + elif _EXTRA_CTX.is_draft_model: actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q block_table = forward_context.attn_metadata[key].decode.block_table # TODO: This is a hack and should be fixed in the future. @@ -1243,12 +1242,11 @@ class AscendMLAImpl(MLAAttentionImpl): "actual_seq_lengths": actual_seq_lengths, "actual_seq_lengths_kv": decode_meta.seq_lens_list, } - forward_context: ForwardContext = get_forward_context() - if forward_context.is_draft_model: + if _EXTRA_CTX.is_draft_model: graph_params = get_draft_graph_params() else: graph_params = get_graph_params() - if forward_context.capturing: + if _EXTRA_CTX.capturing: stream = torch_npu.npu.current_stream() event = torch.npu.ExternalEvent() @@ -1261,7 +1259,7 @@ class AscendMLAImpl(MLAAttentionImpl): workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( q_nope, k_nope, k_nope, **common_kwargs ) - if forward_context.is_draft_model: + if _EXTRA_CTX.is_draft_model: update_draft_graph_params_workspaces(num_tokens, workspace) else: update_graph_params_workspaces(num_tokens, workspace) @@ -1493,7 +1491,6 @@ class AscendMLAImpl(MLAAttentionImpl): reach_layer_for_shard_weight_series(layer) return output.fill_(0) - forward_context = get_forward_context() num_actual_tokens = self.get_num_actual_tokens(attn_metadata) assert ( attn_metadata.num_decodes is not None @@ -1505,7 +1502,7 @@ class AscendMLAImpl(MLAAttentionImpl): num_decode_tokens = attn_metadata.num_decode_tokens # Inputs and outputs may be padded for CUDA graphs output_padded = output - o_proj_input_shape = (forward_context.num_tokens, self.num_heads * self.v_head_dim) + o_proj_input_shape = (_EXTRA_CTX.num_tokens, self.num_heads * self.v_head_dim) o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device) # MLA Preprocess diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 22def96c..f7edb5fb 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -7,16 +7,20 @@ import vllm.envs as envs_vllm from torch import nn from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group -from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.triton_utils import HAS_TRITON -from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore +from vllm.v1.attention.backend import ( + AttentionBackend, # type: ignore + AttentionCGSupport, + MLAAttentionImpl, +) from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata @@ -967,10 +971,9 @@ class AscendSFAImpl(MLAAttentionImpl): output: torch.Tensor | None = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." - forward_context = get_forward_context() if attn_metadata is None: # Profiling run. - if self.enable_dsa_cp_with_layer_shard and not forward_context.in_profile_run: + if self.enable_dsa_cp_with_layer_shard and not _EXTRA_CTX.in_profile_run: for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): reach_layer_for_shard_weight_series(layer) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 93c32ce7..e22d1b5a 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -19,6 +19,8 @@ from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.platforms import current_platform +from vllm_ascend.ascend_forward_context import _EXTRA_CTX + from ..utils import weak_ref_tensors @@ -195,7 +197,7 @@ class ACLGraphWrapper: if self.vllm_config.speculative_config else False ) - if self.runtime_mode != CUDAGraphMode.FULL or not forward_context.is_draft_model or not use_eagle: + if self.runtime_mode != CUDAGraphMode.FULL or not _EXTRA_CTX.is_draft_model or not use_eagle: torch.npu.current_stream().synchronize() entry.aclgraph.replay() return entry.output diff --git a/vllm_ascend/distributed/utils.py b/vllm_ascend/distributed/utils.py index bbcf7557..740c6e5d 100644 --- a/vllm_ascend/distributed/utils.py +++ b/vllm_ascend/distributed/utils.py @@ -3,6 +3,7 @@ import torch.distributed as dist from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group from vllm.forward_context import get_forward_context +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.distributed.parallel_state import get_fc3_quant_x_group @@ -16,7 +17,7 @@ def fc3_all_gather_and_maybe_unpad_impl( x = get_fc3_quant_x_group().all_gather(x, 0) dp_metadata = forward_context.dp_metadata if dp_metadata is None: - pad_size = forward_context.pad_size + pad_size = _EXTRA_CTX.pad_size if pad_size > 0: x = x[:-pad_size] else: @@ -24,7 +25,7 @@ def fc3_all_gather_and_maybe_unpad_impl( num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype) dp_size = get_dp_group().world_size - x = x.view(dp_size, forward_context.padded_length, *x.shape[1:]) + x = x.view(dp_size, _EXTRA_CTX.padded_length, *x.shape[1:]) offset = 0 for idx in range(dp_size): num_tokens_dp = num_tokens_across_dp_cpu[idx] diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 4304931b..eb8af0d3 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -37,7 +37,7 @@ if not vllm_version_is("0.16.0"): from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.eplb.core.eplb_utils import init_eplb_config from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context @@ -148,7 +148,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device) topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) - moe_comm_method = get_forward_context().moe_comm_method + moe_comm_method = _EXTRA_CTX.moe_comm_method final_hidden_states = moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -401,12 +401,13 @@ class AscendFusedMoE(FusedMoE): # When static kernels are enabled, the forward pass runs twice (compilation + capture), # causing moe_layer_index to overflow. Wrap the index to prevent out-of-bounds errors. if self.enable_npugraph_ex_static_kernel: - forward_context.moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers)) + moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers)) + forward_context.moe_layer_index = moe_layer_index # Load balancing for token distribution among experts in dummy_run # TODO: The community only considers load balancing when DP > 1. # This approach may overlook some extreme scenarios. - enable_force_load_balance = forward_context.in_profile_run + enable_force_load_balance = _EXTRA_CTX.in_profile_run forward_context = get_forward_context() if self.multistream_overlap_gate: @@ -419,7 +420,7 @@ class AscendFusedMoE(FusedMoE): assert fc3_context.shared_experts is not None shared_out = fc3_context.shared_experts(hidden_states) # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` - moe_comm_type = forward_context.moe_comm_type + moe_comm_type = _EXTRA_CTX.moe_comm_type if ( moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} and not shared_expert_dp_enabled() @@ -442,16 +443,16 @@ class AscendFusedMoE(FusedMoE): global_num_experts=self.global_num_experts, ) - if isinstance(forward_context.moe_comm_method, AllGatherCommImpl): + if isinstance(_EXTRA_CTX.moe_comm_method, AllGatherCommImpl): topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True) topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True) set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids) - hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare( + hidden_states, router_logits, mc2_mask, context_metadata = _EXTRA_CTX.moe_comm_method.prepare( hidden_states=hidden_states, router_logits=router_logits, - replace_allreduce=forward_context.flash_comm_v1_enabled, + replace_allreduce=_EXTRA_CTX.flash_comm_v1_enabled, enable_shared_expert_dp=self.enable_shared_expert_dp, quant_type=self.quant_type, ) @@ -509,7 +510,7 @@ class AscendFusedMoE(FusedMoE): self.load_counter.add_(1) else: self.moe_load.add_(local_load) - routed_out = forward_context.moe_comm_method.finalize( + routed_out = _EXTRA_CTX.moe_comm_method.finalize( hidden_states=fused_experts_results.routed_out, reduce_results=self.reduce_results, context_metadata=context_metadata, @@ -670,8 +671,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): # NOTE: This is exactly the opposite of # `maybe_all_reduce_tensor_model_parallel` - forward_context = get_forward_context() - moe_comm_type = forward_context.moe_comm_type + moe_comm_type = _EXTRA_CTX.moe_comm_type if ( moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} and not shared_expert_dp_enabled() diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 14d72531..01f29ba8 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -19,11 +19,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass import torch -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.prepare_finalize import ( PrepareAndFinalize, @@ -135,7 +134,7 @@ class MoECommMethod(ABC): # Check constraints assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8] - moe_comm_method = get_forward_context().moe_comm_method + moe_comm_method = _EXTRA_CTX.moe_comm_method assert moe_comm_method is not None, "Missing communication context" before_dispatch_evt = torch.npu.current_stream().record_event() diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index aea25579..2033b61b 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -18,10 +18,9 @@ import torch import torch_npu from torch.nn.functional import pad -from vllm.forward_context import get_forward_context from vllm.triton_utils import HAS_TRITON -from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.device.mxfp_compat import ( ensure_mxfp8_moe_available, @@ -147,7 +146,7 @@ def quant_apply_mlp( weight_prefetch_method = get_weight_prefetch_method() if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states) - is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 + is_mc2 = _EXTRA_CTX.moe_comm_type == MoECommType.MC2 if w1_scale_bias is None and w1_offset is None and is_mc2: if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant: # gmm1: gate_up_proj & act_fn: swiglu diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index e7b4cf98..6c7358aa 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -26,10 +26,10 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl from vllm_ascend.quantization.methods.base import QuantType from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable @@ -242,8 +242,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All): """ self.replace_allreduce = replace_allreduce self.enable_shared_expert_dp = enable_shared_expert_dp - forward_context = get_forward_context() - mc2_mask = forward_context.mc2_mask + mc2_mask = _EXTRA_CTX.mc2_mask if self.tp_size > 1: # Also slice mc2_mask split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0) @@ -252,7 +251,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All): padded_hidden_states_shape = hidden_states.shape if not self.replace_allreduce: self.num_tokens, _ = hidden_states.shape - target_pad_length = forward_context.padded_num_tokens + target_pad_length = _EXTRA_CTX.padded_num_tokens pad_size = target_pad_length - self.num_tokens # Pad if necessary (unless shared expert DP is enabled) @@ -367,8 +366,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): """ self.enable_shared_expert_dp = enable_shared_expert_dp if self.moe_config.dp_size > 1: - forward_context = get_forward_context() - max_tokens_across_dp = forward_context.max_tokens_across_dp + max_tokens_across_dp = _EXTRA_CTX.max_tokens_across_dp self.num_tokens = hidden_states.shape[0] pad_size = max_tokens_across_dp - self.num_tokens @@ -381,8 +379,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): router_logits = self.moe_config.dp_group.all_gather(router_logits, 0) if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: - forward_context = get_forward_context() - max_tokens_across_pcp = forward_context.max_tokens_across_pcp + max_tokens_across_pcp = _EXTRA_CTX.max_tokens_across_pcp self.num_tokens_pcp = hidden_states.shape[0] pad_size = max_tokens_across_pcp - self.num_tokens_pcp diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index c7510941..db898645 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -57,9 +57,9 @@ from vllm.distributed import ( tensor_model_parallel_reduce_scatter, ) from vllm.distributed.parallel_state import get_tp_group -from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.distributed.parallel_state import ( get_flashcomm2_odp_group, get_flashcomm2_otp_group, @@ -311,8 +311,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): input_parallel = splitted_input[tp_rank].contiguous() # padding for all-to-all - forward_context = get_forward_context() - num_padding_tokens = forward_context.pad_size + num_padding_tokens = _EXTRA_CTX.pad_size if num_padding_tokens > 0: input_parallel = nn.functional.pad(input_parallel, (0, 0, 0, num_padding_tokens)) @@ -368,7 +367,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): else: output = output_parallel - if not forward_context.flash_comm_v1_enabled: + if not _EXTRA_CTX.flash_comm_v1_enabled: # flashcomm1 not enabled output = get_tp_group().all_gather(output, 0) if num_padding_tokens > 0: @@ -514,9 +513,8 @@ class SequenceRowParallelOp(CustomRowParallelOp): def matmul_and_reduce(self, input_parallel: torch.Tensor, bias_: Parameter | None) -> torch.Tensor: assert self.quant_method is not None try: - forward_context = get_forward_context() - flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled - mmrs_fusion = forward_context.mmrs_fusion + flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled + mmrs_fusion = _EXTRA_CTX.mmrs_fusion except AssertionError: flash_comm_v1_enabled = False mmrs_fusion = False @@ -527,7 +525,7 @@ class SequenceRowParallelOp(CustomRowParallelOp): output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_) return tensor_model_parallel_all_reduce(output_parallel) - pad_size = forward_context.pad_size + pad_size = _EXTRA_CTX.pad_size if pad_size > 0 and not (enable_dsa_cp() and "o_proj" in self.layer.prefix): x = F.pad(x, (0, 0, 0, pad_size)) diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index 09e0ee36..c7ae6046 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -32,6 +32,7 @@ from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import _EXTRA_CTX class IndexerWrapper(nn.Module): @@ -144,7 +145,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): kv_cache: torch.Tensor | None = None, attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: - need_gather_q_kv = get_forward_context().flash_comm_v1_enabled + need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled output_shape = hidden_states.shape # FIXME: This does not seem right, should make sure the buffer is fixed output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index b803cd42..3701941d 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -13,7 +13,7 @@ from vllm.distributed import ( from vllm.forward_context import get_forward_context from vllm.utils.torch_utils import direct_register_custom_op -from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.ops.rotary_embedding import rope_forward_oot from vllm_ascend.ops.triton.muls_add import muls_add_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch @@ -22,12 +22,12 @@ from vllm_ascend.utils import npu_stream_switch, prefetch_stream def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: try: - forward_context = get_forward_context() + get_forward_context() except AssertionError: return residual if x.size(0) != residual.size(0): - pad_size = forward_context.pad_size + pad_size = _EXTRA_CTX.pad_size if pad_size > 0: residual = F.pad(residual, (0, 0, 0, pad_size)) tp_size = get_tensor_model_parallel_world_size() @@ -43,12 +43,12 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c except AssertionError: return x - flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled + flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled if flash_comm_v1_enabled and label: dp_metadata = forward_context.dp_metadata if dp_metadata is None or not is_ep_comm: x = tensor_model_parallel_all_gather(x, 0) - pad_size = forward_context.pad_size + pad_size = _EXTRA_CTX.pad_size if pad_size > 0: x = x[:-pad_size] else: @@ -57,7 +57,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype) dp_size = get_dp_group().world_size - x = x.view(dp_size, forward_context.padded_length, *x.shape[1:]) + x = x.view(dp_size, _EXTRA_CTX.padded_length, *x.shape[1:]) offset = 0 for idx in range(dp_size): num_tokens_dp = num_tokens_across_dp_cpu[idx] @@ -79,7 +79,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor dp_metadata = forward_context.dp_metadata if dp_metadata is None or not is_ep_comm: - pad_size = forward_context.pad_size + pad_size = _EXTRA_CTX.pad_size if pad_size > 0: x = F.pad(x, (0, 0, 0, pad_size)) return tensor_model_parallel_reduce_scatter(x, 0) @@ -87,7 +87,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor # padding dp_size = get_dp_group().world_size num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu - padded_x = torch.empty((dp_size, forward_context.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype) + padded_x = torch.empty((dp_size, _EXTRA_CTX.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype) offset = 0 for idx in range(dp_size): num_tokens_dp = num_tokens_across_dp_cpu[idx] @@ -98,7 +98,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor: - if get_forward_context().flash_comm_v1_enabled and label: + if _EXTRA_CTX.flash_comm_v1_enabled and label: return torch.empty( (x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype ) @@ -107,7 +107,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor: - if get_forward_context().flash_comm_v1_enabled: + if _EXTRA_CTX.flash_comm_v1_enabled: return torch.empty( (x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype ) @@ -138,11 +138,10 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None: def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor: - forward_context = get_forward_context() - moe_comm_type = forward_context.moe_comm_type + moe_comm_type = _EXTRA_CTX.moe_comm_type if ( moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} - or forward_context.flash_comm_v1_enabled + or _EXTRA_CTX.flash_comm_v1_enabled ): return final_hidden_states else: @@ -163,7 +162,7 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str) forward_context = get_forward_context() self = forward_context.no_compile_layers[layer_name] num_tokens = input_parallel.size(0) - if forward_context.flash_comm_v1_enabled: + if _EXTRA_CTX.flash_comm_v1_enabled: num_tokens = num_tokens // self.tp_size output = torch.empty( size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index d97eed6c..9e81fa3c 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -21,7 +21,6 @@ import os import torch import torch_npu from vllm.config import get_current_vllm_config -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, MRotaryEmbedding, @@ -31,6 +30,7 @@ from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb from vllm.triton_utils import HAS_TRITON +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model @@ -240,8 +240,8 @@ class AscendRotaryEmbedding(RotaryEmbedding): is_neox_style = self.is_neox_style if is_neox_style_override is not None: is_neox_style = is_neox_style_override - is_draft_model = get_forward_context().is_draft_model - flash_comm_v1_enabled = get_forward_context().flash_comm_v1_enabled + is_draft_model = _EXTRA_CTX.is_draft_model + flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled if is_draft_model and self.use_mtp and flash_comm_v1_enabled: positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions.contiguous(), True) return torch.ops.vllm.npu_rotary_embedding( diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index 2464b51c..d51b711f 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -6,6 +6,7 @@ from vllm.config import get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm_ascend.ascend_config import WeightPrefetchConfig +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ops.linear import AscendQKVParallelLinear, AscendRowParallelLinear from vllm_ascend.utils import is_moe_model @@ -95,11 +96,11 @@ class WeightPrefetchMethod: if not self.moe.is_active_this_forward: return forward_context = get_forward_context() - if not forward_context or forward_context.model_instance is None: + if not forward_context or _EXTRA_CTX.model_instance is None: return # layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm. - weight = forward_context.model_instance.model.layers[forward_context.layer_idx - 1].mlp.experts.w13_weight + weight = _EXTRA_CTX.model_instance.model.layers[_EXTRA_CTX.layer_idx - 1].mlp.experts.w13_weight # type: ignore # type: ignore weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0) torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=None, max_weight_size=int(weight_size)) @@ -122,9 +123,7 @@ class WeightPrefetchMethod: except AssertionError: return self.mlp.is_active_this_forward = ( - forward_context.layer_idx is not None - and forward_context.num_tokens is not None - and forward_context.num_tokens < 500 + _EXTRA_CTX.layer_idx is not None and _EXTRA_CTX.num_tokens is not None and _EXTRA_CTX.num_tokens < 500 ) if not self.mlp.is_active_this_forward: return @@ -144,9 +143,9 @@ class WeightPrefetchMethod: # start point of gate_up_proj weight prefetch if curr_layer_prefix.split(".")[-2] == "self_attn": - model_instance = forward_context.model_instance + model_instance = _EXTRA_CTX.model_instance layer_idx = int(curr_layer_prefix.split(".")[2]) - weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight + weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight # type: ignore if self.mlp_pre_version_compatibale_config: weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0) else: @@ -156,12 +155,12 @@ class WeightPrefetchMethod: if weight_size > MAX_PREFETCH_WEIGHT_SIZE: weight_size = MAX_PREFETCH_WEIGHT_SIZE torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size)) - forward_context.prefetch_mlp_gate_up_proj = True + _EXTRA_CTX.prefetch_mlp_gate_up_proj = True def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext): - layer_idx = forward_context.layer_idx - model_instance = forward_context.model_instance - weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight + layer_idx = _EXTRA_CTX.layer_idx + model_instance = _EXTRA_CTX.model_instance + weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight # type: ignore if self.mlp_pre_version_compatibale_config: weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0) else: @@ -171,22 +170,22 @@ class WeightPrefetchMethod: if weight_size > MAX_PREFETCH_WEIGHT_SIZE: weight_size = MAX_PREFETCH_WEIGHT_SIZE torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size)) - forward_context.prefetch_mlp_down_proj = True - forward_context.layer_idx += 1 + _EXTRA_CTX.prefetch_mlp_down_proj = True + _EXTRA_CTX.layer_idx = layer_idx + 1 # type: ignore def maybe_prefetch_mlp_weight_postprocess(self, stop_flag: torch.Tensor): if not self.mlp.is_active_this_forward: return try: - forward_context = get_forward_context() + get_forward_context() except AssertionError: return - if forward_context.prefetch_mlp_gate_up_proj or forward_context.prefetch_mlp_down_proj: + if _EXTRA_CTX.prefetch_mlp_gate_up_proj or _EXTRA_CTX.prefetch_mlp_down_proj: torch.ops.vllm.prefetch_postprocess(stop_flag) - forward_context.prefetch_mlp_gate_up_proj = False - forward_context.prefetch_mlp_down_proj = False + _EXTRA_CTX.prefetch_mlp_gate_up_proj = False + _EXTRA_CTX.prefetch_mlp_down_proj = False def maybe_prefetch_mla_or_sla_weight_in_current_stream( self, diff --git a/vllm_ascend/patch/worker/patch_v2_eagle.py b/vllm_ascend/patch/worker/patch_v2_eagle.py index 3b83f937..d3e2af36 100644 --- a/vllm_ascend/patch/worker/patch_v2_eagle.py +++ b/vllm_ascend/patch/worker/patch_v2_eagle.py @@ -153,7 +153,7 @@ def propose( # FIXME(woosuk): This is UNSAFE!! attn_metadata = build_attn_metadata( - attn_metadata_builders=self.attn_metadata_builders, + attn_groups=self.attn_groups, num_reqs=num_reqs, num_tokens=num_reqs, query_start_loc_gpu=query_start_loc, diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 8e44cdec..bdf12811 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -589,11 +589,12 @@ class NPUPlatform(Platform): if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER: return {} + # is_draft_model will be removed later, so we set it to False temporarily. + is_draft_model = False moe_comm_type = select_moe_comm_method( num_tokens, vllm_config, - # is_draft_model will be removed later, so we set it to False temporarily. - is_draft_model=False, + is_draft_model=is_draft_model, ) moe_comm_method = get_moe_comm_method(moe_comm_type) @@ -620,7 +621,7 @@ class NPUPlatform(Platform): # TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None - pad_size = None + pad_size = 0 padded_length = None if flash_comm_v1_enabled or flashcomm_v2_enabled: pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size @@ -657,6 +658,7 @@ class NPUPlatform(Platform): "padded_length": padded_length, "max_tokens_across_dp": max_tokens_across_dp, "mc2_mask": mc2_mask, + "is_draft_model": is_draft_model, } @staticmethod diff --git a/vllm_ascend/quantization/methods/w4a16.py b/vllm_ascend/quantization/methods/w4a16.py index f30ff88e..bb3bc3da 100644 --- a/vllm_ascend/quantization/methods/w4a16.py +++ b/vllm_ascend/quantization/methods/w4a16.py @@ -21,9 +21,9 @@ from typing import Any import torch import torch_npu from vllm.config import get_current_vllm_config -from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ops.fused_moe.experts_selector import select_experts from .base import AscendMoEScheme @@ -215,7 +215,7 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme): topk_ids = topk_ids.to(torch.int32) topk_weights = topk_weights.to(x.dtype) - moe_comm_method = get_forward_context().moe_comm_method + moe_comm_method = _EXTRA_CTX.moe_comm_method return moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight_packed, diff --git a/vllm_ascend/quantization/methods/w4a8.py b/vllm_ascend/quantization/methods/w4a8.py index 8a5ebca2..0ebeafc5 100644 --- a/vllm_ascend/quantization/methods/w4a8.py +++ b/vllm_ascend/quantization/methods/w4a8.py @@ -23,9 +23,9 @@ import torch import torch_npu from vllm.config import get_current_vllm_config from vllm.distributed import get_ep_group -from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz @@ -375,7 +375,7 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): topk_weights = topk_weights.to(x.dtype) - moe_comm_method = get_forward_context().moe_comm_method + moe_comm_method = _EXTRA_CTX.moe_comm_method return moe_comm_method.fused_experts( hidden_states=x, w1=[layer.w13_weight], diff --git a/vllm_ascend/quantization/methods/w8a8_dynamic.py b/vllm_ascend/quantization/methods/w8a8_dynamic.py index d9d838ec..66596629 100644 --- a/vllm_ascend/quantization/methods/w8a8_dynamic.py +++ b/vllm_ascend/quantization/methods/w8a8_dynamic.py @@ -22,11 +22,10 @@ import torch import torch_npu from vllm.config import CompilationMode, get_current_vllm_config from vllm.distributed import get_ep_group -from vllm.forward_context import get_forward_context import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.flash_common3_context import get_flash_common3_context from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute @@ -234,10 +233,9 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): assert topk_weights is not None topk_weights = topk_weights.to(self.in_dtype) - moe_comm_method = get_forward_context().moe_comm_method + moe_comm_method = _EXTRA_CTX.moe_comm_method fused_scale_flag = ( - get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 - and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1 + _EXTRA_CTX.moe_comm_type == MoECommType.FUSED_MC2 and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1 ) if self.dynamic_eplb: w1 = layer.w13_weight_list diff --git a/vllm_ascend/quantization/methods/w8a8_mxfp8.py b/vllm_ascend/quantization/methods/w8a8_mxfp8.py index d3859f1b..1961e168 100644 --- a/vllm_ascend/quantization/methods/w8a8_mxfp8.py +++ b/vllm_ascend/quantization/methods/w8a8_mxfp8.py @@ -22,9 +22,9 @@ import torch import torch_npu from vllm.config import CompilationMode, get_current_vllm_config from vllm.distributed import get_ep_group -from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.device.mxfp_compat import ( FLOAT8_E8M0FNU_DTYPE, ensure_mxfp8_linear_available, @@ -187,7 +187,7 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme): topk_weights = topk_weights.to(x.dtype) - moe_comm_method = get_forward_context().moe_comm_method + moe_comm_method = _EXTRA_CTX.moe_comm_method return moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index b9d4e3bd..c567b161 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -34,7 +34,7 @@ from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.ascend_forward_context import _EXTRA_CTX, set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata @@ -398,7 +398,7 @@ class AscendEagleProposer(EagleProposer): num_tokens=num_tokens, ) forward_context = get_forward_context() - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing: + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not _EXTRA_CTX.capturing: self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata) def _propose( @@ -784,8 +784,8 @@ class AscendEagleProposer(EagleProposer): input_batch_size = num_input_tokens if (self.method == "mtp" or self.use_cuda_graph) else batch_size forward_context = get_forward_context() - forward_context.num_tokens = input_batch_size - forward_context.num_accept_tokens = batch_size + _EXTRA_CTX.num_tokens = input_batch_size + _EXTRA_CTX.num_accept_tokens = batch_size for draft_step in range(self.num_speculative_tokens - 1): # Reset MOE layer index for each draft step iteration @@ -1361,15 +1361,14 @@ class AscendEagleProposer(EagleProposer): hidden_states: torch.Tensor, positions: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - forward_context = get_forward_context() if self.method == "mtp": - if forward_context.flash_comm_v1_enabled: + if _EXTRA_CTX.flash_comm_v1_enabled: hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states) positions = positions.unsqueeze(-1) positions = torch.ops.vllm.maybe_pad_and_reduce(positions) positions = positions.squeeze(-1) else: - if forward_context.flash_comm_v1_enabled: + if _EXTRA_CTX.flash_comm_v1_enabled: hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states) return hidden_states, positions @@ -1388,8 +1387,7 @@ class AscendEagleProposer(EagleProposer): if hidden_states is not None: hidden_states = last_hidden_states else: - forward_context = get_forward_context() - if forward_context.flash_comm_v1_enabled: + if _EXTRA_CTX.flash_comm_v1_enabled: last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( last_hidden_states.contiguous(), True ) diff --git a/vllm_ascend/worker/v2/README.md b/vllm_ascend/worker/v2/README.md index 976372aa..1c1309e6 100644 --- a/vllm_ascend/worker/v2/README.md +++ b/vllm_ascend/worker/v2/README.md @@ -5,4 +5,5 @@ This directory contains the new model runner which is under active development. please see [Model Runner V2](https://github.com/vllm-project/vllm-ascend/issues/5208) to get specific plans. -supported vllm version: main@1339784 +supported vllm version: main@4034c3d32e30d01639459edd3ab486f56993876d +related PR: diff --git a/vllm_ascend/worker/v2/aclgraph_utils.py b/vllm_ascend/worker/v2/aclgraph_utils.py index bb37653c..a99a1518 100644 --- a/vllm_ascend/worker/v2/aclgraph_utils.py +++ b/vllm_ascend/worker/v2/aclgraph_utils.py @@ -19,16 +19,25 @@ from contextlib import contextmanager from typing import Any +import numpy as np import torch import torch.nn as nn +import vllm from vllm.config import VllmConfig -from vllm.v1.attention.backend import AttentionMetadataBuilder +from vllm.config.compilation import CUDAGraphMode +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.logger import logger from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager -from vllm.v1.worker.gpu.cudagraph_utils import prepare_inputs_to_capture as prepare_inputs_to_capture_gpu from vllm.v1.worker.gpu.input_batch import InputBuffers +from vllm.v1.worker.gpu.model_states.interface import ModelState +from vllm.v1.worker.utils import AttentionGroup +from vllm_ascend.ascend_forward_context import _EXTRA_CTX +from vllm_ascend.compilation.acl_graph import set_graph_params, update_full_graph_params +from vllm_ascend.worker.v2.attn_utils import build_attn_metadata from vllm_ascend.worker.v2.utils import torch_cuda_wrapper @@ -38,44 +47,134 @@ class AclGraphManager(CudaGraphManager): def __init__( self, vllm_config: VllmConfig, - use_mrope: bool, + use_aux_hidden_state_outputs: bool, device: torch.device, + model_runner: Any, # NPUModelRunner type, in case circular import, so we pass it as Any ): - with torch_cuda_wrapper(): - super().__init__(vllm_config, use_mrope, device) + # set model runner attribute, so we can access attributes model runner + # when call `run_fullgraph` method in CudaGraphManager, + # then we don't need to # copy `execute_model` method in `NPUModelRunner` class. + self.model_runner = model_runner + super().__init__( + vllm_config, + use_aux_hidden_state_outputs, + device, + ) + # vllm-ascend need to update graph params of attention backend. + # so we need to set graph params before capture full graph. + if super().needs_capture(): + set_graph_params(self.cudagraph_sizes) + + def _capture_full_graph( + self, + num_tokens: int, + num_reqs: int, + model: nn.Module, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None, + num_tokens_across_dp: torch.Tensor, + attn_metadata: dict[str, Any] | None, + slot_mappings: dict[str, torch.Tensor] | None, + has_lora: bool = False, + ) -> None: + """Override _capture_full_graph because we need to set capturing=True in forward context.""" + # set capturing=True in before model forward. + model = ModelWithContext(model) + return super()._capture_full_graph( + num_tokens, + num_reqs, + model, + input_ids, + positions, + inputs_embeds, + num_tokens_across_dp, + attn_metadata, + slot_mappings, + has_lora, + ) def capture_graph( self, num_tokens: int, + capture_cg_mode: CUDAGraphMode, model: nn.Module, + model_state: ModelState, input_buffers: InputBuffers, block_tables: BlockTables, - attn_metadata_builders: list[AttentionMetadataBuilder], + attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, + has_lora: bool = False, + uniform_decode: bool = False, ) -> None: with torch_cuda_wrapper(), prepare_capture_inputs_wrapper(): super().capture_graph( num_tokens, + capture_cg_mode, model, + model_state, input_buffers, block_tables, - attn_metadata_builders, + attn_groups, kv_cache_config, + has_lora, + uniform_decode, ) + def run_fullgraph(self, num_tokens: int) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """Override run_fullgraph to update full graph params in run_fullgraph.""" + logger.info_once(f"run_fullgraph with num_tokens={num_tokens}") + ret = super().run_fullgraph(num_tokens) + assert self.model_runner.cudagraph_and_dp_padding is not None + + positions = self.model_runner.input_buffers.positions[:num_tokens] + _num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = ( + self.model_runner.cudagraph_and_dp_padding + ) + cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode) + + with set_forward_context( + self.model_runner.input_batch.attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + num_tokens_across_dp=num_tokens_across_dp, + batch_descriptor=None, # Full graph model don't need batch_descriptor + slot_mapping=self.model_runner.input_batch.slot_mappings, + ): + forward_context = get_forward_context() + update_full_graph_params( + # FIXME(Ronald1995): support hybrid attn backend + list(self.model_runner.attn_backends.values())[0], + self.model_runner.update_stream, + forward_context, + num_tokens, + self.vllm_config, + self.model_runner.speculative_config, + positions.shape[0], + ) + return ret + + def is_uniform_decode( + self, + num_reqs: int, + num_tokens: int, + max_query_len: int, + ): + return (max_query_len == self.uniform_decode_query_len) and (num_tokens == max_query_len * num_reqs) + @contextmanager def prepare_capture_inputs_wrapper(): """Context manager to override input preparation for NPU graph capture.""" # TODO(Ronald1995): make prepare_inputs_to_capture as static method # in CudaGraphManager. - global prepare_inputs_to_capture_gpu + ori = vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture try: - ori_func = prepare_inputs_to_capture_gpu - prepare_inputs_to_capture_gpu = prepare_inputs_to_capture + vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture = prepare_inputs_to_capture yield finally: - prepare_inputs_to_capture_gpu = ori_func + vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture = ori def prepare_inputs_to_capture( @@ -83,9 +182,66 @@ def prepare_inputs_to_capture( num_tokens: int, input_buffers: InputBuffers, block_tables: BlockTables, - attn_metadata_builders: list[AttentionMetadataBuilder], + attn_groups: list[list[AttentionGroup]], max_model_len: int, kv_cache_config: KVCacheConfig, -) -> dict[str, Any]: - # TODO(Ronald1995): Implement NPU specific input preparation. - return {} + uniform_decode_query_len: int = 0, +) -> tuple[dict[str, Any], dict[str, torch.Tensor]]: + if uniform_decode_query_len > 0: + num_tokens_per_req = uniform_decode_query_len + else: + num_tokens_per_req = num_tokens // num_reqs + + query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req + query_start_loc_np[-1] = num_tokens + query_start_loc_cpu = torch.from_numpy(query_start_loc_np) + input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu + input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens + query_start_loc = input_buffers.query_start_loc[: num_reqs + 1] + + # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens + # rather than max_model_len. + input_buffers.seq_lens[:num_reqs] = num_tokens + input_buffers.seq_lens[num_reqs:] = 0 + input_buffers.seq_lens_cpu[:num_reqs] = num_tokens + input_buffers.seq_lens_cpu[num_reqs:] = 0 + + input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens + input_buffers.dcp_local_seq_lens[num_reqs:] = 0 + + input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables] + slot_mappings = block_tables.slot_mappings[:, :num_tokens] + slot_mappings_by_layer = build_slot_mappings_by_layer(slot_mappings, kv_cache_config) + + attn_metadata = build_attn_metadata( + attn_groups=attn_groups, + num_reqs=num_reqs, + num_tokens=num_tokens, + query_start_loc_gpu=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + max_query_len=num_tokens_per_req, + seq_lens=input_buffers.seq_lens, + max_seq_len=max_model_len, + block_tables=input_block_tables, + slot_mappings=slot_mappings, + kv_cache_config=kv_cache_config, + seq_lens_np=input_buffers.seq_lens_np, + ) + return attn_metadata, slot_mappings_by_layer + + +class ModelWithContext(nn.Module): + """Define a wrapper model to inject forward context. + so we can inherit vllm's CudaGraphManager._capture_full_graph. + """ + + def __init__(self, original_model): + super().__init__() + self.original_model = original_model + + def forward(self, *args, **kwargs): + # In warmup phase, capturing=False by default. + # when capturing, we need to set capturing=True in forward context. + _EXTRA_CTX.capturing = True + + return self.original_model(*args, **kwargs) diff --git a/vllm_ascend/worker/v2/attn_utils.py b/vllm_ascend/worker/v2/attn_utils.py index 46a92094..d3b4aa80 100644 --- a/vllm_ascend/worker/v2/attn_utils.py +++ b/vllm_ascend/worker/v2/attn_utils.py @@ -23,8 +23,8 @@ from typing import Any import numpy as np import torch from vllm.config import VllmConfig -from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig +from vllm.v1.worker.utils import AttentionGroup from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState @@ -43,7 +43,7 @@ def get_attn_mask_builder(device: torch.device): def build_attn_metadata( *, - attn_metadata_builders: list[AttentionMetadataBuilder], + attn_groups: list[list[AttentionGroup]], num_reqs: int, num_tokens: int, query_start_loc_gpu: torch.Tensor, @@ -54,6 +54,7 @@ def build_attn_metadata( block_tables: Sequence[torch.Tensor], slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig, + dcp_local_seq_lens: torch.Tensor | None = None, # extra attributes for ascend npus. seq_lens_np: np.ndarray | None = None, num_computed_tokens_cpu: torch.Tensor | None = None, @@ -72,9 +73,6 @@ def build_attn_metadata( if seq_lens_np is None: seq_lens_np = np.full(num_reqs, max_seq_len, dtype=np.int32) seq_lens_cpu = torch.from_numpy(seq_lens_np)[:num_reqs] - # torch_npu._reshape_and_cache operator requires slot_mappings to - # be torch.int32. - slot_mappings = slot_mappings.to(torch.int32) attn_metadata: dict[str, Any] = {} kv_cache_groups = kv_cache_config.kv_cache_groups @@ -100,13 +98,14 @@ def build_attn_metadata( max_seq_len=max_seq_len, ) - attn_metadata_builder = attn_metadata_builders[i] - metadata = attn_metadata_builder.build( - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, # type: ignore - ) - for layer_name in kv_cache_spec.layer_names: - attn_metadata[layer_name] = metadata + for attn_group in attn_groups[i]: + attn_metadata_builder = attn_group.get_metadata_builder(0) + metadata = attn_metadata_builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = metadata return attn_metadata diff --git a/vllm_ascend/worker/v2/block_table.py b/vllm_ascend/worker/v2/block_table.py new file mode 100644 index 00000000..165612da --- /dev/null +++ b/vllm_ascend/worker/v2/block_table.py @@ -0,0 +1,58 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/block_table.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +from vllm.v1.worker.gpu.block_table import BlockTables + + +class AscendBlockTables(BlockTables): + """Block table for Ascend NPUs.""" + + def __init__( + self, + block_sizes: list[int], + max_num_reqs: int, + max_num_batched_tokens: int, + max_model_len: int, + device: torch.device, + cp_size: int = 1, + cp_rank: int = 0, + cp_interleave: int = 1, + ): + super().__init__( + block_sizes, + max_num_reqs, + max_num_batched_tokens, + max_model_len, + device, + cp_size, + cp_rank, + cp_interleave, + ) + # because we will override these attribute, delete these attribute to + # make sure it's collected by python gc immediately. + del self.slot_mappings + # vllm-ascend' reshape_and_cache function requires slot_mappings to be int32. + # so we need to redefine slot_mappings to be int32. + self.slot_mappings: torch.Tensor = torch.zeros( + self.num_kv_cache_groups, + self.max_num_batched_tokens, + dtype=torch.int32, + device=self.device, + ) diff --git a/vllm_ascend/worker/v2/input_batch.py b/vllm_ascend/worker/v2/input_batch.py index 31e5d90e..9a1ccea0 100644 --- a/vllm_ascend/worker/v2/input_batch.py +++ b/vllm_ascend/worker/v2/input_batch.py @@ -22,6 +22,8 @@ import numpy as np import torch from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers +from vllm_ascend.attention.attention_v1 import AscendAttentionState + class AscendInputBuffers(InputBuffers): """Input buffers for Ascend NPUs.""" @@ -37,6 +39,16 @@ class AscendInputBuffers(InputBuffers): max_num_tokens, device, ) + del self.query_start_loc + + # NOTE: For FULL mode we change +1 to +2 to reserve extra space for padding. + # See _pad_query_start_loc_for_fia. + self.query_start_loc: torch.Tensor = torch.zeros( + max_num_reqs + 2, + dtype=torch.int32, + device=device, + ) + # Create seq_lens_cpu and seq_lens_np. # npu's attention backend still needs seq_lens on CPU side. self.seq_lens_cpu: torch.Tensor = torch.zeros( @@ -56,6 +68,8 @@ class AscendInputBatch(InputBatch): # Create seq_lens_np. # npu's attention backend still needs seq_lens on CPU side. seq_lens_np: np.ndarray + # attn_state is used to build attention metadata. + attn_state: AscendAttentionState | None = None @classmethod def make_dummy( @@ -79,4 +93,11 @@ class AscendInputBatch(InputBatch): input_buffers.seq_lens_np[num_reqs:] = 0 seq_lens_np = input_buffers.seq_lens_np[:num_reqs] input_batch.seq_lens_np = seq_lens_np + # A dummy run for dp or memory profiling. + # When dummy run for dp, num_tokens is set to 1, + # so attn_state is set to DecodeOnly. + # when dummy run for memory profiling, + # attention metadata isn't needed, + # we can also set attn_state to AscendAttentionState.DecodeOnly. + input_batch.attn_state = AscendAttentionState.DecodeOnly return cls(**asdict(input_batch), seq_lens_np=seq_lens_np) diff --git a/vllm_ascend/worker/v2/model_runner.py b/vllm_ascend/worker/v2/model_runner.py index 88188d29..76fd613a 100644 --- a/vllm_ascend/worker/v2/model_runner.py +++ b/vllm_ascend/worker/v2/model_runner.py @@ -17,12 +17,16 @@ # This file is a part of the vllm-ascend project. # +import functools + import numpy as np import torch +import vllm from vllm.config import VllmConfig -from vllm.logger import init_logger +from vllm.config.compilation import CUDAGraphMode +from vllm.sequence import IntermediateTensors from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer +from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu from vllm.v1.worker.gpu.input_batch import ( combine_sampled_and_draft_tokens, @@ -32,23 +36,23 @@ from vllm.v1.worker.gpu.input_batch import ( ) from vllm.v1.worker.gpu.model_runner import GPUModelRunner +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.utils import set_weight_prefetch_method from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager -from vllm_ascend.worker.v2.attn_utils import build_attn_metadata, build_attn_state +from vllm_ascend.worker.v2.attn_utils import build_attn_state from vllm_ascend.worker.v2.input_batch import AscendInputBatch, AscendInputBuffers from vllm_ascend.worker.v2.sample.sampler import AscendSampler from vllm_ascend.worker.v2.spec_decode import init_speculator from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator from vllm_ascend.worker.v2.states import AscendRequestState -from vllm_ascend.worker.v2.utils import torch_cuda_wrapper - -logger = init_logger(__name__) +from vllm_ascend.worker.v2.utils import block_table_wrapper, model_states_wrapper, torch_cuda_wrapper class NPUModelRunner(GPUModelRunner): """Model runner for Ascend NPUs.""" def __init__(self, vllm_config: VllmConfig, device: torch.device): - with torch_cuda_wrapper(): + with torch_cuda_wrapper(), block_table_wrapper(), model_states_wrapper(): super().__init__(vllm_config, device) # because we will override these attribute, delete these attribute to @@ -62,8 +66,9 @@ class NPUModelRunner(GPUModelRunner): # NPU specific initializations can be added below. self.cudagraph_manager: AclGraphManager = AclGraphManager( self.vllm_config, - self.uses_mrope, + self.use_aux_hidden_state_outputs, self.device, + self, ) # we define AscendEagleSpeculator in vllm_ascend.worker.v2.spec_decode.eagle @@ -96,6 +101,7 @@ class NPUModelRunner(GPUModelRunner): max_num_reqs=self.max_num_reqs, vocab_size=self.vocab_size, device=self.device, + req_states=self.req_states, logprobs_mode=self.model_config.logprobs_mode, num_speculative_tokens=self.num_speculative_steps + 1, ) @@ -113,6 +119,59 @@ class NPUModelRunner(GPUModelRunner): pin_memory=True, ) + # Ascend-specific configurations + self.ascend_config = get_ascend_config() + # set this just the same as model runner v1, or it will raise error. + set_weight_prefetch_method(self.ascend_config.weight_prefetch_config) + + # we need to update full graph params in run_fullgraph, + # so create a stream to update full graph params. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.update_stream: torch.npu.Stream = torch.npu.Stream() + + # we need to use return value of `get_cudagraph_and_dp_padding` + # to set forward_context in `run_fullgraph`. + # so we can inherit `execute_model` method. + self.cudagraph_and_dp_padding: tuple[int, torch.Tensor | None, int] | None = None + + # we need to use input_batch to set forward_context in run_fullgraph. + # so we can inherit `execute_model` method. + self.input_batch: AscendInputBatch | None = None + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + dummy_run: bool = False, + skip_attn_for_dummy_run: bool = False, + ) -> ModelRunnerOutput | IntermediateTensors | None: + """Override GPUModelRunner.execute_model for Ascend NPUs by there reasons: + 1. when run fullgraph, we need to use ret value of `get_cudagraph_and_dp_padding` + to set forward_context in `run_fullgraph`. + """ + + # use closure to store return value of get_cudagraph_and_dp_padding in model runner. + def wrapper(func): + @functools.wraps(func) + def inner(*args, **kwargs): + self.cudagraph_and_dp_padding = func(*args, **kwargs) + return self.cudagraph_and_dp_padding + + return inner + + if self.cudagraph_and_dp_padding is None: + vllm.v1.worker.gpu.model_runner.get_cudagraph_and_dp_padding = wrapper( + vllm.v1.worker.gpu.model_runner.get_cudagraph_and_dp_padding + ) + + return super().execute_model( + scheduler_output, + intermediate_tensors, + dummy_run, + skip_attn_for_dummy_run, + ) + def prepare_inputs( self, scheduler_output: SchedulerOutput, @@ -185,33 +244,40 @@ class NPUModelRunner(GPUModelRunner): idx_mapping, total_num_logits, cu_num_logits, max_expand_len ) - # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] - block_tables = self.block_tables.gather_block_tables(idx_mapping) - # Get query_start_loc. - query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32) + # NOTE: For FULL mode we change +1 to +2 to reserve extra space for padding. + # See _pad_query_start_loc_for_fia. + query_start_loc_np = np.empty(self.max_num_reqs + 2, dtype=np.int32) query_start_loc_np[0] = 0 np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1]) # Pad for full CUDA graph mode. # Some attention backends like FA3 require query_start_loc to be non-decreasing. query_start_loc_np[num_reqs + 1 :] = num_tokens + + # This is only required for vllm-ascend. + query_start_loc_np, num_reqs_padded = self._pad_query_start_loc_for_fia( + num_tokens_padded=num_tokens_after_padding, + num_tokens=num_tokens, + num_reqs=num_reqs, + query_start_loc_np=query_start_loc_np, + max_query_len=max(scheduler_output.num_scheduled_tokens.values()), + ) async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc) - query_start_loc_np = query_start_loc_np[: num_reqs + 1] - query_start_loc_cpu = torch.from_numpy(query_start_loc_np) + query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] - max_query_len = num_scheduled_tokens.max().item() - # Get prefill tokens. - prepare_prefill_inputs( - self.input_buffers.input_ids, - self.req_states.next_prefill_tokens, - idx_mapping, - query_start_loc, - self.req_states.prefill_token_ids.gpu, - self.req_states.prefill_len.gpu, - self.req_states.num_computed_tokens.gpu, - ) + # Get prefill tokens if any. + if self.req_states.any_prefills(idx_mapping_np): + prepare_prefill_inputs( + self.input_buffers.input_ids, + self.req_states.next_prefill_tokens, + idx_mapping, + query_start_loc, + self.req_states.all_token_ids.gpu, + self.req_states.prefill_len.gpu, + self.req_states.num_computed_tokens.gpu, + ) # Prepare positions and seq_lens. prepare_pos_seq_lens( @@ -223,14 +289,8 @@ class NPUModelRunner(GPUModelRunner): ) seq_lens = self.input_buffers.seq_lens[:num_reqs] - # Prepare M-RoPE positions. - if self.uses_mrope: - self.mrope_states.prepare_mrope_positions( - idx_mapping, - query_start_loc, - self.req_states.prefill_len.gpu, - self.req_states.num_computed_tokens.gpu, - ) + # Pad for full CUDA graph mode. + self.input_buffers.seq_lens_np[num_reqs_padded:] = 0 # Some input token ids are directly read from the last sampled tokens # and draft tokens. Also, get the logits indices to sample tokens from. @@ -246,43 +306,12 @@ class NPUModelRunner(GPUModelRunner): total_num_logits, ) - # Compute slot mappings: [num_kv_cache_groups, num_tokens] - slot_mappings = self.block_tables.compute_slot_mappings( - idx_mapping, query_start_loc, self.input_buffers.positions[:num_tokens] - ) - # Layer name -> slot mapping. - slot_mappings_by_layer = build_slot_mappings_by_layer(slot_mappings, self.kv_cache_config) - # Layer name -> attention metadata. - # TODO(Ronald1995): try to add a new method `build_attn_metadata` in - # vllm gpu_model_runner_v2, maybe we don't overwrite `prepare_inputs` - # method like this. - attn_metadata = build_attn_metadata( - attn_metadata_builders=self.attn_metadata_builders, - num_reqs=num_reqs, - num_tokens=num_tokens, - query_start_loc_gpu=query_start_loc, - query_start_loc_cpu=query_start_loc_cpu, - max_query_len=max_query_len, - seq_lens=self.input_buffers.seq_lens, - max_seq_len=self.max_model_len, - block_tables=block_tables, - slot_mappings=slot_mappings, - kv_cache_config=self.kv_cache_config, - # extra attributes for ascend npus. - seq_lens_np=self.input_buffers.seq_lens_np, - num_computed_tokens_cpu=self.req_states.num_computed_tokens_cpu[idx_mapping_cpu], - attn_state=attn_state, - ) - input_ids = self.input_buffers.input_ids[:num_tokens_after_padding] positions = self.input_buffers.positions[:num_tokens_after_padding] - mrope_positions = None - if self.uses_mrope: - mrope_positions = self.mrope_states.mrope_positions - mrope_positions = mrope_positions[:, :num_tokens_after_padding] - return AscendInputBatch( + + self.input_batch = AscendInputBatch( req_ids=req_ids, - num_reqs=num_reqs, + num_reqs=num_reqs_padded, idx_mapping=idx_mapping, idx_mapping_np=idx_mapping_np, expanded_idx_mapping=expanded_idx_mapping, @@ -294,18 +323,18 @@ class NPUModelRunner(GPUModelRunner): query_start_loc=query_start_loc, query_start_loc_np=query_start_loc_np, seq_lens=seq_lens, + dcp_local_seq_lens=None, # TODO(Ronald1995): support cp. input_ids=input_ids, positions=positions, - mrope_positions=mrope_positions, - inputs_embeds=None, - attn_metadata=attn_metadata, - slot_mappings=slot_mappings_by_layer, logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, has_structured_output_reqs=scheduler_output.has_structured_output_requests, + # extra attributes for ascend npus. seq_lens_np=self.input_buffers.seq_lens_np, + attn_state=attn_state, ) + return self.input_batch def postprocess( self, @@ -352,7 +381,7 @@ class NPUModelRunner(GPUModelRunner): self.req_states.num_computed_tokens_cpu[req_index] = self.num_computed_tokens_cpu[req_index] # update seq_lens_cpu - for i, req_id in enumerate(req_ids): + for i, req_id in enumerate(req_ids): # type: ignore req_index = self.req_states.req_id_to_index[req_id] num_computed_tokens = self.req_states.num_computed_tokens_cpu[req_index] self.input_buffers.seq_lens_cpu[i] = num_computed_tokens + num_scheduled_tokens[req_id] @@ -361,3 +390,44 @@ class NPUModelRunner(GPUModelRunner): # TODO(Ronald1995): just define the method in case calling error in # worker, implement it in the future. pass + + def _pad_query_start_loc_for_fia( + self, + num_tokens_padded: int, + num_tokens: int, + num_reqs: int, + query_start_loc_np: np.ndarray, + max_query_len: int, + ) -> tuple[np.ndarray, int]: + """ + This function is only designed to satisfied the constraint that when the layout is TND, + the first dimension of `hidden_states` must equal the last element of `actual_seq_lengths_q`. + """ + assert self.cudagraph_and_dp_padding is not None + _num_tokens_after_padding, _num_tokens_across_dp, synced_cudagraph_mode = self.cudagraph_and_dp_padding + cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode) + if cudagraph_runtime_mode != CUDAGraphMode.FULL: + return query_start_loc_np, num_reqs + uniform_decode_query_len = self.cudagraph_manager.uniform_decode_query_len + is_uniform_decode = self.cudagraph_manager.is_uniform_decode( + num_reqs=num_reqs, + num_tokens=num_tokens, + max_query_len=max_query_len, + ) + if is_uniform_decode: + # Uniform-batch case: num_reqs must be no greater than num_reqs_padded + num_reqs_padded = num_tokens_padded // uniform_decode_query_len + + last_loc = query_start_loc_np[num_reqs] + query_start_loc_np[num_reqs + 1 : num_reqs_padded + 1] = ( + np.arange(1, num_reqs_padded + 1 - num_reqs) * uniform_decode_query_len + last_loc + ) + else: + # Mixed-batch case: num_reqs must equal num_reqs_padded + num_reqs_padded = min(num_tokens_padded, self.max_num_reqs) + + # Insert a dummy request instead of setting query_start_loc[num_reqs] = num_tokens_padded directly + query_start_loc_np[num_reqs_padded + 1] = num_tokens_padded + num_reqs_padded = num_reqs_padded + 1 + + return query_start_loc_np, num_reqs_padded diff --git a/vllm_ascend/worker/v2/model_states/__init__.py b/vllm_ascend/worker/v2/model_states/__init__.py new file mode 100644 index 00000000..5a6163c0 --- /dev/null +++ b/vllm_ascend/worker/v2/model_states/__init__.py @@ -0,0 +1,34 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_states/__init__.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +import torch.nn as nn +from vllm.config import VllmConfig +from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache + + +def init_asecnd_model_state( + vllm_config: VllmConfig, + model: nn.Module, + encoder_cache: EncoderCache | None, + device: torch.device, +): + from vllm_ascend.worker.v2.model_states.default import AscendModelState + + return AscendModelState(vllm_config, model, encoder_cache, device) diff --git a/vllm_ascend/worker/v2/model_states/default.py b/vllm_ascend/worker/v2/model_states/default.py new file mode 100644 index 00000000..bde4d7dc --- /dev/null +++ b/vllm_ascend/worker/v2/model_states/default.py @@ -0,0 +1,62 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_states/default.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from typing import Any + +import torch +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.worker.gpu.model_states.default import DefaultModelState +from vllm.v1.worker.utils import AttentionGroup + +from vllm_ascend.worker.v2.attn_utils import build_attn_metadata +from vllm_ascend.worker.v2.input_batch import AscendInputBatch + + +class AscendModelState(DefaultModelState): + """Model state for Ascend NPUs.""" + + def prepare_attn( + self, + input_batch: AscendInputBatch, + block_tables: tuple[torch.Tensor, ...], + slot_mappings: torch.Tensor, + attn_groups: list[list[AttentionGroup]], + kv_cache_config: KVCacheConfig, + ) -> dict[str, Any]: + """Override prepare_attn method because `build_attn_metadata` is different from vllm.""" + query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np) + max_query_len = input_batch.num_scheduled_tokens.max().item() + attn_metadata = build_attn_metadata( + attn_groups=attn_groups, + num_reqs=input_batch.num_reqs, + num_tokens=input_batch.num_tokens, + query_start_loc_gpu=input_batch.query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + max_query_len=max_query_len, + seq_lens=input_batch.seq_lens, + max_seq_len=self.max_model_len, + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=kv_cache_config, + dcp_local_seq_lens=input_batch.dcp_local_seq_lens, + # extra attributes for ascend npus. + seq_lens_np=input_batch.seq_lens_np, + attn_state=input_batch.attn_state, + ) + return attn_metadata diff --git a/vllm_ascend/worker/v2/sample/sampler.py b/vllm_ascend/worker/v2/sample/sampler.py index 4bbb0fa3..27ed359e 100644 --- a/vllm_ascend/worker/v2/sample/sampler.py +++ b/vllm_ascend/worker/v2/sample/sampler.py @@ -16,9 +16,6 @@ # import numpy as np import torch -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p -from vllm.v1.worker.gpu.sample.gumbel import apply_temperature -from vllm.v1.worker.gpu.sample.min_p import apply_min_p from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample @@ -53,21 +50,23 @@ class AscendSampler(Sampler): self.num_speculative_tokens, ) + # Apply bad words masking in place. + self.bad_words_state.apply_bad_words( + logits, + idx_mapping, + idx_mapping_np, + input_ids, + expanded_local_pos, + ) + # Apply temperature in place. - apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu) + self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np) - # Apply min_p in place if any request has a non-zero min_p. - do_min_p = self.sampling_states.do_min_p(idx_mapping_np) - if do_min_p: - apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu) + # Apply min_p in place. + self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np) - # Apply top_k and/or top_p. This might return a new tensor. - do_top_k = self.sampling_states.do_top_k(idx_mapping_np) - top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None - do_top_p = self.sampling_states.do_top_p(idx_mapping_np) - top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None - if do_top_k or do_top_p: - logits = apply_top_k_top_p(logits, top_k, top_p) + # Apply top_k and/or top_p. This might or might not return a new tensor. + logits = self.sampling_states.apply_top_k_top_p(logits, idx_mapping, idx_mapping_np) # Sample the next token. sampled = gumbel_sample( diff --git a/vllm_ascend/worker/v2/spec_decode/eagle.py b/vllm_ascend/worker/v2/spec_decode/eagle.py index c9ee8c9f..1ab557c2 100644 --- a/vllm_ascend/worker/v2/spec_decode/eagle.py +++ b/vllm_ascend/worker/v2/spec_decode/eagle.py @@ -23,7 +23,7 @@ import torch import vllm from vllm.config import VllmConfig from vllm.v1.worker.gpu.input_batch import InputBatch -from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator +from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.worker.v2.attn_utils import build_attn_metadata diff --git a/vllm_ascend/worker/v2/states.py b/vllm_ascend/worker/v2/states.py index f1c24702..9ef7acd9 100644 --- a/vllm_ascend/worker/v2/states.py +++ b/vllm_ascend/worker/v2/states.py @@ -56,13 +56,13 @@ class AscendRequestState(RequestState): self, req_id, prompt_len, - prefill_token_ids, + all_token_ids, num_computed_tokens, ): super().add_request( req_id, prompt_len, - prefill_token_ids, + all_token_ids, num_computed_tokens, ) req_idx = self.req_id_to_index[req_id] diff --git a/vllm_ascend/worker/v2/utils.py b/vllm_ascend/worker/v2/utils.py index 0c28b2fb..349002b8 100644 --- a/vllm_ascend/worker/v2/utils.py +++ b/vllm_ascend/worker/v2/utils.py @@ -1,6 +1,11 @@ from contextlib import contextmanager import torch +import vllm +from vllm.logger import logger + +from vllm_ascend.worker.v2.block_table import AscendBlockTables +from vllm_ascend.worker.v2.model_states import init_asecnd_model_state @contextmanager @@ -15,6 +20,34 @@ def torch_cuda_wrapper(): torch.cuda.CUDAGraph = torch.npu.NPUGraph torch.cuda.graph = torch.npu.graph torch.cuda.synchronize = torch.npu.synchronize + torch.cuda.set_stream = torch.npu.set_stream + torch.cuda.current_device = torch.npu.current_device + torch.cuda.mem_get_info = torch.npu.mem_get_info + logger.info_once("Wrapping torch.cuda with torch.npu.") + yield + finally: + pass + + +@contextmanager +def block_table_wrapper(): + try: + # vllm-ascend need to initialize slot mapping as torch.int32 dtype, + # but vllm default is torch.int64 dtype. + vllm.v1.worker.gpu.model_runner.BlockTables = AscendBlockTables + logger.info_once("Wrapping BlockTables with AscendBlockTables.") + yield + finally: + pass + + +@contextmanager +def model_states_wrapper(): + try: + # prepare_attn in AscendModelState is different from vllm, + # we need to override init_model_state. + vllm.v1.worker.gpu.model_runner.init_model_state = init_asecnd_model_state + logger.info_once("Wrapping init_model_state with init_asecnd_model_state.") yield finally: pass