[Feature] support aclgraph for model runner v2 (#7110)
### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -184,7 +184,7 @@ def test_token_dispatcher_with_all_gather_quant(
|
|||||||
):
|
):
|
||||||
context_mock = MagicMock()
|
context_mock = MagicMock()
|
||||||
context_mock.fused_moe_state = 0
|
context_mock.fused_moe_state = 0
|
||||||
with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context",
|
with patch("vllm_ascend.ascend_forward_context.get_forward_context",
|
||||||
return_value=context_mock):
|
return_value=context_mock):
|
||||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||||
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
|
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
|
||||||
|
|||||||
@@ -85,3 +85,29 @@ def test_egale_spec_decoding(
|
|||||||
},
|
},
|
||||||
) as runner:
|
) as runner:
|
||||||
runner.model.generate(prompts, sampling_params)
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
|
@pytest.mark.parametrize("enforce_eager", [False])
|
||||||
|
@pytest.mark.parametrize("compilation_config", [{"cudagraph_mode": "FULL_DECODE_ONLY"}, {}])
|
||||||
|
@patch.dict(os.environ, {"VLLM_USE_V2_MODEL_RUNNER": "1"})
|
||||||
|
def test_qwen3_dense_graph_mode(
|
||||||
|
model: str,
|
||||||
|
max_tokens: int,
|
||||||
|
enforce_eager: bool,
|
||||||
|
) -> None:
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
|
||||||
|
with VllmRunner(
|
||||||
|
model,
|
||||||
|
max_model_len=1024,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
) as runner:
|
||||||
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
|||||||
|
|
||||||
@patch("torch_npu._npu_reshape_and_cache")
|
@patch("torch_npu._npu_reshape_and_cache")
|
||||||
@patch("torch_npu._npu_flash_attention")
|
@patch("torch_npu._npu_flash_attention")
|
||||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||||
def test_forward_prefill_310(
|
def test_forward_prefill_310(
|
||||||
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
|
self, mock_get_forward_context, mock_npu_npu_flash_attention, mock_npu_reshape_and_cache
|
||||||
):
|
):
|
||||||
@@ -105,7 +105,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
|||||||
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
|
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
|
||||||
@patch("torch_npu._npu_reshape_and_cache")
|
@patch("torch_npu._npu_reshape_and_cache")
|
||||||
@patch("torch_npu._npu_paged_attention_splitfuse")
|
@patch("torch_npu._npu_paged_attention_splitfuse")
|
||||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||||
def test_forward_chunked_prefill_310(
|
def test_forward_chunked_prefill_310(
|
||||||
self,
|
self,
|
||||||
mock_get_forward_context,
|
mock_get_forward_context,
|
||||||
@@ -140,7 +140,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
|||||||
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
|
@patch("torch_npu.npu_format_cast", return_value=torch.randn((1, 128, 16, 16), dtype=torch.float16))
|
||||||
@patch("torch_npu._npu_reshape_and_cache")
|
@patch("torch_npu._npu_reshape_and_cache")
|
||||||
@patch("torch_npu._npu_paged_attention_splitfuse")
|
@patch("torch_npu._npu_paged_attention_splitfuse")
|
||||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||||
def test_forward_prefill_cache_hit_310(
|
def test_forward_prefill_cache_hit_310(
|
||||||
self,
|
self,
|
||||||
mock_get_forward_context,
|
mock_get_forward_context,
|
||||||
@@ -175,7 +175,7 @@ class TestAscendAttentionBackendImpl310(TestBase):
|
|||||||
@patch("vllm_ascend.attention.attention_v1.using_paged_attention")
|
@patch("vllm_ascend.attention.attention_v1.using_paged_attention")
|
||||||
@patch("torch_npu._npu_paged_attention")
|
@patch("torch_npu._npu_paged_attention")
|
||||||
@patch("torch_npu._npu_reshape_and_cache")
|
@patch("torch_npu._npu_reshape_and_cache")
|
||||||
@patch("vllm_ascend.attention.attention_v1.get_forward_context")
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||||
def test_forward_paged_attention_310(
|
def test_forward_paged_attention_310(
|
||||||
self, mock_get_forward_context, mock_npu_reshape_and_cache, mock_paged_attention, mock_using_paged_attention
|
self, mock_get_forward_context, mock_npu_reshape_and_cache, mock_paged_attention, mock_using_paged_attention
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ class TestAscendAttentionCPImpl(TestBase):
|
|||||||
@patch('torch_npu.npu_attention_update')
|
@patch('torch_npu.npu_attention_update')
|
||||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||||
@patch(
|
@patch(
|
||||||
'vllm_ascend.attention.context_parallel.attention_cp.get_forward_context'
|
'vllm_ascend.ascend_forward_context.get_forward_context'
|
||||||
)
|
)
|
||||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||||
def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp,
|
def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp,
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
def test_forward_fused_infer_attention(
|
def test_forward_fused_infer_attention(
|
||||||
self, mock_get_forward_context,
|
self, mock_get_forward_context,
|
||||||
mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache):
|
mock_npu_fused_infer_attention_score, mock_npu_reshape_and_cache):
|
||||||
@@ -248,7 +248,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
@patch('vllm_ascend.attention.attention_v1.using_paged_attention')
|
@patch('vllm_ascend.attention.attention_v1.using_paged_attention')
|
||||||
@patch('torch_npu._npu_paged_attention')
|
@patch('torch_npu._npu_paged_attention')
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
def test_forward_paged_attention(self, mock_get_forward_context,
|
def test_forward_paged_attention(self, mock_get_forward_context,
|
||||||
mock_npu_reshape_and_cache,
|
mock_npu_reshape_and_cache,
|
||||||
mock_paged_attention,
|
mock_paged_attention,
|
||||||
@@ -279,7 +279,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
mock_paged_attention.assert_called_once()
|
mock_paged_attention.assert_called_once()
|
||||||
assert output.shape == (4, 8 * 64)
|
assert output.shape == (4, 8 * 64)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
|
def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
|
||||||
@@ -311,7 +311,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
mock_fused_infer_attention_score.assert_called_once()
|
mock_fused_infer_attention_score.assert_called_once()
|
||||||
assert output.shape == (10, 8, 64)
|
assert output.shape == (10, 8, 64)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('torch_npu._npu_paged_attention')
|
@patch('torch_npu._npu_paged_attention')
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
|
|||||||
@@ -449,7 +449,7 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
self.assertEqual(result.shape[1], N)
|
self.assertEqual(result.shape[1], N)
|
||||||
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
|
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||||
@patch('torch_npu.npu_attention_update')
|
@patch('torch_npu.npu_attention_update')
|
||||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||||
|
|||||||
@@ -929,7 +929,7 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
self.assertEqual(out.shape, prefix_out.shape)
|
self.assertEqual(out.shape, prefix_out.shape)
|
||||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
|
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
|
||||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||||
def test_forward_decode_without_graph(self,
|
def test_forward_decode_without_graph(self,
|
||||||
@@ -1095,7 +1095,7 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
|
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
|
||||||
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
|
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||||
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
|
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
|
||||||
mock_get_forward_context):
|
mock_get_forward_context):
|
||||||
|
|||||||
@@ -161,12 +161,13 @@ class TestACLGraphWrapper(TestBase):
|
|||||||
vllm_config=self.mock_vllm_config,
|
vllm_config=self.mock_vllm_config,
|
||||||
runtime_mode=CUDAGraphMode.NONE)
|
runtime_mode=CUDAGraphMode.NONE)
|
||||||
|
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
def test_call_with_none_runtime_mode(self, mock_envs,
|
def test_call_with_none_runtime_mode(self, mock_envs,
|
||||||
mock_current_platform,
|
mock_current_platform,
|
||||||
mock_get_forward_context):
|
mock_get_forward_context, mock_get_forward_context_2):
|
||||||
"""Test __call__ method when runtime mode is NONE"""
|
"""Test __call__ method when runtime mode is NONE"""
|
||||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
@@ -185,16 +186,19 @@ class TestACLGraphWrapper(TestBase):
|
|||||||
self.mock_runnable.assert_called_once_with("arg1", "arg2")
|
self.mock_runnable.assert_called_once_with("arg1", "arg2")
|
||||||
self.assertEqual(result, "test_output")
|
self.assertEqual(result, "test_output")
|
||||||
|
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
def test_call_with_mismatched_runtime_mode(self, mock_envs,
|
def test_call_with_mismatched_runtime_mode(self, mock_envs,
|
||||||
mock_current_platform,
|
mock_current_platform,
|
||||||
mock_get_forward_context):
|
mock_get_forward_context,
|
||||||
|
mock_get_forward_context_2):
|
||||||
"""Test __call__ method when runtime mode doesn't match wrapper mode"""
|
"""Test __call__ method when runtime mode doesn't match wrapper mode"""
|
||||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
mock_get_forward_context.return_value = self.mock_forward_context
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
|
||||||
|
|
||||||
wrapper = ACLGraphWrapper(
|
wrapper = ACLGraphWrapper(
|
||||||
@@ -214,18 +218,20 @@ class TestACLGraphWrapper(TestBase):
|
|||||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
)
|
)
|
||||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||||
def test_call_capture_graph_first_time(
|
def test_call_capture_graph_first_time(
|
||||||
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
||||||
mock_current_platform, mock_get_forward_context,
|
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
||||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||||
"""Test __call__ method captures graph for the first time"""
|
"""Test __call__ method captures graph for the first time"""
|
||||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
mock_get_forward_context.return_value = self.mock_forward_context
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
# Mock torch.npu.NPUGraph
|
# Mock torch.npu.NPUGraph
|
||||||
@@ -284,6 +290,7 @@ class TestACLGraphWrapper(TestBase):
|
|||||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
)
|
)
|
||||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||||
@@ -291,12 +298,15 @@ class TestACLGraphWrapper(TestBase):
|
|||||||
def test_call_replay_graph(self, mock_weak_ref_tensors,
|
def test_call_replay_graph(self, mock_weak_ref_tensors,
|
||||||
mock_compilation_counter, mock_envs,
|
mock_compilation_counter, mock_envs,
|
||||||
mock_current_platform, mock_get_forward_context,
|
mock_current_platform, mock_get_forward_context,
|
||||||
|
mock_get_forward_context_2,
|
||||||
mock_validate_cudagraph_capturing_enabled,
|
mock_validate_cudagraph_capturing_enabled,
|
||||||
mock_torch):
|
mock_torch):
|
||||||
"""Test __call__ method replays graph when already captured"""
|
"""Test __call__ method replays graph when already captured"""
|
||||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
mock_get_forward_context.return_value = self.mock_forward_context
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||||
|
|
||||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
self.mock_forward_context.is_draft_model = False
|
self.mock_forward_context.is_draft_model = False
|
||||||
|
|
||||||
@@ -358,17 +368,19 @@ class TestACLGraphWrapper(TestBase):
|
|||||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
)
|
)
|
||||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||||
def test_call_with_debug_mode_input_address_check(
|
def test_call_with_debug_mode_input_address_check(
|
||||||
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
||||||
mock_get_forward_context,
|
mock_get_forward_context,mock_get_forward_context_2,
|
||||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||||
"""Test __call__ method with debug mode input address checking"""
|
"""Test __call__ method with debug mode input address checking"""
|
||||||
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
||||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
mock_get_forward_context.return_value = self.mock_forward_context
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
self.mock_forward_context.is_draft_model = False
|
self.mock_forward_context.is_draft_model = False
|
||||||
|
|
||||||
@@ -413,17 +425,19 @@ class TestACLGraphWrapper(TestBase):
|
|||||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
)
|
)
|
||||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||||
def test_call_with_debug_mode_input_address_mismatch(
|
def test_call_with_debug_mode_input_address_mismatch(
|
||||||
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
||||||
mock_get_forward_context,
|
mock_get_forward_context,mock_get_forward_context_2,
|
||||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||||
"""Test __call__ method with debug mode input address mismatch raises AssertionError"""
|
"""Test __call__ method with debug mode input address mismatch raises AssertionError"""
|
||||||
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
||||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
mock_get_forward_context.return_value = self.mock_forward_context
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
# Mock torch.npu.NPUGraph
|
# Mock torch.npu.NPUGraph
|
||||||
@@ -471,6 +485,7 @@ class TestACLGraphWrapper(TestBase):
|
|||||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
)
|
)
|
||||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||||
@@ -478,12 +493,13 @@ class TestACLGraphWrapper(TestBase):
|
|||||||
@patch('vllm_ascend.compilation.acl_graph.patch')
|
@patch('vllm_ascend.compilation.acl_graph.patch')
|
||||||
def test_call_capture_graph_with_gc_disable(
|
def test_call_capture_graph_with_gc_disable(
|
||||||
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
|
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
|
||||||
mock_envs, mock_current_platform, mock_get_forward_context,
|
mock_envs, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
||||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||||
"""Test __call__ method captures graph with gc_disable option enabled"""
|
"""Test __call__ method captures graph with gc_disable option enabled"""
|
||||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
mock_get_forward_context.return_value = self.mock_forward_context
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
# Enable gc_disable option
|
# Enable gc_disable option
|
||||||
@@ -545,18 +561,20 @@ class TestACLGraphWrapper(TestBase):
|
|||||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||||
)
|
)
|
||||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||||
def test_call_capture_graph_with_weak_ref_output(
|
def test_call_capture_graph_with_weak_ref_output(
|
||||||
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
||||||
mock_current_platform, mock_get_forward_context,
|
mock_current_platform, mock_get_forward_context,mock_get_forward_context_2,
|
||||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||||
"""Test __call__ method captures graph with weak_ref_output option enabled"""
|
"""Test __call__ method captures graph with weak_ref_output option enabled"""
|
||||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
mock_get_forward_context.return_value = self.mock_forward_context
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
# Enable weak_ref_output option
|
# Enable weak_ref_output option
|
||||||
@@ -610,16 +628,18 @@ class TestACLGraphWrapper(TestBase):
|
|||||||
self.assertEqual(result, "weak_ref_output")
|
self.assertEqual(result, "weak_ref_output")
|
||||||
|
|
||||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||||
@patch('vllm_ascend.compilation.acl_graph.logger')
|
@patch('vllm_ascend.compilation.acl_graph.logger')
|
||||||
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
|
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
|
||||||
mock_current_platform,
|
mock_current_platform,
|
||||||
mock_get_forward_context):
|
mock_get_forward_context,mock_get_forward_context_2):
|
||||||
"""Test __call__ method captures graph with debug logging enabled"""
|
"""Test __call__ method captures graph with debug logging enabled"""
|
||||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||||
mock_get_forward_context.return_value = self.mock_forward_context
|
mock_get_forward_context.return_value = self.mock_forward_context
|
||||||
|
mock_get_forward_context_2.return_value = self.mock_forward_context
|
||||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
# Enable debug logging
|
# Enable debug logging
|
||||||
@@ -757,10 +777,11 @@ class TestPCPDCPGraphParams(TestBase):
|
|||||||
self.graph_params.events[4].append(mock_event)
|
self.graph_params.events[4].append(mock_event)
|
||||||
self.graph_params.handles[4].append(MagicMock())
|
self.graph_params.handles[4].append(MagicMock())
|
||||||
|
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('torch.npu.graph_task_update_end', )
|
@patch('torch.npu.graph_task_update_end', )
|
||||||
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
@patch('torch.npu.graph_task_update_begin', MagicMock())
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
|
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
|
||||||
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
|
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context):
|
||||||
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
|
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
|
||||||
block_table = torch.zeros(2, 5, dtype=torch.long)
|
block_table = torch.zeros(2, 5, dtype=torch.long)
|
||||||
seq_lens = torch.tensor([4, 4])
|
seq_lens = torch.tensor([4, 4])
|
||||||
@@ -790,6 +811,7 @@ class TestPCPDCPGraphParams(TestBase):
|
|||||||
forward_context = MagicMock()
|
forward_context = MagicMock()
|
||||||
forward_context.attn_metadata = {"attn_layer_0": metadata}
|
forward_context.attn_metadata = {"attn_layer_0": metadata}
|
||||||
forward_context.is_draft_model = False
|
forward_context.is_draft_model = False
|
||||||
|
mock_context.return_value = forward_context
|
||||||
|
|
||||||
num_heads = 256
|
num_heads = 256
|
||||||
scale = 0.1
|
scale = 0.1
|
||||||
|
|||||||
@@ -119,11 +119,9 @@ def mock_dist_env(mocker: MockerFixture):
|
|||||||
return_value=(torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]), None, 0)), \
|
return_value=(torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]), None, 0)), \
|
||||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
|
patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
|
||||||
return_value=mock_forward_context_obj), \
|
return_value=mock_forward_context_obj), \
|
||||||
patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context',
|
patch('vllm_ascend.ascend_forward_context.get_forward_context',
|
||||||
return_value=mock_forward_context_obj), \
|
return_value=mock_forward_context_obj), \
|
||||||
patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3), \
|
patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3), \
|
||||||
patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
|
|
||||||
return_value=mock_forward_context_obj), \
|
|
||||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
||||||
return_value=None), \
|
return_value=None), \
|
||||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||||
@@ -298,7 +296,7 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
|
|
||||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
|
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
|
||||||
return_value=MagicMock())
|
return_value=MagicMock())
|
||||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||||
return_value=AscendDeviceType.A3)
|
return_value=AscendDeviceType.A3)
|
||||||
@patch('torch_npu.npu_grouped_matmul')
|
@patch('torch_npu.npu_grouped_matmul')
|
||||||
@@ -407,7 +405,7 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.HAS_TRITON', False)
|
@patch('vllm_ascend.ops.fused_moe.moe_mlp.HAS_TRITON', False)
|
||||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
|
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
|
||||||
return_value=MagicMock())
|
return_value=MagicMock())
|
||||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch('torch_npu.npu_grouped_matmul')
|
@patch('torch_npu.npu_grouped_matmul')
|
||||||
@patch('torch_npu.npu_swiglu')
|
@patch('torch_npu.npu_swiglu')
|
||||||
@patch('torch_npu.npu_dynamic_quant')
|
@patch('torch_npu.npu_dynamic_quant')
|
||||||
@@ -513,7 +511,7 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method",
|
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method",
|
||||||
return_value=MagicMock())
|
return_value=MagicMock())
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||||
@patch("torch_npu.npu_grouped_matmul")
|
@patch("torch_npu.npu_grouped_matmul")
|
||||||
@patch("torch_npu.npu_swiglu")
|
@patch("torch_npu.npu_swiglu")
|
||||||
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
|
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
|
||||||
|
|||||||
@@ -121,9 +121,10 @@ class TestAscendMultiHeadLatentAttention(TestBase):
|
|||||||
@patch("vllm_ascend.ops.mla.get_ascend_config")
|
@patch("vllm_ascend.ops.mla.get_ascend_config")
|
||||||
@patch("vllm_ascend.ops.mla.get_tensor_model_parallel_world_size")
|
@patch("vllm_ascend.ops.mla.get_tensor_model_parallel_world_size")
|
||||||
@patch("vllm_ascend.ops.mla.get_forward_context")
|
@patch("vllm_ascend.ops.mla.get_forward_context")
|
||||||
def test_forward(self, mock_get_forward_context, mock_tp_size,
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
|
def test_forward(self, mock_get_forward_context_2, mock_get_forward_context, mock_tp_size,
|
||||||
mock_ascend_config, mock_get_vllm_config,
|
mock_ascend_config, mock_get_vllm_config,
|
||||||
mock_mla_forward):
|
mock_mla_forward,):
|
||||||
mock_tp_size.return_value = 1
|
mock_tp_size.return_value = 1
|
||||||
mock_ascend_config.return_value.enable_shared_expert_dp = False
|
mock_ascend_config.return_value.enable_shared_expert_dp = False
|
||||||
mock_vllm_config = MagicMock(spec=VllmConfig)
|
mock_vllm_config = MagicMock(spec=VllmConfig)
|
||||||
@@ -159,6 +160,7 @@ class TestAscendMultiHeadLatentAttention(TestBase):
|
|||||||
mock_forward_context = MagicMock(spec=ForwardContext)
|
mock_forward_context = MagicMock(spec=ForwardContext)
|
||||||
mock_forward_context.flash_comm_v1_enabled = False
|
mock_forward_context.flash_comm_v1_enabled = False
|
||||||
mock_get_forward_context.return_value = mock_forward_context
|
mock_get_forward_context.return_value = mock_forward_context
|
||||||
|
mock_get_forward_context_2.return_value = mock_forward_context
|
||||||
|
|
||||||
mock_mla_forward.return_value = (3, self.hidden_size)
|
mock_mla_forward.return_value = (3, self.hidden_size)
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
self.moe_config.dp_group = MagicMock()
|
self.moe_config.dp_group = MagicMock()
|
||||||
self.moe_config.global_redundant_expert_num = 0
|
self.moe_config.global_redundant_expert_num = 0
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||||
)
|
)
|
||||||
@@ -73,7 +73,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
context_metadata=context_metadata)
|
context_metadata=context_metadata)
|
||||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||||
@@ -116,7 +116,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
context_metadata=context_metadata)
|
context_metadata=context_metadata)
|
||||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
|
||||||
)
|
)
|
||||||
@@ -155,7 +155,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
hidden_states, router_logits, False, False, QuantType.NONE)
|
hidden_states, router_logits, False, False, QuantType.NONE)
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||||
return_value=0)
|
return_value=0)
|
||||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
|
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
|
||||||
mock_tp_size):
|
mock_tp_size):
|
||||||
mock_context = MagicMock()
|
mock_context = MagicMock()
|
||||||
@@ -65,7 +65,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||||
return_value=0)
|
return_value=0)
|
||||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch("torch.distributed.all_gather")
|
@patch("torch.distributed.all_gather")
|
||||||
def test_mc2_tp_split_allgather(self, mock_all_gather,
|
def test_mc2_tp_split_allgather(self, mock_all_gather,
|
||||||
mock_get_forward_context, mock_tp_rank,
|
mock_get_forward_context, mock_tp_rank,
|
||||||
@@ -169,7 +169,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
self.assertEqual(final_result.shape[0], 2)
|
self.assertEqual(final_result.shape[0], 2)
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
||||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
|
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
|
||||||
return_value=False)
|
return_value=False)
|
||||||
def test_allgather_prepare_finalize(self, mock_enable_sp,
|
def test_allgather_prepare_finalize(self, mock_enable_sp,
|
||||||
|
|||||||
@@ -386,10 +386,11 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
set_current_vllm_config(None)
|
set_current_vllm_config(None)
|
||||||
|
|
||||||
# cpu does not support parallel-group, let alone `sp`
|
# cpu does not support parallel-group, let alone `sp`
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
||||||
**{"return_value.flash_comm_v1_enabled": False})
|
**{"return_value.flash_comm_v1_enabled": False})
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||||
def test_dummy_run_basic(self, mock_context, mock_get_context):
|
def test_dummy_run_basic(self, mock_context, mock_get_context, mock_get_context_2):
|
||||||
num_tokens = 32
|
num_tokens = 32
|
||||||
with_prefill = False
|
with_prefill = False
|
||||||
|
|
||||||
@@ -402,10 +403,11 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
self.assertTrue(self.proposer._runnable.call_count == 1)
|
self.assertTrue(self.proposer._runnable.call_count == 1)
|
||||||
|
|
||||||
# cpu does not support parallel-group, let alone `sp`
|
# cpu does not support parallel-group, let alone `sp`
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
||||||
**{"return_value.flash_comm_v1_enabled": False})
|
**{"return_value.flash_comm_v1_enabled": False})
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||||
def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
|
def test_dummy_run_with_prefill(self, mock_context, mock_get_context, mock_get_context_2):
|
||||||
mock_context.return_value.__enter__.return_value = None
|
mock_context.return_value.__enter__.return_value = None
|
||||||
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
||||||
with set_current_vllm_config(self.vllm_config):
|
with set_current_vllm_config(self.vllm_config):
|
||||||
@@ -413,11 +415,12 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
|
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
|
||||||
self.assertTrue(self.proposer._runnable.call_count == 1)
|
self.assertTrue(self.proposer._runnable.call_count == 1)
|
||||||
|
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
|
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||||
def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
|
def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
|
||||||
mock_update_full_graph_params):
|
mock_update_full_graph_params, mock_get_context_2):
|
||||||
last_use_cuda_graph = self.proposer.use_cuda_graph
|
last_use_cuda_graph = self.proposer.use_cuda_graph
|
||||||
mock_return_context = MagicMock()
|
mock_return_context = MagicMock()
|
||||||
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
@@ -425,6 +428,7 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
# cpu does not support parallel-group, let alone `sp`
|
# cpu does not support parallel-group, let alone `sp`
|
||||||
mock_return_context.flash_comm_v1_enabled = False
|
mock_return_context.flash_comm_v1_enabled = False
|
||||||
mock_get_context.return_value = mock_return_context
|
mock_get_context.return_value = mock_return_context
|
||||||
|
mock_get_context_2.return_value = mock_return_context
|
||||||
self.proposer.use_cuda_graph = True
|
self.proposer.use_cuda_graph = True
|
||||||
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
||||||
with set_current_vllm_config(self.vllm_config):
|
with set_current_vllm_config(self.vllm_config):
|
||||||
@@ -436,11 +440,12 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
mock_update_full_graph_params.assert_not_called()
|
mock_update_full_graph_params.assert_not_called()
|
||||||
self.proposer.use_cuda_graph = last_use_cuda_graph
|
self.proposer.use_cuda_graph = last_use_cuda_graph
|
||||||
|
|
||||||
|
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
|
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||||
def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
|
def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
|
||||||
mock_update_full_graph_params):
|
mock_update_full_graph_params, mock_get_context_2):
|
||||||
last_use_cuda_graph = self.proposer.use_cuda_graph
|
last_use_cuda_graph = self.proposer.use_cuda_graph
|
||||||
mock_return_context = MagicMock()
|
mock_return_context = MagicMock()
|
||||||
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
@@ -448,6 +453,7 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
# cpu does not support parallel-group, let alone `sp`
|
# cpu does not support parallel-group, let alone `sp`
|
||||||
mock_return_context.flash_comm_v1_enabled = False
|
mock_return_context.flash_comm_v1_enabled = False
|
||||||
mock_get_context.return_value = mock_return_context
|
mock_get_context.return_value = mock_return_context
|
||||||
|
mock_get_context_2.return_value = mock_return_context
|
||||||
self.proposer.use_cuda_graph = True
|
self.proposer.use_cuda_graph = True
|
||||||
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
||||||
with set_current_vllm_config(self.vllm_config):
|
with set_current_vllm_config(self.vllm_config):
|
||||||
|
|||||||
@@ -18,12 +18,11 @@ from collections.abc import Callable
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod
|
||||||
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
|
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
|
||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods
|
from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods
|
||||||
from vllm_ascend.quantization.methods.base import QuantType
|
from vllm_ascend.quantization.methods.base import QuantType
|
||||||
@@ -93,7 +92,7 @@ class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod):
|
|||||||
|
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
final_hidden_states = moe_comm_method.fused_experts(
|
final_hidden_states = moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@@ -222,9 +221,8 @@ class AscendFusedMoE310(FusedMoE):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported."
|
assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported."
|
||||||
forward_context = get_forward_context()
|
|
||||||
|
|
||||||
hidden_states, router_logits, _, context_metadata = forward_context.moe_comm_method.prepare(
|
hidden_states, router_logits, _, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
|
||||||
hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type
|
hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -246,7 +244,7 @@ class AscendFusedMoE310(FusedMoE):
|
|||||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
routed_out = forward_context.moe_comm_method.finalize(
|
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
|
||||||
hidden_states=fused_experts_results.routed_out,
|
hidden_states=fused_experts_results.routed_out,
|
||||||
reduce_results=self.reduce_results,
|
reduce_results=self.reduce_results,
|
||||||
context_metadata=context_metadata,
|
context_metadata=context_metadata,
|
||||||
|
|||||||
@@ -16,8 +16,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult
|
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult
|
||||||
|
|
||||||
from .moe_mlp import unified_apply_mlp
|
from .moe_mlp import unified_apply_mlp
|
||||||
@@ -50,7 +50,7 @@ class AllGatherCommImpl310(AllGatherCommImpl):
|
|||||||
) -> FusedExpertsResult:
|
) -> FusedExpertsResult:
|
||||||
# This method is overridden to use the 310p-specific unified_apply_mlp
|
# This method is overridden to use the 310p-specific unified_apply_mlp
|
||||||
# which provides optimized MLP computation for the 310p platform
|
# which provides optimized MLP computation for the 310p platform
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
assert moe_comm_method is not None, "Missing communication context"
|
assert moe_comm_method is not None, "Missing communication context"
|
||||||
|
|
||||||
dispatch_results = self.token_dispatcher.token_dispatch(
|
dispatch_results = self.token_dispatcher.token_dispatch(
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ from typing import Any
|
|||||||
import torch
|
import torch
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import get_ep_group
|
from vllm.distributed import get_ep_group
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
|
|
||||||
from vllm_ascend._310p.fused_moe.experts_selector import select_experts
|
from vllm_ascend._310p.fused_moe.experts_selector import select_experts
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
|
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
|
||||||
from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType
|
from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
|
|||||||
|
|
||||||
topk_weights = topk_weights.to(self.in_dtype)
|
topk_weights = topk_weights.to(self.in_dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
|
|
||||||
final_hidden_states = moe_comm_method.fused_experts(
|
final_hidden_states = moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from enum import Enum
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import vllm.envs as envs_vllm
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
|
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
|
||||||
@@ -270,3 +271,61 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||||
return moe_comm_type
|
return moe_comm_type
|
||||||
|
|
||||||
|
|
||||||
|
class _ExtraForwardContextProxy:
|
||||||
|
"""Unified forward-context access for v1/v2 model runners."""
|
||||||
|
|
||||||
|
extra_attrs = (
|
||||||
|
"capturing",
|
||||||
|
"moe_comm_type",
|
||||||
|
"moe_comm_method",
|
||||||
|
"mmrs_fusion",
|
||||||
|
"num_tokens",
|
||||||
|
"flash_comm_v1_enabled",
|
||||||
|
"flashcomm_v2_enabled",
|
||||||
|
"pad_size",
|
||||||
|
"padded_length",
|
||||||
|
"num_tokens_across_dp",
|
||||||
|
"mc2_mask",
|
||||||
|
"is_draft_model",
|
||||||
|
"prefetch_mlp_gate_up_proj",
|
||||||
|
"prefetch_mlp_down_proj",
|
||||||
|
"model_instance",
|
||||||
|
"layer_idx",
|
||||||
|
"max_tokens_across_dp",
|
||||||
|
"max_tokens_across_pcp",
|
||||||
|
"num_accept_tokens",
|
||||||
|
"in_profile_run",
|
||||||
|
"padded_num_tokens",
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_extra_attr(self, name: str):
|
||||||
|
if name not in self.extra_attrs:
|
||||||
|
raise AttributeError(
|
||||||
|
f"{name} is not extra forward context attribute, "
|
||||||
|
"please get/set it from vllm's _forward_context directly."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ctx():
|
||||||
|
return get_forward_context()
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
self.check_extra_attr(name)
|
||||||
|
ctx = self._ctx()
|
||||||
|
if envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
|
||||||
|
return ctx.additional_kwargs[name]
|
||||||
|
return getattr(ctx, name)
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
self.check_extra_attr(name)
|
||||||
|
ctx = self._ctx()
|
||||||
|
if envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
|
||||||
|
ctx.additional_kwargs[name] = value
|
||||||
|
else:
|
||||||
|
setattr(ctx, name, value)
|
||||||
|
|
||||||
|
|
||||||
|
# usage: from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
|
_EXTRA_CTX = _ExtraForwardContextProxy()
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
import vllm.envs as envs_vllm
|
import vllm.envs as envs_vllm
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.v1.attention.backend import ( # type: ignore
|
from vllm.v1.attention.backend import ( # type: ignore
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
@@ -40,6 +39,7 @@ from vllm.v1.attention.backends.registry import ( # type: ignore
|
|||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill
|
from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill
|
||||||
from vllm_ascend.attention.utils import (
|
from vllm_ascend.attention.utils import (
|
||||||
@@ -392,7 +392,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
):
|
):
|
||||||
if using_paged_attention(num_tokens, vllm_config):
|
if using_paged_attention(num_tokens, vllm_config):
|
||||||
# Paged Attention update logic
|
# Paged Attention update logic
|
||||||
if forward_context.is_draft_model:
|
if _EXTRA_CTX.is_draft_model:
|
||||||
graph_params = get_draft_graph_params()
|
graph_params = get_draft_graph_params()
|
||||||
else:
|
else:
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
@@ -444,7 +444,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
event.record(update_stream)
|
event.record(update_stream)
|
||||||
else:
|
else:
|
||||||
# FIA update logic
|
# FIA update logic
|
||||||
if forward_context.is_draft_model:
|
if _EXTRA_CTX.is_draft_model:
|
||||||
graph_params = get_draft_graph_params()
|
graph_params = get_draft_graph_params()
|
||||||
attn_metadata = draft_attn_metadatas
|
attn_metadata = draft_attn_metadatas
|
||||||
attn_keys = list(attn_metadata[0].keys())
|
attn_keys = list(attn_metadata[0].keys())
|
||||||
@@ -462,7 +462,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
num_layers = len(attn_keys)
|
num_layers = len(attn_keys)
|
||||||
if num_layers == 0:
|
if num_layers == 0:
|
||||||
return
|
return
|
||||||
if forward_context.is_draft_model:
|
if _EXTRA_CTX.is_draft_model:
|
||||||
attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers)
|
attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers)
|
||||||
attn_count = 0
|
attn_count = 0
|
||||||
with torch.npu.stream(update_stream):
|
with torch.npu.stream(update_stream):
|
||||||
@@ -488,7 +488,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
softmax_lse,
|
softmax_lse,
|
||||||
) = param
|
) = param
|
||||||
|
|
||||||
if forward_context.is_draft_model:
|
if _EXTRA_CTX.is_draft_model:
|
||||||
draft_step = attn_count // num_layers
|
draft_step = attn_count // num_layers
|
||||||
seq_lens = attn_metadata[draft_step][key].seq_lens_list
|
seq_lens = attn_metadata[draft_step][key].seq_lens_list
|
||||||
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
|
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
|
||||||
@@ -535,8 +535,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
|
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
|
||||||
|
|
||||||
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
||||||
forward_context = get_forward_context()
|
if _EXTRA_CTX.is_draft_model:
|
||||||
if forward_context.is_draft_model:
|
|
||||||
graph_params = get_draft_graph_params()
|
graph_params = get_draft_graph_params()
|
||||||
else:
|
else:
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
@@ -563,7 +562,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
sparse_mode=3,
|
sparse_mode=3,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
)
|
)
|
||||||
if forward_context.is_draft_model:
|
if _EXTRA_CTX.is_draft_model:
|
||||||
update_draft_graph_params_workspaces(num_tokens, workspace)
|
update_draft_graph_params_workspaces(num_tokens, workspace)
|
||||||
else:
|
else:
|
||||||
update_graph_params_workspaces(num_tokens, workspace)
|
update_graph_params_workspaces(num_tokens, workspace)
|
||||||
@@ -625,9 +624,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
output: torch.Tensor | None = None,
|
output: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
forward_context: ForwardContext = get_forward_context()
|
|
||||||
num_tokens = query.shape[0]
|
num_tokens = query.shape[0]
|
||||||
if forward_context.capturing:
|
if _EXTRA_CTX.capturing:
|
||||||
# Get workspace from cache or calculate it if not present.
|
# Get workspace from cache or calculate it if not present.
|
||||||
workspace = graph_params.workspaces.get(num_tokens)
|
workspace = graph_params.workspaces.get(num_tokens)
|
||||||
if workspace is None:
|
if workspace is None:
|
||||||
@@ -761,11 +759,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_metadata: AscendMetadata,
|
attn_metadata: AscendMetadata,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
):
|
):
|
||||||
forward_context: ForwardContext = get_forward_context()
|
|
||||||
# we inherit ForwardContext in model runner v2, when enable model
|
# we inherit ForwardContext in model runner v2, when enable model
|
||||||
# runner v2, there is not capturing attribute in forward_context,
|
# runner v2, there is not capturing attribute in forward_context,
|
||||||
# just use getattr to avoid attribute error.
|
# just use getattr to avoid attribute error.
|
||||||
if getattr(forward_context, "capturing", False):
|
if _EXTRA_CTX.capturing:
|
||||||
attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output)
|
attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output)
|
||||||
output[:num_tokens] = attn_output[:num_tokens]
|
output[:num_tokens] = attn_output[:num_tokens]
|
||||||
return output
|
return output
|
||||||
@@ -841,8 +838,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_metadata: AscendMetadata,
|
attn_metadata: AscendMetadata,
|
||||||
output: torch.Tensor | None = None,
|
output: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
if _EXTRA_CTX.capturing:
|
||||||
if forward_context.capturing:
|
|
||||||
return self.full_graph_pa(query, attn_metadata, output)
|
return self.full_graph_pa(query, attn_metadata, output)
|
||||||
torch_npu._npu_paged_attention(
|
torch_npu._npu_paged_attention(
|
||||||
query=query,
|
query=query,
|
||||||
|
|||||||
@@ -29,10 +29,10 @@ from vllm.distributed import (
|
|||||||
get_decode_context_model_parallel_world_size,
|
get_decode_context_model_parallel_world_size,
|
||||||
get_pcp_group,
|
get_pcp_group,
|
||||||
)
|
)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
|
||||||
from vllm.v1.attention.backend import AttentionCGSupport
|
from vllm.v1.attention.backend import AttentionCGSupport
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.attention.attention_v1 import (
|
from vllm_ascend.attention.attention_v1 import (
|
||||||
AscendAttentionBackendImpl,
|
AscendAttentionBackendImpl,
|
||||||
AscendAttentionMetadataBuilder,
|
AscendAttentionMetadataBuilder,
|
||||||
@@ -559,9 +559,8 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
"actual_seq_lengths": torch.arange(attn_metadata.num_decodes_flatten) + 1,
|
"actual_seq_lengths": torch.arange(attn_metadata.num_decodes_flatten) + 1,
|
||||||
}
|
}
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
forward_context: ForwardContext = get_forward_context()
|
|
||||||
num_tokens = query.shape[0]
|
num_tokens = query.shape[0]
|
||||||
if forward_context.capturing:
|
if _EXTRA_CTX.capturing:
|
||||||
stream = torch_npu.npu.current_stream()
|
stream = torch_npu.npu.current_stream()
|
||||||
|
|
||||||
event = torch.npu.ExternalEvent()
|
event = torch.npu.ExternalEvent()
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from vllm.distributed import (
|
|||||||
get_decode_context_model_parallel_world_size,
|
get_decode_context_model_parallel_world_size,
|
||||||
get_pcp_group,
|
get_pcp_group,
|
||||||
)
|
)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.v1.attention.backend import AttentionCGSupport
|
from vllm.v1.attention.backend import AttentionCGSupport
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||||
@@ -30,6 +29,7 @@ from vllm_ascend.attention.mla_v1 import (
|
|||||||
)
|
)
|
||||||
# isort: on
|
# isort: on
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||||
AscendPCPMetadata,
|
AscendPCPMetadata,
|
||||||
CPChunkedContextMetadata,
|
CPChunkedContextMetadata,
|
||||||
@@ -294,7 +294,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
num_dcp_pcp_tokens=None,
|
num_dcp_pcp_tokens=None,
|
||||||
draft_attn_metadatas=None,
|
draft_attn_metadatas=None,
|
||||||
):
|
):
|
||||||
if forward_context.is_draft_model:
|
if _EXTRA_CTX.is_draft_model:
|
||||||
graph_params = get_draft_graph_params()
|
graph_params = get_draft_graph_params()
|
||||||
else:
|
else:
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
@@ -659,12 +659,11 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
"softmax_lse_flag": True,
|
"softmax_lse_flag": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
if _EXTRA_CTX.is_draft_model:
|
||||||
if forward_context.is_draft_model:
|
|
||||||
graph_params = get_draft_graph_params()
|
graph_params = get_draft_graph_params()
|
||||||
else:
|
else:
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
if forward_context.capturing:
|
if _EXTRA_CTX.capturing:
|
||||||
stream = torch_npu.npu.current_stream()
|
stream = torch_npu.npu.current_stream()
|
||||||
event = torch.npu.ExternalEvent()
|
event = torch.npu.ExternalEvent()
|
||||||
event.wait(stream)
|
event.wait(stream)
|
||||||
|
|||||||
@@ -6,16 +6,20 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
import vllm.envs as envs_vllm
|
import vllm.envs as envs_vllm
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
|
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
|
||||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
from vllm.utils.math_utils import cdiv, round_down
|
from vllm.utils.math_utils import cdiv, round_down
|
||||||
from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore
|
from vllm.v1.attention.backend import (
|
||||||
|
AttentionBackend, # type: ignore
|
||||||
|
AttentionCGSupport,
|
||||||
|
MLAAttentionImpl,
|
||||||
|
)
|
||||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
|
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata, CPChunkedContextMetadata
|
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata, CPChunkedContextMetadata
|
||||||
@@ -44,12 +48,7 @@ from vllm_ascend.ops.layer_shard_linear import (
|
|||||||
)
|
)
|
||||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||||
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
|
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
|
||||||
from vllm_ascend.utils import (
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, get_weight_prefetch_method, maybe_trans_nz, weak_ref_tensors
|
||||||
ACL_FORMAT_FRACTAL_ND,
|
|
||||||
get_weight_prefetch_method,
|
|
||||||
maybe_trans_nz,
|
|
||||||
weak_ref_tensors,
|
|
||||||
)
|
|
||||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -737,7 +736,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
num_dcp_pcp_tokens=None,
|
num_dcp_pcp_tokens=None,
|
||||||
draft_attn_metadatas=None,
|
draft_attn_metadatas=None,
|
||||||
):
|
):
|
||||||
if forward_context.is_draft_model:
|
if _EXTRA_CTX.is_draft_model:
|
||||||
graph_params = get_draft_graph_params()
|
graph_params = get_draft_graph_params()
|
||||||
else:
|
else:
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
@@ -769,12 +768,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
softmax_lse,
|
softmax_lse,
|
||||||
) = param
|
) = param
|
||||||
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
|
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
|
||||||
if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model:
|
if speculative_config and speculative_config.method == "mtp" and not _EXTRA_CTX.is_draft_model:
|
||||||
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
|
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
|
||||||
spec_multiple = speculative_config.num_speculative_tokens + 1
|
spec_multiple = speculative_config.num_speculative_tokens + 1
|
||||||
seq_lens_list = seq_lens_list + [0] * (num_tokens // spec_multiple - len(seq_lens_list))
|
seq_lens_list = seq_lens_list + [0] * (num_tokens // spec_multiple - len(seq_lens_list))
|
||||||
actual_seq_lengths = [spec_multiple * (i + 1) for i in range(num_tokens // spec_multiple)]
|
actual_seq_lengths = [spec_multiple * (i + 1) for i in range(num_tokens // spec_multiple)]
|
||||||
elif forward_context.is_draft_model:
|
elif _EXTRA_CTX.is_draft_model:
|
||||||
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
|
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
|
||||||
block_table = forward_context.attn_metadata[key].decode.block_table
|
block_table = forward_context.attn_metadata[key].decode.block_table
|
||||||
# TODO: This is a hack and should be fixed in the future.
|
# TODO: This is a hack and should be fixed in the future.
|
||||||
@@ -1243,12 +1242,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
"actual_seq_lengths": actual_seq_lengths,
|
"actual_seq_lengths": actual_seq_lengths,
|
||||||
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
|
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
|
||||||
}
|
}
|
||||||
forward_context: ForwardContext = get_forward_context()
|
if _EXTRA_CTX.is_draft_model:
|
||||||
if forward_context.is_draft_model:
|
|
||||||
graph_params = get_draft_graph_params()
|
graph_params = get_draft_graph_params()
|
||||||
else:
|
else:
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
if forward_context.capturing:
|
if _EXTRA_CTX.capturing:
|
||||||
stream = torch_npu.npu.current_stream()
|
stream = torch_npu.npu.current_stream()
|
||||||
|
|
||||||
event = torch.npu.ExternalEvent()
|
event = torch.npu.ExternalEvent()
|
||||||
@@ -1261,7 +1259,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||||
q_nope, k_nope, k_nope, **common_kwargs
|
q_nope, k_nope, k_nope, **common_kwargs
|
||||||
)
|
)
|
||||||
if forward_context.is_draft_model:
|
if _EXTRA_CTX.is_draft_model:
|
||||||
update_draft_graph_params_workspaces(num_tokens, workspace)
|
update_draft_graph_params_workspaces(num_tokens, workspace)
|
||||||
else:
|
else:
|
||||||
update_graph_params_workspaces(num_tokens, workspace)
|
update_graph_params_workspaces(num_tokens, workspace)
|
||||||
@@ -1493,7 +1491,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
reach_layer_for_shard_weight_series(layer)
|
reach_layer_for_shard_weight_series(layer)
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
|
||||||
num_actual_tokens = self.get_num_actual_tokens(attn_metadata)
|
num_actual_tokens = self.get_num_actual_tokens(attn_metadata)
|
||||||
assert (
|
assert (
|
||||||
attn_metadata.num_decodes is not None
|
attn_metadata.num_decodes is not None
|
||||||
@@ -1505,7 +1502,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
# Inputs and outputs may be padded for CUDA graphs
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
output_padded = output
|
output_padded = output
|
||||||
o_proj_input_shape = (forward_context.num_tokens, self.num_heads * self.v_head_dim)
|
o_proj_input_shape = (_EXTRA_CTX.num_tokens, self.num_heads * self.v_head_dim)
|
||||||
o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
|
|
||||||
# MLA Preprocess
|
# MLA Preprocess
|
||||||
|
|||||||
@@ -7,16 +7,20 @@ import vllm.envs as envs_vllm
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
|
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
|
||||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore
|
from vllm.v1.attention.backend import (
|
||||||
|
AttentionBackend, # type: ignore
|
||||||
|
AttentionCGSupport,
|
||||||
|
MLAAttentionImpl,
|
||||||
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata
|
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata
|
||||||
@@ -967,10 +971,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
output: torch.Tensor | None = None,
|
output: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
forward_context = get_forward_context()
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
if self.enable_dsa_cp_with_layer_shard and not forward_context.in_profile_run:
|
if self.enable_dsa_cp_with_layer_shard and not _EXTRA_CTX.in_profile_run:
|
||||||
for layer in self.layer_sharding_kwargs or []:
|
for layer in self.layer_sharding_kwargs or []:
|
||||||
if is_hidden_layer(layer):
|
if is_hidden_layer(layer):
|
||||||
reach_layer_for_shard_weight_series(layer)
|
reach_layer_for_shard_weight_series(layer)
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
|
|||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
|
|
||||||
from ..utils import weak_ref_tensors
|
from ..utils import weak_ref_tensors
|
||||||
|
|
||||||
|
|
||||||
@@ -195,7 +197,7 @@ class ACLGraphWrapper:
|
|||||||
if self.vllm_config.speculative_config
|
if self.vllm_config.speculative_config
|
||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
if self.runtime_mode != CUDAGraphMode.FULL or not forward_context.is_draft_model or not use_eagle:
|
if self.runtime_mode != CUDAGraphMode.FULL or not _EXTRA_CTX.is_draft_model or not use_eagle:
|
||||||
torch.npu.current_stream().synchronize()
|
torch.npu.current_stream().synchronize()
|
||||||
entry.aclgraph.replay()
|
entry.aclgraph.replay()
|
||||||
return entry.output
|
return entry.output
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import torch.distributed as dist
|
|||||||
from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group
|
from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.distributed.parallel_state import get_fc3_quant_x_group
|
from vllm_ascend.distributed.parallel_state import get_fc3_quant_x_group
|
||||||
|
|
||||||
|
|
||||||
@@ -16,7 +17,7 @@ def fc3_all_gather_and_maybe_unpad_impl(
|
|||||||
x = get_fc3_quant_x_group().all_gather(x, 0)
|
x = get_fc3_quant_x_group().all_gather(x, 0)
|
||||||
dp_metadata = forward_context.dp_metadata
|
dp_metadata = forward_context.dp_metadata
|
||||||
if dp_metadata is None:
|
if dp_metadata is None:
|
||||||
pad_size = forward_context.pad_size
|
pad_size = _EXTRA_CTX.pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
x = x[:-pad_size]
|
x = x[:-pad_size]
|
||||||
else:
|
else:
|
||||||
@@ -24,7 +25,7 @@ def fc3_all_gather_and_maybe_unpad_impl(
|
|||||||
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
||||||
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
|
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
|
||||||
dp_size = get_dp_group().world_size
|
dp_size = get_dp_group().world_size
|
||||||
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
|
x = x.view(dp_size, _EXTRA_CTX.padded_length, *x.shape[1:])
|
||||||
offset = 0
|
offset = 0
|
||||||
for idx in range(dp_size):
|
for idx in range(dp_size):
|
||||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ if not vllm_version_is("0.16.0"):
|
|||||||
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore
|
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
|
from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
|
||||||
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
|
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
|
||||||
@@ -148,7 +148,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
|
random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
|
||||||
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
|
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
final_hidden_states = moe_comm_method.fused_experts(
|
final_hidden_states = moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@@ -401,12 +401,13 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
# When static kernels are enabled, the forward pass runs twice (compilation + capture),
|
# When static kernels are enabled, the forward pass runs twice (compilation + capture),
|
||||||
# causing moe_layer_index to overflow. Wrap the index to prevent out-of-bounds errors.
|
# causing moe_layer_index to overflow. Wrap the index to prevent out-of-bounds errors.
|
||||||
if self.enable_npugraph_ex_static_kernel:
|
if self.enable_npugraph_ex_static_kernel:
|
||||||
forward_context.moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers))
|
moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers))
|
||||||
|
forward_context.moe_layer_index = moe_layer_index
|
||||||
|
|
||||||
# Load balancing for token distribution among experts in dummy_run
|
# Load balancing for token distribution among experts in dummy_run
|
||||||
# TODO: The community only considers load balancing when DP > 1.
|
# TODO: The community only considers load balancing when DP > 1.
|
||||||
# This approach may overlook some extreme scenarios.
|
# This approach may overlook some extreme scenarios.
|
||||||
enable_force_load_balance = forward_context.in_profile_run
|
enable_force_load_balance = _EXTRA_CTX.in_profile_run
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
if self.multistream_overlap_gate:
|
if self.multistream_overlap_gate:
|
||||||
@@ -419,7 +420,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
assert fc3_context.shared_experts is not None
|
assert fc3_context.shared_experts is not None
|
||||||
shared_out = fc3_context.shared_experts(hidden_states)
|
shared_out = fc3_context.shared_experts(hidden_states)
|
||||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||||
moe_comm_type = forward_context.moe_comm_type
|
moe_comm_type = _EXTRA_CTX.moe_comm_type
|
||||||
if (
|
if (
|
||||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||||
and not shared_expert_dp_enabled()
|
and not shared_expert_dp_enabled()
|
||||||
@@ -442,16 +443,16 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
global_num_experts=self.global_num_experts,
|
global_num_experts=self.global_num_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(forward_context.moe_comm_method, AllGatherCommImpl):
|
if isinstance(_EXTRA_CTX.moe_comm_method, AllGatherCommImpl):
|
||||||
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True)
|
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True)
|
||||||
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True)
|
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True)
|
||||||
|
|
||||||
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
|
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
|
||||||
|
|
||||||
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
|
hidden_states, router_logits, mc2_mask, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
replace_allreduce=forward_context.flash_comm_v1_enabled,
|
replace_allreduce=_EXTRA_CTX.flash_comm_v1_enabled,
|
||||||
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
||||||
quant_type=self.quant_type,
|
quant_type=self.quant_type,
|
||||||
)
|
)
|
||||||
@@ -509,7 +510,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
self.load_counter.add_(1)
|
self.load_counter.add_(1)
|
||||||
else:
|
else:
|
||||||
self.moe_load.add_(local_load)
|
self.moe_load.add_(local_load)
|
||||||
routed_out = forward_context.moe_comm_method.finalize(
|
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
|
||||||
hidden_states=fused_experts_results.routed_out,
|
hidden_states=fused_experts_results.routed_out,
|
||||||
reduce_results=self.reduce_results,
|
reduce_results=self.reduce_results,
|
||||||
context_metadata=context_metadata,
|
context_metadata=context_metadata,
|
||||||
@@ -670,8 +671,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|||||||
|
|
||||||
# NOTE: This is exactly the opposite of
|
# NOTE: This is exactly the opposite of
|
||||||
# `maybe_all_reduce_tensor_model_parallel`
|
# `maybe_all_reduce_tensor_model_parallel`
|
||||||
forward_context = get_forward_context()
|
moe_comm_type = _EXTRA_CTX.moe_comm_type
|
||||||
moe_comm_type = forward_context.moe_comm_type
|
|
||||||
if (
|
if (
|
||||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||||
and not shared_expert_dp_enabled()
|
and not shared_expert_dp_enabled()
|
||||||
|
|||||||
@@ -19,11 +19,10 @@ from abc import ABC, abstractmethod
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||||
PrepareAndFinalize,
|
PrepareAndFinalize,
|
||||||
@@ -135,7 +134,7 @@ class MoECommMethod(ABC):
|
|||||||
# Check constraints
|
# Check constraints
|
||||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
|
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
assert moe_comm_method is not None, "Missing communication context"
|
assert moe_comm_method is not None, "Missing communication context"
|
||||||
|
|
||||||
before_dispatch_evt = torch.npu.current_stream().record_event()
|
before_dispatch_evt = torch.npu.current_stream().record_event()
|
||||||
|
|||||||
@@ -18,10 +18,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from torch.nn.functional import pad
|
from torch.nn.functional import pad
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||||
from vllm_ascend.device.device_op import DeviceOperator
|
from vllm_ascend.device.device_op import DeviceOperator
|
||||||
from vllm_ascend.device.mxfp_compat import (
|
from vllm_ascend.device.mxfp_compat import (
|
||||||
ensure_mxfp8_moe_available,
|
ensure_mxfp8_moe_available,
|
||||||
@@ -147,7 +146,7 @@ def quant_apply_mlp(
|
|||||||
weight_prefetch_method = get_weight_prefetch_method()
|
weight_prefetch_method = get_weight_prefetch_method()
|
||||||
if weight_prefetch_method:
|
if weight_prefetch_method:
|
||||||
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
|
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
|
||||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
is_mc2 = _EXTRA_CTX.moe_comm_type == MoECommType.MC2
|
||||||
if w1_scale_bias is None and w1_offset is None and is_mc2:
|
if w1_scale_bias is None and w1_offset is None and is_mc2:
|
||||||
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant:
|
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant:
|
||||||
# gmm1: gate_up_proj & act_fn: swiglu
|
# gmm1: gate_up_proj & act_fn: swiglu
|
||||||
|
|||||||
@@ -26,10 +26,10 @@ from vllm.distributed.parallel_state import (
|
|||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||||
from vllm_ascend.quantization.methods.base import QuantType
|
from vllm_ascend.quantization.methods.base import QuantType
|
||||||
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
||||||
@@ -242,8 +242,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
|||||||
"""
|
"""
|
||||||
self.replace_allreduce = replace_allreduce
|
self.replace_allreduce = replace_allreduce
|
||||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||||
forward_context = get_forward_context()
|
mc2_mask = _EXTRA_CTX.mc2_mask
|
||||||
mc2_mask = forward_context.mc2_mask
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
# Also slice mc2_mask
|
# Also slice mc2_mask
|
||||||
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
|
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
|
||||||
@@ -252,7 +251,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
|||||||
padded_hidden_states_shape = hidden_states.shape
|
padded_hidden_states_shape = hidden_states.shape
|
||||||
if not self.replace_allreduce:
|
if not self.replace_allreduce:
|
||||||
self.num_tokens, _ = hidden_states.shape
|
self.num_tokens, _ = hidden_states.shape
|
||||||
target_pad_length = forward_context.padded_num_tokens
|
target_pad_length = _EXTRA_CTX.padded_num_tokens
|
||||||
pad_size = target_pad_length - self.num_tokens
|
pad_size = target_pad_length - self.num_tokens
|
||||||
|
|
||||||
# Pad if necessary (unless shared expert DP is enabled)
|
# Pad if necessary (unless shared expert DP is enabled)
|
||||||
@@ -367,8 +366,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
"""
|
"""
|
||||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||||
if self.moe_config.dp_size > 1:
|
if self.moe_config.dp_size > 1:
|
||||||
forward_context = get_forward_context()
|
max_tokens_across_dp = _EXTRA_CTX.max_tokens_across_dp
|
||||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
|
||||||
|
|
||||||
self.num_tokens = hidden_states.shape[0]
|
self.num_tokens = hidden_states.shape[0]
|
||||||
pad_size = max_tokens_across_dp - self.num_tokens
|
pad_size = max_tokens_across_dp - self.num_tokens
|
||||||
@@ -381,8 +379,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
router_logits = self.moe_config.dp_group.all_gather(router_logits, 0)
|
router_logits = self.moe_config.dp_group.all_gather(router_logits, 0)
|
||||||
|
|
||||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||||
forward_context = get_forward_context()
|
max_tokens_across_pcp = _EXTRA_CTX.max_tokens_across_pcp
|
||||||
max_tokens_across_pcp = forward_context.max_tokens_across_pcp
|
|
||||||
|
|
||||||
self.num_tokens_pcp = hidden_states.shape[0]
|
self.num_tokens_pcp = hidden_states.shape[0]
|
||||||
pad_size = max_tokens_across_pcp - self.num_tokens_pcp
|
pad_size = max_tokens_across_pcp - self.num_tokens_pcp
|
||||||
|
|||||||
@@ -57,9 +57,9 @@ from vllm.distributed import (
|
|||||||
tensor_model_parallel_reduce_scatter,
|
tensor_model_parallel_reduce_scatter,
|
||||||
)
|
)
|
||||||
from vllm.distributed.parallel_state import get_tp_group
|
from vllm.distributed.parallel_state import get_tp_group
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.distributed.parallel_state import (
|
from vllm_ascend.distributed.parallel_state import (
|
||||||
get_flashcomm2_odp_group,
|
get_flashcomm2_odp_group,
|
||||||
get_flashcomm2_otp_group,
|
get_flashcomm2_otp_group,
|
||||||
@@ -311,8 +311,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
|||||||
input_parallel = splitted_input[tp_rank].contiguous()
|
input_parallel = splitted_input[tp_rank].contiguous()
|
||||||
|
|
||||||
# padding for all-to-all
|
# padding for all-to-all
|
||||||
forward_context = get_forward_context()
|
num_padding_tokens = _EXTRA_CTX.pad_size
|
||||||
num_padding_tokens = forward_context.pad_size
|
|
||||||
if num_padding_tokens > 0:
|
if num_padding_tokens > 0:
|
||||||
input_parallel = nn.functional.pad(input_parallel, (0, 0, 0, num_padding_tokens))
|
input_parallel = nn.functional.pad(input_parallel, (0, 0, 0, num_padding_tokens))
|
||||||
|
|
||||||
@@ -368,7 +367,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
|||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
|
|
||||||
if not forward_context.flash_comm_v1_enabled:
|
if not _EXTRA_CTX.flash_comm_v1_enabled:
|
||||||
# flashcomm1 not enabled
|
# flashcomm1 not enabled
|
||||||
output = get_tp_group().all_gather(output, 0)
|
output = get_tp_group().all_gather(output, 0)
|
||||||
if num_padding_tokens > 0:
|
if num_padding_tokens > 0:
|
||||||
@@ -514,9 +513,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
|||||||
def matmul_and_reduce(self, input_parallel: torch.Tensor, bias_: Parameter | None) -> torch.Tensor:
|
def matmul_and_reduce(self, input_parallel: torch.Tensor, bias_: Parameter | None) -> torch.Tensor:
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
try:
|
try:
|
||||||
forward_context = get_forward_context()
|
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
|
||||||
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
|
mmrs_fusion = _EXTRA_CTX.mmrs_fusion
|
||||||
mmrs_fusion = forward_context.mmrs_fusion
|
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
flash_comm_v1_enabled = False
|
flash_comm_v1_enabled = False
|
||||||
mmrs_fusion = False
|
mmrs_fusion = False
|
||||||
@@ -527,7 +525,7 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
|||||||
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
|
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
|
||||||
return tensor_model_parallel_all_reduce(output_parallel)
|
return tensor_model_parallel_all_reduce(output_parallel)
|
||||||
|
|
||||||
pad_size = forward_context.pad_size
|
pad_size = _EXTRA_CTX.pad_size
|
||||||
if pad_size > 0 and not (enable_dsa_cp() and "o_proj" in self.layer.prefix):
|
if pad_size > 0 and not (enable_dsa_cp() and "o_proj" in self.layer.prefix):
|
||||||
x = F.pad(x, (0, 0, 0, pad_size))
|
x = F.pad(x, (0, 0, 0, pad_size))
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
|
|||||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
|
|
||||||
|
|
||||||
class IndexerWrapper(nn.Module):
|
class IndexerWrapper(nn.Module):
|
||||||
@@ -144,7 +145,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
|||||||
kv_cache: torch.Tensor | None = None,
|
kv_cache: torch.Tensor | None = None,
|
||||||
attn_metadata: AttentionMetadata | None = None,
|
attn_metadata: AttentionMetadata | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
need_gather_q_kv = get_forward_context().flash_comm_v1_enabled
|
need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled
|
||||||
output_shape = hidden_states.shape
|
output_shape = hidden_states.shape
|
||||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||||
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from vllm.distributed import (
|
|||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||||
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
|
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
|
||||||
from vllm_ascend.ops.triton.muls_add import muls_add_triton
|
from vllm_ascend.ops.triton.muls_add import muls_add_triton
|
||||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
@@ -22,12 +22,12 @@ from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
|||||||
|
|
||||||
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||||
try:
|
try:
|
||||||
forward_context = get_forward_context()
|
get_forward_context()
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
return residual
|
return residual
|
||||||
|
|
||||||
if x.size(0) != residual.size(0):
|
if x.size(0) != residual.size(0):
|
||||||
pad_size = forward_context.pad_size
|
pad_size = _EXTRA_CTX.pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@@ -43,12 +43,12 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
|||||||
except AssertionError:
|
except AssertionError:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
|
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
|
||||||
if flash_comm_v1_enabled and label:
|
if flash_comm_v1_enabled and label:
|
||||||
dp_metadata = forward_context.dp_metadata
|
dp_metadata = forward_context.dp_metadata
|
||||||
if dp_metadata is None or not is_ep_comm:
|
if dp_metadata is None or not is_ep_comm:
|
||||||
x = tensor_model_parallel_all_gather(x, 0)
|
x = tensor_model_parallel_all_gather(x, 0)
|
||||||
pad_size = forward_context.pad_size
|
pad_size = _EXTRA_CTX.pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
x = x[:-pad_size]
|
x = x[:-pad_size]
|
||||||
else:
|
else:
|
||||||
@@ -57,7 +57,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
|||||||
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
||||||
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
|
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
|
||||||
dp_size = get_dp_group().world_size
|
dp_size = get_dp_group().world_size
|
||||||
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
|
x = x.view(dp_size, _EXTRA_CTX.padded_length, *x.shape[1:])
|
||||||
offset = 0
|
offset = 0
|
||||||
for idx in range(dp_size):
|
for idx in range(dp_size):
|
||||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||||
@@ -79,7 +79,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
|||||||
|
|
||||||
dp_metadata = forward_context.dp_metadata
|
dp_metadata = forward_context.dp_metadata
|
||||||
if dp_metadata is None or not is_ep_comm:
|
if dp_metadata is None or not is_ep_comm:
|
||||||
pad_size = forward_context.pad_size
|
pad_size = _EXTRA_CTX.pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
x = F.pad(x, (0, 0, 0, pad_size))
|
x = F.pad(x, (0, 0, 0, pad_size))
|
||||||
return tensor_model_parallel_reduce_scatter(x, 0)
|
return tensor_model_parallel_reduce_scatter(x, 0)
|
||||||
@@ -87,7 +87,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
|||||||
# padding
|
# padding
|
||||||
dp_size = get_dp_group().world_size
|
dp_size = get_dp_group().world_size
|
||||||
num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu
|
num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu
|
||||||
padded_x = torch.empty((dp_size, forward_context.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype)
|
padded_x = torch.empty((dp_size, _EXTRA_CTX.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype)
|
||||||
offset = 0
|
offset = 0
|
||||||
for idx in range(dp_size):
|
for idx in range(dp_size):
|
||||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||||
@@ -98,7 +98,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
|||||||
|
|
||||||
|
|
||||||
def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor:
|
def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor:
|
||||||
if get_forward_context().flash_comm_v1_enabled and label:
|
if _EXTRA_CTX.flash_comm_v1_enabled and label:
|
||||||
return torch.empty(
|
return torch.empty(
|
||||||
(x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
(x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||||
)
|
)
|
||||||
@@ -107,7 +107,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
|
|||||||
|
|
||||||
|
|
||||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
|
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
|
||||||
if get_forward_context().flash_comm_v1_enabled:
|
if _EXTRA_CTX.flash_comm_v1_enabled:
|
||||||
return torch.empty(
|
return torch.empty(
|
||||||
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||||
)
|
)
|
||||||
@@ -138,11 +138,10 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor:
|
def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
forward_context = get_forward_context()
|
moe_comm_type = _EXTRA_CTX.moe_comm_type
|
||||||
moe_comm_type = forward_context.moe_comm_type
|
|
||||||
if (
|
if (
|
||||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||||
or forward_context.flash_comm_v1_enabled
|
or _EXTRA_CTX.flash_comm_v1_enabled
|
||||||
):
|
):
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
else:
|
else:
|
||||||
@@ -163,7 +162,7 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str)
|
|||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
num_tokens = input_parallel.size(0)
|
num_tokens = input_parallel.size(0)
|
||||||
if forward_context.flash_comm_v1_enabled:
|
if _EXTRA_CTX.flash_comm_v1_enabled:
|
||||||
num_tokens = num_tokens // self.tp_size
|
num_tokens = num_tokens // self.tp_size
|
||||||
output = torch.empty(
|
output = torch.empty(
|
||||||
size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype
|
size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import (
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
DeepseekScalingRotaryEmbedding,
|
DeepseekScalingRotaryEmbedding,
|
||||||
MRotaryEmbedding,
|
MRotaryEmbedding,
|
||||||
@@ -31,6 +30,7 @@ from vllm.model_executor.layers.rotary_embedding import (
|
|||||||
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
|
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
|
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
|
||||||
|
|
||||||
@@ -240,8 +240,8 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
|||||||
is_neox_style = self.is_neox_style
|
is_neox_style = self.is_neox_style
|
||||||
if is_neox_style_override is not None:
|
if is_neox_style_override is not None:
|
||||||
is_neox_style = is_neox_style_override
|
is_neox_style = is_neox_style_override
|
||||||
is_draft_model = get_forward_context().is_draft_model
|
is_draft_model = _EXTRA_CTX.is_draft_model
|
||||||
flash_comm_v1_enabled = get_forward_context().flash_comm_v1_enabled
|
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
|
||||||
if is_draft_model and self.use_mtp and flash_comm_v1_enabled:
|
if is_draft_model and self.use_mtp and flash_comm_v1_enabled:
|
||||||
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions.contiguous(), True)
|
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions.contiguous(), True)
|
||||||
return torch.ops.vllm.npu_rotary_embedding(
|
return torch.ops.vllm.npu_rotary_embedding(
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from vllm.config import get_current_vllm_config
|
|||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.ops.linear import AscendQKVParallelLinear, AscendRowParallelLinear
|
from vllm_ascend.ops.linear import AscendQKVParallelLinear, AscendRowParallelLinear
|
||||||
from vllm_ascend.utils import is_moe_model
|
from vllm_ascend.utils import is_moe_model
|
||||||
|
|
||||||
@@ -95,11 +96,11 @@ class WeightPrefetchMethod:
|
|||||||
if not self.moe.is_active_this_forward:
|
if not self.moe.is_active_this_forward:
|
||||||
return
|
return
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
if not forward_context or forward_context.model_instance is None:
|
if not forward_context or _EXTRA_CTX.model_instance is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm.
|
# layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm.
|
||||||
weight = forward_context.model_instance.model.layers[forward_context.layer_idx - 1].mlp.experts.w13_weight
|
weight = _EXTRA_CTX.model_instance.model.layers[_EXTRA_CTX.layer_idx - 1].mlp.experts.w13_weight # type: ignore # type: ignore
|
||||||
weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0)
|
weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0)
|
||||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=None, max_weight_size=int(weight_size))
|
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=None, max_weight_size=int(weight_size))
|
||||||
|
|
||||||
@@ -122,9 +123,7 @@ class WeightPrefetchMethod:
|
|||||||
except AssertionError:
|
except AssertionError:
|
||||||
return
|
return
|
||||||
self.mlp.is_active_this_forward = (
|
self.mlp.is_active_this_forward = (
|
||||||
forward_context.layer_idx is not None
|
_EXTRA_CTX.layer_idx is not None and _EXTRA_CTX.num_tokens is not None and _EXTRA_CTX.num_tokens < 500
|
||||||
and forward_context.num_tokens is not None
|
|
||||||
and forward_context.num_tokens < 500
|
|
||||||
)
|
)
|
||||||
if not self.mlp.is_active_this_forward:
|
if not self.mlp.is_active_this_forward:
|
||||||
return
|
return
|
||||||
@@ -144,9 +143,9 @@ class WeightPrefetchMethod:
|
|||||||
|
|
||||||
# start point of gate_up_proj weight prefetch
|
# start point of gate_up_proj weight prefetch
|
||||||
if curr_layer_prefix.split(".")[-2] == "self_attn":
|
if curr_layer_prefix.split(".")[-2] == "self_attn":
|
||||||
model_instance = forward_context.model_instance
|
model_instance = _EXTRA_CTX.model_instance
|
||||||
layer_idx = int(curr_layer_prefix.split(".")[2])
|
layer_idx = int(curr_layer_prefix.split(".")[2])
|
||||||
weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight
|
weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight # type: ignore
|
||||||
if self.mlp_pre_version_compatibale_config:
|
if self.mlp_pre_version_compatibale_config:
|
||||||
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0)
|
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0)
|
||||||
else:
|
else:
|
||||||
@@ -156,12 +155,12 @@ class WeightPrefetchMethod:
|
|||||||
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
||||||
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
||||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
||||||
forward_context.prefetch_mlp_gate_up_proj = True
|
_EXTRA_CTX.prefetch_mlp_gate_up_proj = True
|
||||||
|
|
||||||
def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext):
|
def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext):
|
||||||
layer_idx = forward_context.layer_idx
|
layer_idx = _EXTRA_CTX.layer_idx
|
||||||
model_instance = forward_context.model_instance
|
model_instance = _EXTRA_CTX.model_instance
|
||||||
weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight
|
weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight # type: ignore
|
||||||
if self.mlp_pre_version_compatibale_config:
|
if self.mlp_pre_version_compatibale_config:
|
||||||
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0)
|
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0)
|
||||||
else:
|
else:
|
||||||
@@ -171,22 +170,22 @@ class WeightPrefetchMethod:
|
|||||||
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
||||||
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
||||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
||||||
forward_context.prefetch_mlp_down_proj = True
|
_EXTRA_CTX.prefetch_mlp_down_proj = True
|
||||||
forward_context.layer_idx += 1
|
_EXTRA_CTX.layer_idx = layer_idx + 1 # type: ignore
|
||||||
|
|
||||||
def maybe_prefetch_mlp_weight_postprocess(self, stop_flag: torch.Tensor):
|
def maybe_prefetch_mlp_weight_postprocess(self, stop_flag: torch.Tensor):
|
||||||
if not self.mlp.is_active_this_forward:
|
if not self.mlp.is_active_this_forward:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
forward_context = get_forward_context()
|
get_forward_context()
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
return
|
return
|
||||||
|
|
||||||
if forward_context.prefetch_mlp_gate_up_proj or forward_context.prefetch_mlp_down_proj:
|
if _EXTRA_CTX.prefetch_mlp_gate_up_proj or _EXTRA_CTX.prefetch_mlp_down_proj:
|
||||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||||
forward_context.prefetch_mlp_gate_up_proj = False
|
_EXTRA_CTX.prefetch_mlp_gate_up_proj = False
|
||||||
forward_context.prefetch_mlp_down_proj = False
|
_EXTRA_CTX.prefetch_mlp_down_proj = False
|
||||||
|
|
||||||
def maybe_prefetch_mla_or_sla_weight_in_current_stream(
|
def maybe_prefetch_mla_or_sla_weight_in_current_stream(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ def propose(
|
|||||||
|
|
||||||
# FIXME(woosuk): This is UNSAFE!!
|
# FIXME(woosuk): This is UNSAFE!!
|
||||||
attn_metadata = build_attn_metadata(
|
attn_metadata = build_attn_metadata(
|
||||||
attn_metadata_builders=self.attn_metadata_builders,
|
attn_groups=self.attn_groups,
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_tokens=num_reqs,
|
num_tokens=num_reqs,
|
||||||
query_start_loc_gpu=query_start_loc,
|
query_start_loc_gpu=query_start_loc,
|
||||||
|
|||||||
@@ -589,11 +589,12 @@ class NPUPlatform(Platform):
|
|||||||
if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
|
if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
# is_draft_model will be removed later, so we set it to False temporarily.
|
||||||
|
is_draft_model = False
|
||||||
moe_comm_type = select_moe_comm_method(
|
moe_comm_type = select_moe_comm_method(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
vllm_config,
|
vllm_config,
|
||||||
# is_draft_model will be removed later, so we set it to False temporarily.
|
is_draft_model=is_draft_model,
|
||||||
is_draft_model=False,
|
|
||||||
)
|
)
|
||||||
moe_comm_method = get_moe_comm_method(moe_comm_type)
|
moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||||
|
|
||||||
@@ -620,7 +621,7 @@ class NPUPlatform(Platform):
|
|||||||
|
|
||||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||||
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
||||||
pad_size = None
|
pad_size = 0
|
||||||
padded_length = None
|
padded_length = None
|
||||||
if flash_comm_v1_enabled or flashcomm_v2_enabled:
|
if flash_comm_v1_enabled or flashcomm_v2_enabled:
|
||||||
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
||||||
@@ -657,6 +658,7 @@ class NPUPlatform(Platform):
|
|||||||
"padded_length": padded_length,
|
"padded_length": padded_length,
|
||||||
"max_tokens_across_dp": max_tokens_across_dp,
|
"max_tokens_across_dp": max_tokens_across_dp,
|
||||||
"mc2_mask": mc2_mask,
|
"mc2_mask": mc2_mask,
|
||||||
|
"is_draft_model": is_draft_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ from typing import Any
|
|||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||||
|
|
||||||
from .base import AscendMoEScheme
|
from .base import AscendMoEScheme
|
||||||
@@ -215,7 +215,7 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
|
|||||||
topk_ids = topk_ids.to(torch.int32)
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
return moe_comm_method.fused_experts(
|
return moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight_packed,
|
w1=layer.w13_weight_packed,
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import get_ep_group
|
from vllm.distributed import get_ep_group
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||||
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz
|
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz
|
||||||
@@ -375,7 +375,7 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
|
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
return moe_comm_method.fused_experts(
|
return moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=[layer.w13_weight],
|
w1=[layer.w13_weight],
|
||||||
|
|||||||
@@ -22,11 +22,10 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.config import CompilationMode, get_current_vllm_config
|
from vllm.config import CompilationMode, get_current_vllm_config
|
||||||
from vllm.distributed import get_ep_group
|
from vllm.distributed import get_ep_group
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.flash_common3_context import get_flash_common3_context
|
from vllm_ascend.flash_common3_context import get_flash_common3_context
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
|
||||||
@@ -234,10 +233,9 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
assert topk_weights is not None
|
assert topk_weights is not None
|
||||||
topk_weights = topk_weights.to(self.in_dtype)
|
topk_weights = topk_weights.to(self.in_dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
fused_scale_flag = (
|
fused_scale_flag = (
|
||||||
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
|
_EXTRA_CTX.moe_comm_type == MoECommType.FUSED_MC2 and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
|
||||||
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
|
|
||||||
)
|
)
|
||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
w1 = layer.w13_weight_list
|
w1 = layer.w13_weight_list
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.config import CompilationMode, get_current_vllm_config
|
from vllm.config import CompilationMode, get_current_vllm_config
|
||||||
from vllm.distributed import get_ep_group
|
from vllm.distributed import get_ep_group
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
from vllm_ascend.device.mxfp_compat import (
|
from vllm_ascend.device.mxfp_compat import (
|
||||||
FLOAT8_E8M0FNU_DTYPE,
|
FLOAT8_E8M0FNU_DTYPE,
|
||||||
ensure_mxfp8_linear_available,
|
ensure_mxfp8_linear_available,
|
||||||
@@ -187,7 +187,7 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
|
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
return moe_comm_method.fused_experts(
|
return moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
|
|||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, set_ascend_forward_context
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
@@ -398,7 +398,7 @@ class AscendEagleProposer(EagleProposer):
|
|||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
)
|
)
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing:
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not _EXTRA_CTX.capturing:
|
||||||
self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata)
|
self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata)
|
||||||
|
|
||||||
def _propose(
|
def _propose(
|
||||||
@@ -784,8 +784,8 @@ class AscendEagleProposer(EagleProposer):
|
|||||||
input_batch_size = num_input_tokens if (self.method == "mtp" or self.use_cuda_graph) else batch_size
|
input_batch_size = num_input_tokens if (self.method == "mtp" or self.use_cuda_graph) else batch_size
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
forward_context.num_tokens = input_batch_size
|
_EXTRA_CTX.num_tokens = input_batch_size
|
||||||
forward_context.num_accept_tokens = batch_size
|
_EXTRA_CTX.num_accept_tokens = batch_size
|
||||||
|
|
||||||
for draft_step in range(self.num_speculative_tokens - 1):
|
for draft_step in range(self.num_speculative_tokens - 1):
|
||||||
# Reset MOE layer index for each draft step iteration
|
# Reset MOE layer index for each draft step iteration
|
||||||
@@ -1361,15 +1361,14 @@ class AscendEagleProposer(EagleProposer):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
forward_context = get_forward_context()
|
|
||||||
if self.method == "mtp":
|
if self.method == "mtp":
|
||||||
if forward_context.flash_comm_v1_enabled:
|
if _EXTRA_CTX.flash_comm_v1_enabled:
|
||||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states)
|
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states)
|
||||||
positions = positions.unsqueeze(-1)
|
positions = positions.unsqueeze(-1)
|
||||||
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
|
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
|
||||||
positions = positions.squeeze(-1)
|
positions = positions.squeeze(-1)
|
||||||
else:
|
else:
|
||||||
if forward_context.flash_comm_v1_enabled:
|
if _EXTRA_CTX.flash_comm_v1_enabled:
|
||||||
hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states)
|
hidden_states = split_inputs_tp_to_sp(hidden_states, hidden_states)
|
||||||
return hidden_states, positions
|
return hidden_states, positions
|
||||||
|
|
||||||
@@ -1388,8 +1387,7 @@ class AscendEagleProposer(EagleProposer):
|
|||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
hidden_states = last_hidden_states
|
hidden_states = last_hidden_states
|
||||||
else:
|
else:
|
||||||
forward_context = get_forward_context()
|
if _EXTRA_CTX.flash_comm_v1_enabled:
|
||||||
if forward_context.flash_comm_v1_enabled:
|
|
||||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
last_hidden_states.contiguous(), True
|
last_hidden_states.contiguous(), True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,4 +5,5 @@ This directory contains the new model runner which is under active development.
|
|||||||
please see [Model Runner V2](https://github.com/vllm-project/vllm-ascend/issues/5208)
|
please see [Model Runner V2](https://github.com/vllm-project/vllm-ascend/issues/5208)
|
||||||
to get specific plans.
|
to get specific plans.
|
||||||
|
|
||||||
supported vllm version: main@1339784
|
supported vllm version: main@4034c3d32e30d01639459edd3ab486f56993876d
|
||||||
|
related PR: <https://github.com/vllm-project/vllm-ascend/pull/7110>
|
||||||
|
|||||||
@@ -19,16 +19,25 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import vllm
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
from vllm.config.compilation import CUDAGraphMode
|
||||||
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
|
from vllm.logger import logger
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
|
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
||||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||||
from vllm.v1.worker.gpu.cudagraph_utils import prepare_inputs_to_capture as prepare_inputs_to_capture_gpu
|
|
||||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||||
|
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||||
|
from vllm.v1.worker.utils import AttentionGroup
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
|
from vllm_ascend.compilation.acl_graph import set_graph_params, update_full_graph_params
|
||||||
|
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
|
||||||
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
|
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
|
||||||
|
|
||||||
|
|
||||||
@@ -38,44 +47,134 @@ class AclGraphManager(CudaGraphManager):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
use_mrope: bool,
|
use_aux_hidden_state_outputs: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
model_runner: Any, # NPUModelRunner type, in case circular import, so we pass it as Any
|
||||||
):
|
):
|
||||||
with torch_cuda_wrapper():
|
# set model runner attribute, so we can access attributes model runner
|
||||||
super().__init__(vllm_config, use_mrope, device)
|
# when call `run_fullgraph` method in CudaGraphManager,
|
||||||
|
# then we don't need to # copy `execute_model` method in `NPUModelRunner` class.
|
||||||
|
self.model_runner = model_runner
|
||||||
|
super().__init__(
|
||||||
|
vllm_config,
|
||||||
|
use_aux_hidden_state_outputs,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
# vllm-ascend need to update graph params of attention backend.
|
||||||
|
# so we need to set graph params before capture full graph.
|
||||||
|
if super().needs_capture():
|
||||||
|
set_graph_params(self.cudagraph_sizes)
|
||||||
|
|
||||||
|
def _capture_full_graph(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_reqs: int,
|
||||||
|
model: nn.Module,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor | None,
|
||||||
|
num_tokens_across_dp: torch.Tensor,
|
||||||
|
attn_metadata: dict[str, Any] | None,
|
||||||
|
slot_mappings: dict[str, torch.Tensor] | None,
|
||||||
|
has_lora: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Override _capture_full_graph because we need to set capturing=True in forward context."""
|
||||||
|
# set capturing=True in before model forward.
|
||||||
|
model = ModelWithContext(model)
|
||||||
|
return super()._capture_full_graph(
|
||||||
|
num_tokens,
|
||||||
|
num_reqs,
|
||||||
|
model,
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
inputs_embeds,
|
||||||
|
num_tokens_across_dp,
|
||||||
|
attn_metadata,
|
||||||
|
slot_mappings,
|
||||||
|
has_lora,
|
||||||
|
)
|
||||||
|
|
||||||
def capture_graph(
|
def capture_graph(
|
||||||
self,
|
self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
|
capture_cg_mode: CUDAGraphMode,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
model_state: ModelState,
|
||||||
input_buffers: InputBuffers,
|
input_buffers: InputBuffers,
|
||||||
block_tables: BlockTables,
|
block_tables: BlockTables,
|
||||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
attn_groups: list[list[AttentionGroup]],
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
|
has_lora: bool = False,
|
||||||
|
uniform_decode: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
with torch_cuda_wrapper(), prepare_capture_inputs_wrapper():
|
with torch_cuda_wrapper(), prepare_capture_inputs_wrapper():
|
||||||
super().capture_graph(
|
super().capture_graph(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
|
capture_cg_mode,
|
||||||
model,
|
model,
|
||||||
|
model_state,
|
||||||
input_buffers,
|
input_buffers,
|
||||||
block_tables,
|
block_tables,
|
||||||
attn_metadata_builders,
|
attn_groups,
|
||||||
kv_cache_config,
|
kv_cache_config,
|
||||||
|
has_lora,
|
||||||
|
uniform_decode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def run_fullgraph(self, num_tokens: int) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||||
|
"""Override run_fullgraph to update full graph params in run_fullgraph."""
|
||||||
|
logger.info_once(f"run_fullgraph with num_tokens={num_tokens}")
|
||||||
|
ret = super().run_fullgraph(num_tokens)
|
||||||
|
assert self.model_runner.cudagraph_and_dp_padding is not None
|
||||||
|
|
||||||
|
positions = self.model_runner.input_buffers.positions[:num_tokens]
|
||||||
|
_num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||||
|
self.model_runner.cudagraph_and_dp_padding
|
||||||
|
)
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
|
||||||
|
|
||||||
|
with set_forward_context(
|
||||||
|
self.model_runner.input_batch.attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
batch_descriptor=None, # Full graph model don't need batch_descriptor
|
||||||
|
slot_mapping=self.model_runner.input_batch.slot_mappings,
|
||||||
|
):
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
update_full_graph_params(
|
||||||
|
# FIXME(Ronald1995): support hybrid attn backend
|
||||||
|
list(self.model_runner.attn_backends.values())[0],
|
||||||
|
self.model_runner.update_stream,
|
||||||
|
forward_context,
|
||||||
|
num_tokens,
|
||||||
|
self.vllm_config,
|
||||||
|
self.model_runner.speculative_config,
|
||||||
|
positions.shape[0],
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def is_uniform_decode(
|
||||||
|
self,
|
||||||
|
num_reqs: int,
|
||||||
|
num_tokens: int,
|
||||||
|
max_query_len: int,
|
||||||
|
):
|
||||||
|
return (max_query_len == self.uniform_decode_query_len) and (num_tokens == max_query_len * num_reqs)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def prepare_capture_inputs_wrapper():
|
def prepare_capture_inputs_wrapper():
|
||||||
"""Context manager to override input preparation for NPU graph capture."""
|
"""Context manager to override input preparation for NPU graph capture."""
|
||||||
# TODO(Ronald1995): make prepare_inputs_to_capture as static method
|
# TODO(Ronald1995): make prepare_inputs_to_capture as static method
|
||||||
# in CudaGraphManager.
|
# in CudaGraphManager.
|
||||||
global prepare_inputs_to_capture_gpu
|
ori = vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture
|
||||||
try:
|
try:
|
||||||
ori_func = prepare_inputs_to_capture_gpu
|
vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture = prepare_inputs_to_capture
|
||||||
prepare_inputs_to_capture_gpu = prepare_inputs_to_capture
|
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
prepare_inputs_to_capture_gpu = ori_func
|
vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture = ori
|
||||||
|
|
||||||
|
|
||||||
def prepare_inputs_to_capture(
|
def prepare_inputs_to_capture(
|
||||||
@@ -83,9 +182,66 @@ def prepare_inputs_to_capture(
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
input_buffers: InputBuffers,
|
input_buffers: InputBuffers,
|
||||||
block_tables: BlockTables,
|
block_tables: BlockTables,
|
||||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
attn_groups: list[list[AttentionGroup]],
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
) -> dict[str, Any]:
|
uniform_decode_query_len: int = 0,
|
||||||
# TODO(Ronald1995): Implement NPU specific input preparation.
|
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
||||||
return {}
|
if uniform_decode_query_len > 0:
|
||||||
|
num_tokens_per_req = uniform_decode_query_len
|
||||||
|
else:
|
||||||
|
num_tokens_per_req = num_tokens // num_reqs
|
||||||
|
|
||||||
|
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||||
|
query_start_loc_np[-1] = num_tokens
|
||||||
|
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||||
|
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
|
||||||
|
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||||
|
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||||
|
|
||||||
|
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
|
||||||
|
# rather than max_model_len.
|
||||||
|
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||||
|
input_buffers.seq_lens[num_reqs:] = 0
|
||||||
|
input_buffers.seq_lens_cpu[:num_reqs] = num_tokens
|
||||||
|
input_buffers.seq_lens_cpu[num_reqs:] = 0
|
||||||
|
|
||||||
|
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
|
||||||
|
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
|
||||||
|
|
||||||
|
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||||
|
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
|
||||||
|
slot_mappings_by_layer = build_slot_mappings_by_layer(slot_mappings, kv_cache_config)
|
||||||
|
|
||||||
|
attn_metadata = build_attn_metadata(
|
||||||
|
attn_groups=attn_groups,
|
||||||
|
num_reqs=num_reqs,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
query_start_loc_gpu=query_start_loc,
|
||||||
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
|
max_query_len=num_tokens_per_req,
|
||||||
|
seq_lens=input_buffers.seq_lens,
|
||||||
|
max_seq_len=max_model_len,
|
||||||
|
block_tables=input_block_tables,
|
||||||
|
slot_mappings=slot_mappings,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
seq_lens_np=input_buffers.seq_lens_np,
|
||||||
|
)
|
||||||
|
return attn_metadata, slot_mappings_by_layer
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWithContext(nn.Module):
|
||||||
|
"""Define a wrapper model to inject forward context.
|
||||||
|
so we can inherit vllm's CudaGraphManager._capture_full_graph.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, original_model):
|
||||||
|
super().__init__()
|
||||||
|
self.original_model = original_model
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
# In warmup phase, capturing=False by default.
|
||||||
|
# when capturing, we need to set capturing=True in forward context.
|
||||||
|
_EXTRA_CTX.capturing = True
|
||||||
|
|
||||||
|
return self.original_model(*args, **kwargs)
|
||||||
|
|||||||
@@ -23,8 +23,8 @@ from typing import Any
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
|
||||||
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig
|
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig
|
||||||
|
from vllm.v1.worker.utils import AttentionGroup
|
||||||
|
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
@@ -43,7 +43,7 @@ def get_attn_mask_builder(device: torch.device):
|
|||||||
|
|
||||||
def build_attn_metadata(
|
def build_attn_metadata(
|
||||||
*,
|
*,
|
||||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
attn_groups: list[list[AttentionGroup]],
|
||||||
num_reqs: int,
|
num_reqs: int,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
query_start_loc_gpu: torch.Tensor,
|
query_start_loc_gpu: torch.Tensor,
|
||||||
@@ -54,6 +54,7 @@ def build_attn_metadata(
|
|||||||
block_tables: Sequence[torch.Tensor],
|
block_tables: Sequence[torch.Tensor],
|
||||||
slot_mappings: torch.Tensor,
|
slot_mappings: torch.Tensor,
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
|
dcp_local_seq_lens: torch.Tensor | None = None,
|
||||||
# extra attributes for ascend npus.
|
# extra attributes for ascend npus.
|
||||||
seq_lens_np: np.ndarray | None = None,
|
seq_lens_np: np.ndarray | None = None,
|
||||||
num_computed_tokens_cpu: torch.Tensor | None = None,
|
num_computed_tokens_cpu: torch.Tensor | None = None,
|
||||||
@@ -72,9 +73,6 @@ def build_attn_metadata(
|
|||||||
if seq_lens_np is None:
|
if seq_lens_np is None:
|
||||||
seq_lens_np = np.full(num_reqs, max_seq_len, dtype=np.int32)
|
seq_lens_np = np.full(num_reqs, max_seq_len, dtype=np.int32)
|
||||||
seq_lens_cpu = torch.from_numpy(seq_lens_np)[:num_reqs]
|
seq_lens_cpu = torch.from_numpy(seq_lens_np)[:num_reqs]
|
||||||
# torch_npu._reshape_and_cache operator requires slot_mappings to
|
|
||||||
# be torch.int32.
|
|
||||||
slot_mappings = slot_mappings.to(torch.int32)
|
|
||||||
|
|
||||||
attn_metadata: dict[str, Any] = {}
|
attn_metadata: dict[str, Any] = {}
|
||||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||||
@@ -100,13 +98,14 @@ def build_attn_metadata(
|
|||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_metadata_builder = attn_metadata_builders[i]
|
for attn_group in attn_groups[i]:
|
||||||
metadata = attn_metadata_builder.build(
|
attn_metadata_builder = attn_group.get_metadata_builder(0)
|
||||||
common_prefix_len=0,
|
metadata = attn_metadata_builder.build(
|
||||||
common_attn_metadata=common_attn_metadata, # type: ignore
|
common_prefix_len=0,
|
||||||
)
|
common_attn_metadata=common_attn_metadata,
|
||||||
for layer_name in kv_cache_spec.layer_names:
|
)
|
||||||
attn_metadata[layer_name] = metadata
|
for layer_name in attn_group.layer_names:
|
||||||
|
attn_metadata[layer_name] = metadata
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
58
vllm_ascend/worker/v2/block_table.py
Normal file
58
vllm_ascend/worker/v2/block_table.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/block_table.py
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||||
|
|
||||||
|
|
||||||
|
class AscendBlockTables(BlockTables):
|
||||||
|
"""Block table for Ascend NPUs."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_sizes: list[int],
|
||||||
|
max_num_reqs: int,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
max_model_len: int,
|
||||||
|
device: torch.device,
|
||||||
|
cp_size: int = 1,
|
||||||
|
cp_rank: int = 0,
|
||||||
|
cp_interleave: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
block_sizes,
|
||||||
|
max_num_reqs,
|
||||||
|
max_num_batched_tokens,
|
||||||
|
max_model_len,
|
||||||
|
device,
|
||||||
|
cp_size,
|
||||||
|
cp_rank,
|
||||||
|
cp_interleave,
|
||||||
|
)
|
||||||
|
# because we will override these attribute, delete these attribute to
|
||||||
|
# make sure it's collected by python gc immediately.
|
||||||
|
del self.slot_mappings
|
||||||
|
# vllm-ascend' reshape_and_cache function requires slot_mappings to be int32.
|
||||||
|
# so we need to redefine slot_mappings to be int32.
|
||||||
|
self.slot_mappings: torch.Tensor = torch.zeros(
|
||||||
|
self.num_kv_cache_groups,
|
||||||
|
self.max_num_batched_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
@@ -22,6 +22,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||||
|
|
||||||
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
|
|
||||||
|
|
||||||
class AscendInputBuffers(InputBuffers):
|
class AscendInputBuffers(InputBuffers):
|
||||||
"""Input buffers for Ascend NPUs."""
|
"""Input buffers for Ascend NPUs."""
|
||||||
@@ -37,6 +39,16 @@ class AscendInputBuffers(InputBuffers):
|
|||||||
max_num_tokens,
|
max_num_tokens,
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
|
del self.query_start_loc
|
||||||
|
|
||||||
|
# NOTE: For FULL mode we change +1 to +2 to reserve extra space for padding.
|
||||||
|
# See _pad_query_start_loc_for_fia.
|
||||||
|
self.query_start_loc: torch.Tensor = torch.zeros(
|
||||||
|
max_num_reqs + 2,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
# Create seq_lens_cpu and seq_lens_np.
|
# Create seq_lens_cpu and seq_lens_np.
|
||||||
# npu's attention backend still needs seq_lens on CPU side.
|
# npu's attention backend still needs seq_lens on CPU side.
|
||||||
self.seq_lens_cpu: torch.Tensor = torch.zeros(
|
self.seq_lens_cpu: torch.Tensor = torch.zeros(
|
||||||
@@ -56,6 +68,8 @@ class AscendInputBatch(InputBatch):
|
|||||||
# Create seq_lens_np.
|
# Create seq_lens_np.
|
||||||
# npu's attention backend still needs seq_lens on CPU side.
|
# npu's attention backend still needs seq_lens on CPU side.
|
||||||
seq_lens_np: np.ndarray
|
seq_lens_np: np.ndarray
|
||||||
|
# attn_state is used to build attention metadata.
|
||||||
|
attn_state: AscendAttentionState | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_dummy(
|
def make_dummy(
|
||||||
@@ -79,4 +93,11 @@ class AscendInputBatch(InputBatch):
|
|||||||
input_buffers.seq_lens_np[num_reqs:] = 0
|
input_buffers.seq_lens_np[num_reqs:] = 0
|
||||||
seq_lens_np = input_buffers.seq_lens_np[:num_reqs]
|
seq_lens_np = input_buffers.seq_lens_np[:num_reqs]
|
||||||
input_batch.seq_lens_np = seq_lens_np
|
input_batch.seq_lens_np = seq_lens_np
|
||||||
|
# A dummy run for dp or memory profiling.
|
||||||
|
# When dummy run for dp, num_tokens is set to 1,
|
||||||
|
# so attn_state is set to DecodeOnly.
|
||||||
|
# when dummy run for memory profiling,
|
||||||
|
# attention metadata isn't needed,
|
||||||
|
# we can also set attn_state to AscendAttentionState.DecodeOnly.
|
||||||
|
input_batch.attn_state = AscendAttentionState.DecodeOnly
|
||||||
return cls(**asdict(input_batch), seq_lens_np=seq_lens_np)
|
return cls(**asdict(input_batch), seq_lens_np=seq_lens_np)
|
||||||
|
|||||||
@@ -17,12 +17,16 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import vllm
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.config.compilation import CUDAGraphMode
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
|
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
|
||||||
from vllm.v1.worker.gpu.input_batch import (
|
from vllm.v1.worker.gpu.input_batch import (
|
||||||
combine_sampled_and_draft_tokens,
|
combine_sampled_and_draft_tokens,
|
||||||
@@ -32,23 +36,23 @@ from vllm.v1.worker.gpu.input_batch import (
|
|||||||
)
|
)
|
||||||
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.utils import set_weight_prefetch_method
|
||||||
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
|
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
|
||||||
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata, build_attn_state
|
from vllm_ascend.worker.v2.attn_utils import build_attn_state
|
||||||
from vllm_ascend.worker.v2.input_batch import AscendInputBatch, AscendInputBuffers
|
from vllm_ascend.worker.v2.input_batch import AscendInputBatch, AscendInputBuffers
|
||||||
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
|
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
|
||||||
from vllm_ascend.worker.v2.spec_decode import init_speculator
|
from vllm_ascend.worker.v2.spec_decode import init_speculator
|
||||||
from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator
|
from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator
|
||||||
from vllm_ascend.worker.v2.states import AscendRequestState
|
from vllm_ascend.worker.v2.states import AscendRequestState
|
||||||
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
|
from vllm_ascend.worker.v2.utils import block_table_wrapper, model_states_wrapper, torch_cuda_wrapper
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class NPUModelRunner(GPUModelRunner):
|
class NPUModelRunner(GPUModelRunner):
|
||||||
"""Model runner for Ascend NPUs."""
|
"""Model runner for Ascend NPUs."""
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||||
with torch_cuda_wrapper():
|
with torch_cuda_wrapper(), block_table_wrapper(), model_states_wrapper():
|
||||||
super().__init__(vllm_config, device)
|
super().__init__(vllm_config, device)
|
||||||
|
|
||||||
# because we will override these attribute, delete these attribute to
|
# because we will override these attribute, delete these attribute to
|
||||||
@@ -62,8 +66,9 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
# NPU specific initializations can be added below.
|
# NPU specific initializations can be added below.
|
||||||
self.cudagraph_manager: AclGraphManager = AclGraphManager(
|
self.cudagraph_manager: AclGraphManager = AclGraphManager(
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
self.uses_mrope,
|
self.use_aux_hidden_state_outputs,
|
||||||
self.device,
|
self.device,
|
||||||
|
self,
|
||||||
)
|
)
|
||||||
|
|
||||||
# we define AscendEagleSpeculator in vllm_ascend.worker.v2.spec_decode.eagle
|
# we define AscendEagleSpeculator in vllm_ascend.worker.v2.spec_decode.eagle
|
||||||
@@ -96,6 +101,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
max_num_reqs=self.max_num_reqs,
|
max_num_reqs=self.max_num_reqs,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
req_states=self.req_states,
|
||||||
logprobs_mode=self.model_config.logprobs_mode,
|
logprobs_mode=self.model_config.logprobs_mode,
|
||||||
num_speculative_tokens=self.num_speculative_steps + 1,
|
num_speculative_tokens=self.num_speculative_steps + 1,
|
||||||
)
|
)
|
||||||
@@ -113,6 +119,59 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ascend-specific configurations
|
||||||
|
self.ascend_config = get_ascend_config()
|
||||||
|
# set this just the same as model runner v1, or it will raise error.
|
||||||
|
set_weight_prefetch_method(self.ascend_config.weight_prefetch_config)
|
||||||
|
|
||||||
|
# we need to update full graph params in run_fullgraph,
|
||||||
|
# so create a stream to update full graph params.
|
||||||
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||||
|
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
||||||
|
|
||||||
|
# we need to use return value of `get_cudagraph_and_dp_padding`
|
||||||
|
# to set forward_context in `run_fullgraph`.
|
||||||
|
# so we can inherit `execute_model` method.
|
||||||
|
self.cudagraph_and_dp_padding: tuple[int, torch.Tensor | None, int] | None = None
|
||||||
|
|
||||||
|
# we need to use input_batch to set forward_context in run_fullgraph.
|
||||||
|
# so we can inherit `execute_model` method.
|
||||||
|
self.input_batch: AscendInputBatch | None = None
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
|
dummy_run: bool = False,
|
||||||
|
skip_attn_for_dummy_run: bool = False,
|
||||||
|
) -> ModelRunnerOutput | IntermediateTensors | None:
|
||||||
|
"""Override GPUModelRunner.execute_model for Ascend NPUs by there reasons:
|
||||||
|
1. when run fullgraph, we need to use ret value of `get_cudagraph_and_dp_padding`
|
||||||
|
to set forward_context in `run_fullgraph`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# use closure to store return value of get_cudagraph_and_dp_padding in model runner.
|
||||||
|
def wrapper(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
self.cudagraph_and_dp_padding = func(*args, **kwargs)
|
||||||
|
return self.cudagraph_and_dp_padding
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
if self.cudagraph_and_dp_padding is None:
|
||||||
|
vllm.v1.worker.gpu.model_runner.get_cudagraph_and_dp_padding = wrapper(
|
||||||
|
vllm.v1.worker.gpu.model_runner.get_cudagraph_and_dp_padding
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().execute_model(
|
||||||
|
scheduler_output,
|
||||||
|
intermediate_tensors,
|
||||||
|
dummy_run,
|
||||||
|
skip_attn_for_dummy_run,
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_inputs(
|
def prepare_inputs(
|
||||||
self,
|
self,
|
||||||
scheduler_output: SchedulerOutput,
|
scheduler_output: SchedulerOutput,
|
||||||
@@ -185,33 +244,40 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
|
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
|
||||||
)
|
)
|
||||||
|
|
||||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
|
||||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
|
||||||
|
|
||||||
# Get query_start_loc.
|
# Get query_start_loc.
|
||||||
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
|
# NOTE: For FULL mode we change +1 to +2 to reserve extra space for padding.
|
||||||
|
# See _pad_query_start_loc_for_fia.
|
||||||
|
query_start_loc_np = np.empty(self.max_num_reqs + 2, dtype=np.int32)
|
||||||
query_start_loc_np[0] = 0
|
query_start_loc_np[0] = 0
|
||||||
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
|
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
|
||||||
# Pad for full CUDA graph mode.
|
# Pad for full CUDA graph mode.
|
||||||
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
|
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
|
||||||
query_start_loc_np[num_reqs + 1 :] = num_tokens
|
query_start_loc_np[num_reqs + 1 :] = num_tokens
|
||||||
|
|
||||||
|
# This is only required for vllm-ascend.
|
||||||
|
query_start_loc_np, num_reqs_padded = self._pad_query_start_loc_for_fia(
|
||||||
|
num_tokens_padded=num_tokens_after_padding,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
num_reqs=num_reqs,
|
||||||
|
query_start_loc_np=query_start_loc_np,
|
||||||
|
max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
|
||||||
|
)
|
||||||
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
|
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
|
||||||
|
|
||||||
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
|
query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1]
|
||||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
|
||||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||||
max_query_len = num_scheduled_tokens.max().item()
|
|
||||||
|
|
||||||
# Get prefill tokens.
|
# Get prefill tokens if any.
|
||||||
prepare_prefill_inputs(
|
if self.req_states.any_prefills(idx_mapping_np):
|
||||||
self.input_buffers.input_ids,
|
prepare_prefill_inputs(
|
||||||
self.req_states.next_prefill_tokens,
|
self.input_buffers.input_ids,
|
||||||
idx_mapping,
|
self.req_states.next_prefill_tokens,
|
||||||
query_start_loc,
|
idx_mapping,
|
||||||
self.req_states.prefill_token_ids.gpu,
|
query_start_loc,
|
||||||
self.req_states.prefill_len.gpu,
|
self.req_states.all_token_ids.gpu,
|
||||||
self.req_states.num_computed_tokens.gpu,
|
self.req_states.prefill_len.gpu,
|
||||||
)
|
self.req_states.num_computed_tokens.gpu,
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare positions and seq_lens.
|
# Prepare positions and seq_lens.
|
||||||
prepare_pos_seq_lens(
|
prepare_pos_seq_lens(
|
||||||
@@ -223,14 +289,8 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
)
|
)
|
||||||
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
||||||
|
|
||||||
# Prepare M-RoPE positions.
|
# Pad for full CUDA graph mode.
|
||||||
if self.uses_mrope:
|
self.input_buffers.seq_lens_np[num_reqs_padded:] = 0
|
||||||
self.mrope_states.prepare_mrope_positions(
|
|
||||||
idx_mapping,
|
|
||||||
query_start_loc,
|
|
||||||
self.req_states.prefill_len.gpu,
|
|
||||||
self.req_states.num_computed_tokens.gpu,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Some input token ids are directly read from the last sampled tokens
|
# Some input token ids are directly read from the last sampled tokens
|
||||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||||
@@ -246,43 +306,12 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
total_num_logits,
|
total_num_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
|
||||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
|
||||||
idx_mapping, query_start_loc, self.input_buffers.positions[:num_tokens]
|
|
||||||
)
|
|
||||||
# Layer name -> slot mapping.
|
|
||||||
slot_mappings_by_layer = build_slot_mappings_by_layer(slot_mappings, self.kv_cache_config)
|
|
||||||
# Layer name -> attention metadata.
|
|
||||||
# TODO(Ronald1995): try to add a new method `build_attn_metadata` in
|
|
||||||
# vllm gpu_model_runner_v2, maybe we don't overwrite `prepare_inputs`
|
|
||||||
# method like this.
|
|
||||||
attn_metadata = build_attn_metadata(
|
|
||||||
attn_metadata_builders=self.attn_metadata_builders,
|
|
||||||
num_reqs=num_reqs,
|
|
||||||
num_tokens=num_tokens,
|
|
||||||
query_start_loc_gpu=query_start_loc,
|
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
|
||||||
max_query_len=max_query_len,
|
|
||||||
seq_lens=self.input_buffers.seq_lens,
|
|
||||||
max_seq_len=self.max_model_len,
|
|
||||||
block_tables=block_tables,
|
|
||||||
slot_mappings=slot_mappings,
|
|
||||||
kv_cache_config=self.kv_cache_config,
|
|
||||||
# extra attributes for ascend npus.
|
|
||||||
seq_lens_np=self.input_buffers.seq_lens_np,
|
|
||||||
num_computed_tokens_cpu=self.req_states.num_computed_tokens_cpu[idx_mapping_cpu],
|
|
||||||
attn_state=attn_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||||
mrope_positions = None
|
|
||||||
if self.uses_mrope:
|
self.input_batch = AscendInputBatch(
|
||||||
mrope_positions = self.mrope_states.mrope_positions
|
|
||||||
mrope_positions = mrope_positions[:, :num_tokens_after_padding]
|
|
||||||
return AscendInputBatch(
|
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs_padded,
|
||||||
idx_mapping=idx_mapping,
|
idx_mapping=idx_mapping,
|
||||||
idx_mapping_np=idx_mapping_np,
|
idx_mapping_np=idx_mapping_np,
|
||||||
expanded_idx_mapping=expanded_idx_mapping,
|
expanded_idx_mapping=expanded_idx_mapping,
|
||||||
@@ -294,18 +323,18 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
query_start_loc_np=query_start_loc_np,
|
query_start_loc_np=query_start_loc_np,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
dcp_local_seq_lens=None, # TODO(Ronald1995): support cp.
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
mrope_positions=mrope_positions,
|
|
||||||
inputs_embeds=None,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
slot_mappings=slot_mappings_by_layer,
|
|
||||||
logits_indices=logits_indices,
|
logits_indices=logits_indices,
|
||||||
cu_num_logits=cu_num_logits,
|
cu_num_logits=cu_num_logits,
|
||||||
cu_num_logits_np=cu_num_logits_np,
|
cu_num_logits_np=cu_num_logits_np,
|
||||||
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
|
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
|
||||||
|
# extra attributes for ascend npus.
|
||||||
seq_lens_np=self.input_buffers.seq_lens_np,
|
seq_lens_np=self.input_buffers.seq_lens_np,
|
||||||
|
attn_state=attn_state,
|
||||||
)
|
)
|
||||||
|
return self.input_batch
|
||||||
|
|
||||||
def postprocess(
|
def postprocess(
|
||||||
self,
|
self,
|
||||||
@@ -352,7 +381,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
self.req_states.num_computed_tokens_cpu[req_index] = self.num_computed_tokens_cpu[req_index]
|
self.req_states.num_computed_tokens_cpu[req_index] = self.num_computed_tokens_cpu[req_index]
|
||||||
|
|
||||||
# update seq_lens_cpu
|
# update seq_lens_cpu
|
||||||
for i, req_id in enumerate(req_ids):
|
for i, req_id in enumerate(req_ids): # type: ignore
|
||||||
req_index = self.req_states.req_id_to_index[req_id]
|
req_index = self.req_states.req_id_to_index[req_id]
|
||||||
num_computed_tokens = self.req_states.num_computed_tokens_cpu[req_index]
|
num_computed_tokens = self.req_states.num_computed_tokens_cpu[req_index]
|
||||||
self.input_buffers.seq_lens_cpu[i] = num_computed_tokens + num_scheduled_tokens[req_id]
|
self.input_buffers.seq_lens_cpu[i] = num_computed_tokens + num_scheduled_tokens[req_id]
|
||||||
@@ -361,3 +390,44 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
# TODO(Ronald1995): just define the method in case calling error in
|
# TODO(Ronald1995): just define the method in case calling error in
|
||||||
# worker, implement it in the future.
|
# worker, implement it in the future.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _pad_query_start_loc_for_fia(
|
||||||
|
self,
|
||||||
|
num_tokens_padded: int,
|
||||||
|
num_tokens: int,
|
||||||
|
num_reqs: int,
|
||||||
|
query_start_loc_np: np.ndarray,
|
||||||
|
max_query_len: int,
|
||||||
|
) -> tuple[np.ndarray, int]:
|
||||||
|
"""
|
||||||
|
This function is only designed to satisfied the constraint that when the layout is TND,
|
||||||
|
the first dimension of `hidden_states` must equal the last element of `actual_seq_lengths_q`.
|
||||||
|
"""
|
||||||
|
assert self.cudagraph_and_dp_padding is not None
|
||||||
|
_num_tokens_after_padding, _num_tokens_across_dp, synced_cudagraph_mode = self.cudagraph_and_dp_padding
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
|
||||||
|
if cudagraph_runtime_mode != CUDAGraphMode.FULL:
|
||||||
|
return query_start_loc_np, num_reqs
|
||||||
|
uniform_decode_query_len = self.cudagraph_manager.uniform_decode_query_len
|
||||||
|
is_uniform_decode = self.cudagraph_manager.is_uniform_decode(
|
||||||
|
num_reqs=num_reqs,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
)
|
||||||
|
if is_uniform_decode:
|
||||||
|
# Uniform-batch case: num_reqs must be no greater than num_reqs_padded
|
||||||
|
num_reqs_padded = num_tokens_padded // uniform_decode_query_len
|
||||||
|
|
||||||
|
last_loc = query_start_loc_np[num_reqs]
|
||||||
|
query_start_loc_np[num_reqs + 1 : num_reqs_padded + 1] = (
|
||||||
|
np.arange(1, num_reqs_padded + 1 - num_reqs) * uniform_decode_query_len + last_loc
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Mixed-batch case: num_reqs must equal num_reqs_padded
|
||||||
|
num_reqs_padded = min(num_tokens_padded, self.max_num_reqs)
|
||||||
|
|
||||||
|
# Insert a dummy request instead of setting query_start_loc[num_reqs] = num_tokens_padded directly
|
||||||
|
query_start_loc_np[num_reqs_padded + 1] = num_tokens_padded
|
||||||
|
num_reqs_padded = num_reqs_padded + 1
|
||||||
|
|
||||||
|
return query_start_loc_np, num_reqs_padded
|
||||||
|
|||||||
34
vllm_ascend/worker/v2/model_states/__init__.py
Normal file
34
vllm_ascend/worker/v2/model_states/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_states/__init__.py
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||||
|
|
||||||
|
|
||||||
|
def init_asecnd_model_state(
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
model: nn.Module,
|
||||||
|
encoder_cache: EncoderCache | None,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
from vllm_ascend.worker.v2.model_states.default import AscendModelState
|
||||||
|
|
||||||
|
return AscendModelState(vllm_config, model, encoder_cache, device)
|
||||||
62
vllm_ascend/worker/v2/model_states/default.py
Normal file
62
vllm_ascend/worker/v2/model_states/default.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_states/default.py
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
|
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
|
||||||
|
from vllm.v1.worker.utils import AttentionGroup
|
||||||
|
|
||||||
|
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
|
||||||
|
from vllm_ascend.worker.v2.input_batch import AscendInputBatch
|
||||||
|
|
||||||
|
|
||||||
|
class AscendModelState(DefaultModelState):
|
||||||
|
"""Model state for Ascend NPUs."""
|
||||||
|
|
||||||
|
def prepare_attn(
|
||||||
|
self,
|
||||||
|
input_batch: AscendInputBatch,
|
||||||
|
block_tables: tuple[torch.Tensor, ...],
|
||||||
|
slot_mappings: torch.Tensor,
|
||||||
|
attn_groups: list[list[AttentionGroup]],
|
||||||
|
kv_cache_config: KVCacheConfig,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Override prepare_attn method because `build_attn_metadata` is different from vllm."""
|
||||||
|
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
|
||||||
|
max_query_len = input_batch.num_scheduled_tokens.max().item()
|
||||||
|
attn_metadata = build_attn_metadata(
|
||||||
|
attn_groups=attn_groups,
|
||||||
|
num_reqs=input_batch.num_reqs,
|
||||||
|
num_tokens=input_batch.num_tokens,
|
||||||
|
query_start_loc_gpu=input_batch.query_start_loc,
|
||||||
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
seq_lens=input_batch.seq_lens,
|
||||||
|
max_seq_len=self.max_model_len,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slot_mappings=slot_mappings,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
|
||||||
|
# extra attributes for ascend npus.
|
||||||
|
seq_lens_np=input_batch.seq_lens_np,
|
||||||
|
attn_state=input_batch.attn_state,
|
||||||
|
)
|
||||||
|
return attn_metadata
|
||||||
@@ -16,9 +16,6 @@
|
|||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
|
||||||
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature
|
|
||||||
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
|
|
||||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||||
|
|
||||||
from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample
|
from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample
|
||||||
@@ -53,21 +50,23 @@ class AscendSampler(Sampler):
|
|||||||
self.num_speculative_tokens,
|
self.num_speculative_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply bad words masking in place.
|
||||||
|
self.bad_words_state.apply_bad_words(
|
||||||
|
logits,
|
||||||
|
idx_mapping,
|
||||||
|
idx_mapping_np,
|
||||||
|
input_ids,
|
||||||
|
expanded_local_pos,
|
||||||
|
)
|
||||||
|
|
||||||
# Apply temperature in place.
|
# Apply temperature in place.
|
||||||
apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu)
|
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
|
||||||
|
|
||||||
# Apply min_p in place if any request has a non-zero min_p.
|
# Apply min_p in place.
|
||||||
do_min_p = self.sampling_states.do_min_p(idx_mapping_np)
|
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
|
||||||
if do_min_p:
|
|
||||||
apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu)
|
|
||||||
|
|
||||||
# Apply top_k and/or top_p. This might return a new tensor.
|
# Apply top_k and/or top_p. This might or might not return a new tensor.
|
||||||
do_top_k = self.sampling_states.do_top_k(idx_mapping_np)
|
logits = self.sampling_states.apply_top_k_top_p(logits, idx_mapping, idx_mapping_np)
|
||||||
top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None
|
|
||||||
do_top_p = self.sampling_states.do_top_p(idx_mapping_np)
|
|
||||||
top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None
|
|
||||||
if do_top_k or do_top_p:
|
|
||||||
logits = apply_top_k_top_p(logits, top_k, top_p)
|
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
sampled = gumbel_sample(
|
sampled = gumbel_sample(
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import torch
|
|||||||
import vllm
|
import vllm
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||||
from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator
|
from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator
|
||||||
|
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
|
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
|
||||||
|
|||||||
@@ -56,13 +56,13 @@ class AscendRequestState(RequestState):
|
|||||||
self,
|
self,
|
||||||
req_id,
|
req_id,
|
||||||
prompt_len,
|
prompt_len,
|
||||||
prefill_token_ids,
|
all_token_ids,
|
||||||
num_computed_tokens,
|
num_computed_tokens,
|
||||||
):
|
):
|
||||||
super().add_request(
|
super().add_request(
|
||||||
req_id,
|
req_id,
|
||||||
prompt_len,
|
prompt_len,
|
||||||
prefill_token_ids,
|
all_token_ids,
|
||||||
num_computed_tokens,
|
num_computed_tokens,
|
||||||
)
|
)
|
||||||
req_idx = self.req_id_to_index[req_id]
|
req_idx = self.req_id_to_index[req_id]
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import vllm
|
||||||
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
from vllm_ascend.worker.v2.block_table import AscendBlockTables
|
||||||
|
from vllm_ascend.worker.v2.model_states import init_asecnd_model_state
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -15,6 +20,34 @@ def torch_cuda_wrapper():
|
|||||||
torch.cuda.CUDAGraph = torch.npu.NPUGraph
|
torch.cuda.CUDAGraph = torch.npu.NPUGraph
|
||||||
torch.cuda.graph = torch.npu.graph
|
torch.cuda.graph = torch.npu.graph
|
||||||
torch.cuda.synchronize = torch.npu.synchronize
|
torch.cuda.synchronize = torch.npu.synchronize
|
||||||
|
torch.cuda.set_stream = torch.npu.set_stream
|
||||||
|
torch.cuda.current_device = torch.npu.current_device
|
||||||
|
torch.cuda.mem_get_info = torch.npu.mem_get_info
|
||||||
|
logger.info_once("Wrapping torch.cuda with torch.npu.")
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def block_table_wrapper():
|
||||||
|
try:
|
||||||
|
# vllm-ascend need to initialize slot mapping as torch.int32 dtype,
|
||||||
|
# but vllm default is torch.int64 dtype.
|
||||||
|
vllm.v1.worker.gpu.model_runner.BlockTables = AscendBlockTables
|
||||||
|
logger.info_once("Wrapping BlockTables with AscendBlockTables.")
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def model_states_wrapper():
|
||||||
|
try:
|
||||||
|
# prepare_attn in AscendModelState is different from vllm,
|
||||||
|
# we need to override init_model_state.
|
||||||
|
vllm.v1.worker.gpu.model_runner.init_model_state = init_asecnd_model_state
|
||||||
|
logger.info_once("Wrapping init_model_state with init_asecnd_model_state.")
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user