diff --git a/tests/ut/compilation/test_acl_graph.py b/tests/ut/compilation/test_acl_graph.py index 77d779ca..ce2cf592 100644 --- a/tests/ut/compilation/test_acl_graph.py +++ b/tests/ut/compilation/test_acl_graph.py @@ -24,15 +24,12 @@ from vllm.forward_context import BatchDescriptor, ForwardContext from tests.ut.base import TestBase from vllm_ascend.attention.attention_v1 import (AscendMetadata, AscendMetadataForDecode) -from vllm_ascend.attention.context_parallel.attention_cp import \ - AscendAttentionCPImpl -from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, AscendMLAMetadata) from vllm_ascend.compilation.acl_graph import ( ACLGraphEntry, ACLGraphWrapper, get_draft_graph_params, get_graph_params, - set_draft_graph_params, set_graph_params, - update_draft_graph_params_workspaces) + set_draft_graph_params, set_graph_params, update_attn_dcp_pcp_params, + update_draft_graph_params_workspaces, update_mla_attn_dcp_pcp_params) class TestACLGraphEntry(TestBase): @@ -814,9 +811,8 @@ class TestPCPDCPGraphParams(TestBase): out, lse)) with patch("torch_npu._C._npu_setStream", return_value=None): - AscendMlaCPImpl.update_graph_params( - self.update_stream, forward_context, 4 - ) + update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, + 4) _mock_graph_task_end.assert_called_once() @@ -856,8 +852,6 @@ class TestPCPDCPGraphParams(TestBase): out, lse, 2, 0, 0)) with patch("torch_npu._C._npu_setStream", return_value=None): - AscendAttentionCPImpl.update_graph_params( - self.update_stream, forward_context, 4, None - ) + update_attn_dcp_pcp_params(self.update_stream, forward_context, 4) _mock_graph_task_end.assert_called_once() diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 5f3c8eb8..2e83474b 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -333,11 +333,11 @@ class TestEagleProposerDummyRun(TestBase): self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4) self.assertTrue(self.proposer._runnable.call_count == 1) - @patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params") + @patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params") @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context, - mock_update_full_graph_params): + mock_update_attn_params): last_use_cuda_graph = self.proposer.use_cuda_graph mock_return_context = MagicMock() mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL @@ -352,14 +352,14 @@ class TestEagleProposerDummyRun(TestBase): in_graph_capturing=True, aclgraph_runtime_mode=CUDAGraphMode.FULL) self.assertTrue(self.proposer._runnable.call_count == 1) - mock_update_full_graph_params.assert_not_called() + mock_update_attn_params.assert_not_called() self.proposer.use_cuda_graph = last_use_cuda_graph - @patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params") + @patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params") @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_in_graph_run(self, mock_context, mock_get_context, - mock_update_full_graph_params): + mock_update_attn_params): last_use_cuda_graph = self.proposer.use_cuda_graph mock_return_context = MagicMock() mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL @@ -374,7 +374,7 @@ class TestEagleProposerDummyRun(TestBase): in_graph_capturing=False, aclgraph_runtime_mode=CUDAGraphMode.FULL) self.assertTrue(self.proposer._runnable.call_count == 1) - self.assertTrue(mock_update_full_graph_params.call_count == 1) + self.assertTrue(mock_update_attn_params.call_count == 1) self.proposer.use_cuda_graph = last_use_cuda_graph diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 5a8f6a10..2f0a8348 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -371,144 +371,6 @@ class AscendAttentionBackendImpl(AttentionImpl): self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer ) - @staticmethod - def update_graph_params( - update_stream, - forward_context, - num_tokens, - vllm_config, - speculative_config=None, - num_dcp_pcp_tokens=None, - ): - if using_paged_attention(num_tokens, vllm_config): - # Paged Attention update logic - if forward_context.is_draft_model: - graph_params = get_draft_graph_params() - else: - graph_params = get_graph_params() - with torch.npu.stream(update_stream): - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[num_tokens], - graph_params.handles[num_tokens], - graph_params.events[num_tokens], - ): - ( - query, - key_cache, - value_cache, - num_kv_heads, - num_heads, - scale, - block_table, - seq_lens, - output, - ) = param - seq_lens = forward_context.attn_metadata[key].seq_lens - - workspace = torch_npu._npu_paged_attention_get_workspace( - query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output, - ) - torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu._npu_paged_attention( - query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output, - workspace=workspace, - ) - torch.npu.graph_task_update_end(update_stream) - event.record(update_stream) - else: - # FIA update logic - if forward_context.is_draft_model: - graph_params = get_draft_graph_params() - attn_metadata = forward_context.draft_attn_metadatas - attn_keys = list(attn_metadata[0].keys()) - else: - graph_params = get_graph_params() - attn_metadata = forward_context.attn_metadata - attn_keys = list(attn_metadata.keys()) - # For Qwen3-next, since the kv_cache_config has already categorized - # linear_attn and self_attn, the attn_metadata is first arranged with - # self_attn followed by linear_attn. Therefore, using zip directly - # filters out the update operations for linear_attn. - # TODO: We use a new variable `attn_keys` to ensure the loop count is - # correct after get by `zip` because of the new structure of the attn_metadata - # when running with the merged full eagle-graph. Should check it with Qwen3-next. - num_layers = len(attn_keys) - if num_layers == 0: - return - if forward_context.is_draft_model: - attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers) - attn_count = 0 - with torch.npu.stream(update_stream): - for key, param, handle, event in zip( - attn_keys, - graph_params.attn_params[num_tokens], - graph_params.handles[num_tokens], - graph_params.events[num_tokens], - ): - ( - query, - key_cache, - value, - block_tables, - attn_mask, - block_size, - seq_lens, - query_start_loc, - num_kv_heads, - num_heads, - scale, - attn_output, - softmax_lse, - ) = param - - if forward_context.is_draft_model: - draft_step = attn_count // num_layers - seq_lens = attn_metadata[draft_step][key].seq_lens_list - actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q - attn_count = attn_count + 1 - else: - seq_lens = attn_metadata[key].seq_lens_list - actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q - - torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu.npu_fused_infer_attention_score.out( - query=query, - key=key_cache, - value=value, - block_table=block_tables, - atten_mask=attn_mask, - input_layout="TND", - block_size=block_size, - actual_seq_lengths=actual_seq_lengths_q, - actual_seq_lengths_kv=seq_lens, - num_key_value_heads=num_kv_heads, - num_heads=num_heads, - scale=scale, - sparse_mode=3, - workspace=graph_params.workspaces.get(num_tokens), - out=[attn_output, softmax_lse], - ) - torch.npu.graph_task_update_end(update_stream) - - event.record(update_stream) - def process_weights_after_loading(self, act_dtype: torch.dtype): super().process_weights_after_loading(act_dtype) if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 04f3d956..ae406aa9 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -276,79 +276,6 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None - @staticmethod - def update_graph_params( - update_stream, - forward_context, - num_tokens, - vllm_config, - speculative_config=None, - num_dcp_pcp_tokens=None, - ): - graph_params = get_graph_params() - # FIXME: Behold! We are using a temporary hack here to update the args - # for each layer's attention op in the graph. - with torch.npu.stream(update_stream): - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[num_tokens], - graph_params.handles[num_tokens], - graph_params.events[num_tokens], - ): - ( - q_nope, - k_nope, - value, - num_heads, - num_kv_heads, - scale, - block_table, - block_size, - actual_seq_lengths_kv, - actual_seq_lengths_q, - attn_output, - softmax_lse, - dcp_size, - pcp_rank, - dcp_rank, - ) = param - attn_metadata = forward_context.attn_metadata[key] - actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank] - pad_length = num_tokens - len(actual_seq_lengths_kv) - if pad_length > 0: - pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype) - actual_seq_lengths_kv = np.concatenate([actual_seq_lengths_kv, pad_tensor]) - - actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q - - if dcp_size > 1: - num_heads = num_heads * dcp_size - - torch.npu.graph_task_update_begin(update_stream, handle) - - torch_npu.npu_fused_infer_attention_score.out( - q_nope, - k_nope, - value, - num_heads=num_heads, - num_key_value_heads=num_kv_heads, - input_layout="TND", - atten_mask=None, - scale=scale, - antiquant_mode=0, - antiquant_scale=None, - softmax_lse_flag=True, - block_table=block_table, - block_size=block_size, - actual_seq_lengths_kv=actual_seq_lengths_kv, - actual_seq_lengths=actual_seq_lengths_q, - workspace=graph_params.workspaces.get(num_tokens), - out=[attn_output, softmax_lse], - ) - torch.npu.graph_task_update_end(update_stream) - - event.record(update_stream) - def _attention_with_nomask_and_mask( self, q: torch.Tensor, diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index e0cd7998..a45d9397 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -284,85 +284,6 @@ class AscendMlaCPImpl(AscendMLAImpl): self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None - @staticmethod - def update_graph_params( - update_stream, - forward_context, - num_tokens, - vllm_config=None, - speculative_config=None, - num_dcp_pcp_tokens=None, - ): - if forward_context.is_draft_model: - graph_params = get_draft_graph_params() - else: - graph_params = get_graph_params() - # FIXME: Behold! We are using a temporary hack here to update the args - # for each layer's attention op in the graph. - with torch.npu.stream(update_stream): - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[num_tokens], - graph_params.handles[num_tokens], - graph_params.events[num_tokens], - ): - ( - q_nope, - k_nope, - q_pe, - k_pe, - num_heads, - num_kv_heads, - input_layout, - spec_attn_mask, - sparse_mode, - scale, - block_table, - block_size, - actual_seq_lengths, - actual_seq_lengths_kv, - attn_output, - softmax_lse, - ) = param - - decode_meta = forward_context.attn_metadata[key].decode - seq_len = decode_meta.cp_seq_len - if isinstance(seq_len, torch.Tensor): - seq_len = seq_len.tolist() - actual_seq_lengths_kv = seq_len - - pad_length = num_tokens - len(actual_seq_lengths_kv) - if pad_length > 0: - actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (num_tokens - len(actual_seq_lengths_kv)) - - torch.npu.graph_task_update_begin(update_stream, handle) - - torch_npu.npu_fused_infer_attention_score.out( - q_nope, - k_nope, - k_nope, - query_rope=q_pe, - key_rope=k_pe, - num_heads=num_heads, - num_key_value_heads=num_kv_heads, - input_layout=input_layout, - atten_mask=spec_attn_mask, - sparse_mode=sparse_mode, - scale=scale, - antiquant_mode=0, - antiquant_scale=None, - softmax_lse_flag=True, - block_table=block_table, - block_size=block_size, - actual_seq_lengths_kv=actual_seq_lengths_kv, - actual_seq_lengths=actual_seq_lengths, - workspace=graph_params.workspaces.get(num_tokens), - out=[attn_output, softmax_lse], - ) - torch.npu.graph_task_update_end(update_stream) - - event.record(update_stream) - def get_num_actual_tokens(self, attn_metadata: M): if self.pcp_size > 1: return attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 5b81f3ba..832f53c0 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -720,88 +720,6 @@ class AscendMLAImpl(MLAAttentionImpl): ) register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) - @staticmethod - def update_graph_params( - update_stream, - forward_context, - num_tokens, - vllm_config=None, - speculative_config=None, - num_dcp_pcp_tokens=None, - ): - if forward_context.is_draft_model: - graph_params = get_draft_graph_params() - else: - graph_params = get_graph_params() - # FIXME: Behold! We are using a temporary hack here to update the args - # for each layer's attention op in the graph. - with torch.npu.stream(update_stream): - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[num_tokens], - graph_params.handles[num_tokens], - graph_params.events[num_tokens], - ): - ( - q_nope, - k_nope, - q_pe, - k_pe, - num_heads, - num_kv_heads, - input_layout, - attn_mask, - sparse_mode, - scale, - block_table, - block_size, - seq_lens_list, - actual_seq_lengths, - attn_output, - softmax_lse, - ) = param - seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list - if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model: - actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q - spec_multiple = speculative_config.num_speculative_tokens + 1 - seq_lens_list = seq_lens_list + [0] * (num_tokens // spec_multiple - len(seq_lens_list)) - actual_seq_lengths = [spec_multiple * (i + 1) for i in range(num_tokens // spec_multiple)] - elif forward_context.is_draft_model: - actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q - block_table = forward_context.attn_metadata[key].decode.block_table - # TODO: This is a hack and should be fixed in the future. - if speculative_config.disable_padded_drafter_batch: - block_table = block_table[: len(actual_seq_lengths)] - seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list)) - else: - seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list)) - torch.npu.graph_task_update_begin(update_stream, handle) - - torch_npu.npu_fused_infer_attention_score.out( - q_nope, - k_nope, - k_nope, - query_rope=q_pe, - key_rope=k_pe, - num_heads=num_heads, - num_key_value_heads=num_kv_heads, - input_layout=input_layout, - atten_mask=attn_mask, - sparse_mode=sparse_mode, - scale=scale, - antiquant_mode=0, - antiquant_scale=None, - block_table=block_table, - block_size=block_size, - actual_seq_lengths_kv=seq_lens_list, - actual_seq_lengths=actual_seq_lengths, - workspace=graph_params.workspaces.get(num_tokens), - out=[attn_output, softmax_lse], - ) - torch.npu.graph_task_update_end(update_stream) - - event.record(update_stream) - def _v_up_proj(self, x): # Convert from (N, B, L)/(N, B, 1, L) to (N, B, L) x = x.view(self.num_heads, -1, self.kv_lora_rank) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 18db774c..56613ac9 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from typing import Any from unittest.mock import patch +import numpy as np import torch import torch_npu import vllm.envs as envs @@ -19,6 +20,8 @@ from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.platforms import current_platform +from vllm_ascend.attention.utils import using_paged_attention + from ..utils import weak_ref_tensors @@ -210,24 +213,343 @@ def weak_ref_workspaces(params): params.workspaces[num_tokens] = weak_ref_tensors(params.workspaces[num_tokens]) -def update_full_graph_params( - attn_backend, - update_stream, - forward_context, - num_tokens, - vllm_config, - speculative_config=None, - num_dcp_pcp_tokens=None, -): - impl_cls = attn_backend.get_impl_cls() - impl_cls.update_graph_params( - update_stream, - forward_context, - num_tokens, - vllm_config, - speculative_config, - num_dcp_pcp_tokens, - ) +def _update_attn_pa_params(update_stream, forward_context, runtime_shape): + graph_params = get_graph_params() + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + output, + ) = param + seq_lens = forward_context.attn_metadata[key].seq_lens + + torch.npu.graph_task_update_begin(update_stream, handle) + torch_npu._npu_paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=graph_params.workspaces.get(runtime_shape), + ) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + +def _update_attn_fia_params(update_stream, forward_context, runtime_shape, draft_attn_metadatas=None): + if forward_context.is_draft_model: + graph_params = get_draft_graph_params() + attn_metadata = draft_attn_metadatas + attn_keys = list(attn_metadata[0].keys()) + else: + graph_params = get_graph_params() + attn_metadata = forward_context.attn_metadata + attn_keys = list(attn_metadata.keys()) + # For Qwen3-next, since the kv_cache_config has already categorized + # linear_attn and self_attn, the attn_metadata is first arranged with + # self_attn followed by linear_attn. Therefore, using zip directly + # filters out the update operations for linear_attn. + # TODO: We use a new variable `attn_keys` to ensure the loop count is + # correct after get by `zip` because of the new structure of the attn_metadata + # when running with the merged full eagle-graph. Should check it with Qwen3-next. + num_layers = len(attn_keys) + if num_layers == 0: + return + if forward_context.is_draft_model: + attn_keys = attn_keys * (len(graph_params.attn_params[runtime_shape]) // num_layers) + attn_count = 0 + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + attn_keys, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value, + block_tables, + attn_mask, + block_size, + seq_lens, + query_start_loc, + num_kv_heads, + num_heads, + scale, + attn_output, + softmax_lse, + ) = param + + if forward_context.is_draft_model: + draft_step = attn_count // num_layers + seq_lens = attn_metadata[draft_step][key].seq_lens_list + actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q + attn_count = attn_count + 1 + else: + seq_lens = attn_metadata[key].seq_lens_list + actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q + + torch.npu.graph_task_update_begin(update_stream, handle) + torch_npu.npu_fused_infer_attention_score.out( + query=query, + key=key_cache, + value=value, + block_table=block_tables, + atten_mask=attn_mask, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=actual_seq_lengths_q, + actual_seq_lengths_kv=seq_lens, + num_key_value_heads=num_kv_heads, + num_heads=num_heads, + scale=scale, + sparse_mode=3, + workspace=graph_params.workspaces.get(runtime_shape), + out=[attn_output, softmax_lse], + ) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + +def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config, draft_attn_metadatas=None): + if using_paged_attention(runtime_shape, vllm_config): + _update_attn_pa_params(update_stream, forward_context, runtime_shape) + else: + _update_attn_fia_params(update_stream, forward_context, runtime_shape, draft_attn_metadatas) + + +def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config): + if forward_context.is_draft_model: + graph_params = get_draft_graph_params() + else: + graph_params = get_graph_params() + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + q_nope, + k_nope, + q_pe, + k_pe, + num_heads, + num_kv_heads, + input_layout, + attn_mask, + sparse_mode, + scale, + block_table, + block_size, + seq_lens_list, + actual_seq_lengths, + attn_output, + softmax_lse, + ) = param + seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list + if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model: + actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q + spec_multiple = speculative_config.num_speculative_tokens + 1 + seq_lens_list = seq_lens_list + [0] * (runtime_shape // spec_multiple - len(seq_lens_list)) + actual_seq_lengths = [spec_multiple * (i + 1) for i in range(runtime_shape // spec_multiple)] + elif forward_context.is_draft_model: + actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q + block_table = forward_context.attn_metadata[key].decode.block_table + # TODO: This is a hack and should be fixed in the future. + if speculative_config.disable_padded_drafter_batch: + block_table = block_table[: len(actual_seq_lengths)] + seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list)) + else: + seq_lens_list = seq_lens_list + [0] * (runtime_shape - len(seq_lens_list)) + torch.npu.graph_task_update_begin(update_stream, handle) + + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=num_heads, + num_key_value_heads=num_kv_heads, + input_layout=input_layout, + atten_mask=attn_mask, + sparse_mode=sparse_mode, + scale=scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=seq_lens_list, + actual_seq_lengths=actual_seq_lengths, + workspace=graph_params.workspaces.get(runtime_shape), + out=[attn_output, softmax_lse], + ) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + +def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + graph_params = get_graph_params() + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + q_nope, + k_nope, + value, + num_heads, + num_kv_heads, + scale, + block_table, + block_size, + actual_seq_lengths_kv, + actual_seq_lengths_q, + attn_output, + softmax_lse, + dcp_size, + pcp_rank, + dcp_rank, + ) = param + attn_metadata = forward_context.attn_metadata[key] + actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank] + pad_length = runtime_shape - len(actual_seq_lengths_kv) + if pad_length > 0: + pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype) + actual_seq_lengths_kv = np.concatenate([actual_seq_lengths_kv, pad_tensor]) + + actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q + + if dcp_size > 1: + num_heads = num_heads * dcp_size + + torch.npu.graph_task_update_begin(update_stream, handle) + + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + k_nope, + value, + num_heads=num_heads, + num_key_value_heads=num_kv_heads, + input_layout="TND", + atten_mask=None, + scale=scale, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=actual_seq_lengths_kv, + actual_seq_lengths=actual_seq_lengths_q, + workspace=graph_params.workspaces.get(runtime_shape), + out=[attn_output, softmax_lse], + ) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + +def update_mla_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): + if forward_context.is_draft_model: + graph_params = get_draft_graph_params() + else: + graph_params = get_graph_params() + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + q_nope, + k_nope, + q_pe, + k_pe, + num_heads, + num_kv_heads, + input_layout, + spec_attn_mask, + sparse_mode, + scale, + block_table, + block_size, + actual_seq_lengths, + actual_seq_lengths_kv, + attn_output, + softmax_lse, + ) = param + + decode_meta = forward_context.attn_metadata[key].decode + seq_len = decode_meta.cp_seq_len + if isinstance(seq_len, torch.Tensor): + seq_len = seq_len.tolist() + actual_seq_lengths_kv = seq_len + + pad_length = runtime_shape - len(actual_seq_lengths_kv) + if pad_length > 0: + actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (runtime_shape - len(actual_seq_lengths_kv)) + + torch.npu.graph_task_update_begin(update_stream, handle) + + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=num_heads, + num_key_value_heads=num_kv_heads, + input_layout=input_layout, + atten_mask=spec_attn_mask, + sparse_mode=sparse_mode, + scale=scale, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=actual_seq_lengths_kv, + actual_seq_lengths=actual_seq_lengths, + workspace=graph_params.workspaces.get(runtime_shape), + out=[attn_output, softmax_lse], + ) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) @dataclass diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py index c9d2cc1d..44377b4d 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py @@ -416,7 +416,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase) # NOTE: Must process Attention/MLAAttention before MambaBase to maintain - # ordering expected by graph parameter update logic in attention backends. + # ordering expected by acl_graph.py's _update_attn_fia_params. mamba_layers: dict[str, MambaBase] = {} for layer_name, attn_module in attn_layers.items(): if isinstance(attn_module, Attention): diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index a6fc106a..ccf18d7f 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -36,7 +36,10 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, - update_full_graph_params) + update_attn_dcp_pcp_params, + update_attn_params, + update_mla_attn_dcp_pcp_params, + update_mla_attn_params) from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.ops.triton.spec_decode.utils import \ prepare_inputs_padded_kernel @@ -1178,9 +1181,21 @@ class EagleProposer(VllmEagleProposer): # update full-graph params for one spec token def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None): - update_full_graph_params( - self.runner.attn_backend, self.update_stream, forward_context, num_tokens, - self.vllm_config, self.vllm_config.speculative_config) + if self.vllm_config.model_config.use_mla: + if self.pcp_size * self.dcp_size > 1: + update_mla_attn_dcp_pcp_params(self.update_stream, + forward_context, num_tokens) + else: + update_mla_attn_params(self.update_stream, forward_context, + num_tokens, + self.vllm_config.speculative_config) + else: + if self.pcp_size * self.dcp_size > 1: + update_attn_dcp_pcp_params(self.update_stream, forward_context, + num_tokens) + else: + update_attn_params(self.update_stream, forward_context, + num_tokens, self.vllm_config, draft_attn_metadatas) # padding tensor into desired size def _pad_tensor(self, tensor, pad_size): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 59c432c9..4805486b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -84,7 +84,10 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_pag from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, set_draft_graph_params, set_graph_params, - update_full_graph_params) + update_attn_dcp_pcp_params, + update_attn_params, + update_mla_attn_dcp_pcp_params, + update_mla_attn_params) # yapf: enable from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.core.eplb_device_transfer_loader import \ @@ -1139,9 +1142,26 @@ class NPUModelRunner(GPUModelRunner): if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \ and not self.use_sparse: # TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead - update_full_graph_params(self.attn_backend, self.update_stream, forward_context, - maybe_padded_num_tokens, self.vllm_config, - self.vllm_config.speculative_config) + if self.vllm_config.model_config.use_mla: + if self.pcp_size * self.dcp_size > 1: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_dcp_pcp_params(self.update_stream, + forward_context, + maybe_padded_num_tokens) + else: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params(self.update_stream, forward_context, + maybe_padded_num_tokens, + self.speculative_config) + else: + if self.pcp_size * self.dcp_size > 1: + update_attn_dcp_pcp_params(self.update_stream, + forward_context, + maybe_padded_num_tokens) + else: + update_attn_params(self.update_stream, forward_context, + maybe_padded_num_tokens, + self.vllm_config) if get_forward_context().sp_enabled and not isinstance( hidden_states, IntermediateTensors): @@ -2018,9 +2038,25 @@ class NPUModelRunner(GPUModelRunner): assert forward_context is not None if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ not forward_context.capturing and not self.use_sparse: - update_full_graph_params(self.attn_backend, self.update_stream, forward_context, - num_tokens, self.vllm_config, - self.speculative_config, positions.shape[0]) + if self.vllm_config.model_config.use_mla: + # FIXME: Try using `auto_dispatch_capture=True` + if self.pcp_size * self.dcp_size > 1: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_dcp_pcp_params(self.update_stream, + forward_context, + positions.shape[0]) + else: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params(self.update_stream, forward_context, + num_tokens, self.speculative_config) + else: + if self.pcp_size * self.dcp_size > 1: + update_attn_dcp_pcp_params(self.update_stream, + forward_context, + positions.shape[0]) + else: + update_attn_params(self.update_stream, forward_context, + num_tokens, self.vllm_config) if self.use_aux_hidden_state_outputs: hidden_states, _ = hidden_states @@ -2863,7 +2899,7 @@ class NPUModelRunner(GPUModelRunner): attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) # NOTE: Must process Attention/MLAAttention before MambaBase to maintain - # ordering expected by graph parameter update logic in attention backends. + # ordering expected by acl_graph.py's _update_attn_fia_params. mamba_layers: dict[str, MambaBase] = {} for layer_name, attn_module in attn_layers.items(): if isinstance(attn_module, Attention):