diff --git a/tests/e2e/singlecard/test_aclgraph.py b/tests/e2e/singlecard/test_aclgraph.py index cf14a9e..e1ccdc9 100644 --- a/tests/e2e/singlecard/test_aclgraph.py +++ b/tests/e2e/singlecard/test_aclgraph.py @@ -20,6 +20,8 @@ Compare the outputs of vLLM with and without aclgraph. Run `pytest tests/compile/test_aclgraph.py`. """ +import os + import pytest from vllm import SamplingParams @@ -73,3 +75,76 @@ def test_models_with_aclgraph( name_0="vllm_eager_outputs", name_1="vllm_aclgraph_outputs", ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [5]) +def test_models_with_aclgraph_full_decode_only( + model: str, + max_tokens: int, +) -> None: + if 'HCCL_OP_EXPANSION_MODE' in os.environ: + del os.environ['HCCL_OP_EXPANSION_MODE'] + prompts = [ + ('Solve the following math problem step by step.' + 'The last line of your response should be of the form Answer: ' + '$Answer (without quotes) where $Answer is the answer to the problem.\n\n' + 'In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$' + 'be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$,' + '$\\angle BDC = 90^\\circ$. Suppose $AD = 1$ and $\\frac{BD}{CD} = \\frac{3}{2}$.' + 'If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$,' + 'where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.' + ), + ('Solve the following math problem step by step.' + 'The last line of your response should be of the form Answer: ' + '$Answer (without quotes) where $Answer is the answer to the problem.\n\n' + 'Let $ABCD$ be a unit square in the plane. Points $X$ and $Y$ are chosen' + 'independently and uniformly at random on the perimeter of $ABCD$.' + 'If the expected value of the area of triangle $\\triangle AXY$' + 'can be expressed as $\\frac{m}{n}$, for relatively prime positive' + 'integers $m$ and $n$, compute $m+n$.'), + ('Solve the following math problem step by step.' + 'The last line of your response should be of the form Answer: ' + '$Answer (without quotes) where $Answer is the answer to the problem.\n\n' + 'Let $a, b, c$ be distinct numbers such that the equations $x^2 + ax + 1 = 0$' + 'and $x^2 + bx + c = 0$ have a common real root, and the equations $x^2 + x + a = 0$' + 'and $x^2 + cx + b = 0$ also have a common real root.' + 'Compute the sum $a + b + c$.') + ] + + sampling_params = SamplingParams(max_tokens=5, + n=1, + temperature=0.0, + top_p=1.0, + top_k=1) + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=False, + compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY"}, + ) as runner: + vllm_aclgraph_outputs = runner.model.generate(prompts, sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + ) as runner: + vllm_eager_outputs = runner.model.generate(prompts, sampling_params) + + vllm_aclgraph_outputs_list = [] + for output in vllm_aclgraph_outputs: + vllm_aclgraph_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + vllm_eager_outputs_list = [] + for output in vllm_eager_outputs: + vllm_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=vllm_aclgraph_outputs_list, + name_0="vllm_eager_outputs", + name_1="vllm_aclgraph_outputs", + ) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 6aac6df..1a982ad 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -461,11 +461,13 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(out.shape, prefix_out.shape) self.assertEqual(lse.shape, prefix_lse.shape) + @patch('vllm_ascend.attention.mla_v1.get_forward_context') @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj") @patch("torch_npu.npu_fused_infer_attention_score") def test_forward_decode_without_graph(self, mock_npu_fused_infer_attention_score, - mock_up_proj): + mock_up_proj, + mock_get_forward_context): num_tokens = 100 block_size = 4 q_nope = torch.randn(num_tokens, self.impl.num_heads, @@ -487,6 +489,7 @@ class TestAscendMLAImpl(TestBase): mock_up_proj.return_value = torch.randn(num_tokens, self.impl.num_heads, self.impl.v_head_dim) + mock_get_forward_context.return_value = MagicMock(capturing=False) result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, block_size, metadata) self.assertEqual(result.shape[0], num_tokens) @@ -614,12 +617,13 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim) self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank) + @patch('vllm_ascend.attention.mla_v1.get_forward_context') @patch("torch.npu.stream") @patch("vllm_ascend.attention.mla_v1.get_multistream_comm_context") @patch("torch_npu.npu_fused_infer_attention_score") def test_forward_decode(self, mock_npu_fused_infer_attention_score, - mock_get_multistream_comm_context, - mock_npu_stream): + mock_get_multistream_comm_context, mock_npu_stream, + mock_get_forward_context): B = 2 N = self.impl.num_kv_heads BS = 100 @@ -644,6 +648,7 @@ class TestAscendMLAImpl(TestBase): ] mock_get_multistream_comm_context.return_value = None + mock_get_forward_context.return_value = MagicMock(capturing=False) result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS, attn_metadata) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d289bb4..98c8c57 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -237,6 +237,7 @@ class AscendAttentionMetadataBuilder: self, common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + model: Optional[nn.Module] = None, ): if attn_state == AscendAttentionState.DecodeOnly: attn_metadata = self.build( diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 39340f7..5bf3262 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, MLAAttentionImpl) from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down @@ -21,6 +22,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, maybe_save_kv_layer_to_connector, split_decodes_and_prefills, wait_for_kv_layer_from_connector) +from vllm_ascend.compilation.acl_graph import get_graph_params from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn @@ -169,7 +171,7 @@ M = TypeVar("M", bound=AscendMLAMetadata) class AscendMLAMetadataBuilder: # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.NEVER + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -389,6 +391,8 @@ class AscendMLAMetadataBuilder: decode_metadata = None if num_decodes > 0: + cos = common_attn_metadata.cos + sin = common_attn_metadata.sin # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist() max_seq_lens = seq_lens[:num_decodes].max().item() @@ -397,21 +401,45 @@ class AscendMLAMetadataBuilder: block_table = block_table[:num_decodes, ...] seq_lens_list = seq_lens.tolist() - cos = self.cos_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) + # TODO: After the fullgraph supports MTP, the if branch needs to deleted + assert self.cos_cache is not None + assert self.sin_cache is not None + if cos is None and sin is None: + cos = self.cos_cache[ + input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) - decode_metadata = AscendMLADecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin, - cos=cos) + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin, + cos=cos) + else: + cos[:num_decodes, + ...] = self.cos_cache[input_positions].unsqueeze( + 1).unsqueeze(2) + sin[:num_decodes, + ...] = self.sin_cache[input_positions].unsqueeze( + 1).unsqueeze(2) + + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin[:num_decodes, ...], + cos=cos[:num_decodes, ...]) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, @@ -431,6 +459,26 @@ class AscendMLAMetadataBuilder: enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, ) + def build_for_graph_capture( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + model: Optional[nn.Module] = None, + ): + if attn_state == AscendAttentionState.DecodeOnly: + attn_metadata = self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + model=model, + ) + else: + raise NotImplementedError( + "Currently we only support building dummy metadata for DecodeOnly state" + ) + + attn_metadata.attn_state = attn_state + return attn_metadata + class DecodeMLAPreprocessResult(NamedTuple): ql_nope: Optional[torch.Tensor] = None @@ -834,24 +882,63 @@ class AscendMLAImpl(MLAAttentionImpl): sparse_mode = 0 spec_attn_mask = None - attn_output, _ = torch_npu.npu_fused_infer_attention_score( - q_nope, - k_nope, - k_nope, - query_rope=q_pe, - key_rope=k_pe, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout=input_layout, - atten_mask=spec_attn_mask, - sparse_mode=sparse_mode, - scale=self.scale, - antiquant_mode=0, - antiquant_scale=None, - block_table=decode_meta.block_table, - block_size=block_size, - actual_seq_lengths_kv=decode_meta.seq_lens_list, - actual_seq_lengths=actual_seq_lengths) + common_kwargs = { + 'query_rope': q_pe, + 'key_rope': k_pe, + 'num_heads': self.num_heads, + 'num_key_value_heads': self.num_kv_heads, + 'input_layout': input_layout, + 'atten_mask': spec_attn_mask, + 'sparse_mode': sparse_mode, + 'scale': self.scale, + 'antiquant_mode': 0, + 'antiquant_scale': None, + 'block_table': decode_meta.block_table, + 'block_size': block_size, + "actual_seq_lengths": actual_seq_lengths, + "actual_seq_lengths_kv": decode_meta.seq_lens_list, + } + graph_params = get_graph_params() + forward_context: ForwardContext = get_forward_context() + if forward_context.capturing: + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + + workspace = graph_params.workspaces.get(num_tokens) + if workspace is None: + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + q_nope, k_nope, k_nope, **common_kwargs) + graph_params.workspaces[num_tokens] = workspace + + attn_output = torch.empty_like(q_nope) + softmax_lse = torch.empty(num_tokens, + dtype=q_nope.dtype, + device=q_nope.device) + + graph_params.attn_params[num_tokens].append( + (q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads, + input_layout, spec_attn_mask, sparse_mode, self.scale, + decode_meta.block_table, block_size, + decode_meta.seq_lens_list, actual_seq_lengths, workspace, + attn_output, softmax_lse)) + + torch.npu.graph_task_group_begin(stream) + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + k_nope, + k_nope, + **common_kwargs, + workspace=workspace, + out=[attn_output, softmax_lse]) + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + else: + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + q_nope, k_nope, k_nope, **common_kwargs) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index efc1103..271ff73 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -63,6 +63,10 @@ class AscendCommonAttentionMetadata: graph_pad_size: int = -1 + # NOTE: This is a temporary solution for rotary embedding in MLA + cos: torch.Tensor = None + sin: torch.Tensor = None + def split_decodes_and_prefills( common_attn_metadata: AscendCommonAttentionMetadata, diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 8a41807..e5f5ae7 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -229,6 +229,52 @@ def update_attn_params(update_stream, forward_context, runtime_shape): event.record(update_stream) +def update_mla_attn_params(update_stream, forward_context, runtime_shape): + graph_params = get_graph_params() + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, + spec_attn_mask, sparse_mode, scale, block_table, block_size, + seq_lens_list, actual_seq_lengths, workspace, attn_output, + softmax_lse) = param + seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list + seq_lens_list = seq_lens_list + [0] * (runtime_shape - + len(seq_lens_list)) + + with torch.npu.stream(update_stream): + torch.npu.graph_task_update_begin(update_stream, handle) + + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=num_heads, + num_key_value_heads=num_kv_heads, + input_layout=input_layout, + atten_mask=spec_attn_mask, + sparse_mode=sparse_mode, + scale=scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=seq_lens_list, + actual_seq_lengths=actual_seq_lengths, + workspace=workspace, + out=[attn_output, softmax_lse]) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + @dataclass class GraphParams: events: dict[int, list[torch.npu.ExternalEvent]] diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index a85895a..a90a73e 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -214,12 +214,7 @@ class NPUPlatform(Platform): if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: compilation_config.level = CompilationLevel.NO_COMPILATION - # TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition - # after MLA being supported - elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or ( - compilation_config.cudagraph_mode - == CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None - and model_config.use_mla): + elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: logger.info( "PIECEWISE compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b46d1be..d2f0aa5 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -104,7 +104,8 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, set_graph_params, - update_attn_params) + update_attn_params, + update_mla_attn_params) from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.core.eplb_device_transfer_loader import \ D2DExpertWeightLoader @@ -358,6 +359,25 @@ class NPUModelRunner(LoRAModelRunnerMixin): dtype=torch.int32, device=self.device) + if self.vllm_config.model_config.use_mla and \ + self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos = torch.ones(self.max_num_reqs, + 1, + 1, + rope_dim, + dtype=self.dtype, + device=self.device) + self.sin = torch.zeros(self.max_num_reqs, + 1, + 1, + rope_dim, + dtype=self.dtype, + device=self.device) + else: + self.cos = None + self.sin = None + self.uses_mrope = self.model_config.uses_mrope # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -1427,6 +1447,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): max_query_len=max_num_scheduled_tokens, graph_pad_size=self.graph_pad_size, decode_token_per_req=self.decode_token_per_req, + cos=self.cos, + sin=self.sin, ) if self.speculative_config and \ @@ -1453,7 +1475,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - model=self.model, + model=self.get_model(), **extra_attn_metadata_args) if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa: @@ -1488,8 +1510,13 @@ class NPUModelRunner(LoRAModelRunnerMixin): forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: - update_attn_params(self.update_stream, forward_context, - positions.shape[0]) + if self.vllm_config.model_config.use_mla: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params(self.update_stream, forward_context, + positions.shape[0]) + else: + update_attn_params(self.update_stream, forward_context, + positions.shape[0]) if get_forward_context().sp_enabled: hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) @@ -2195,14 +2222,21 @@ class NPUModelRunner(LoRAModelRunnerMixin): block_table_tensor=block_table_tensor[:num_reqs], slot_mapping=self.slot_mapping, num_computed_tokens_cpu=num_computed_tokens_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, max_query_len=max_query_len, decode_token_per_req=self.decode_token_per_req, + cos=self.cos, + sin=self.sin, ) for attn_group in self.attn_groups[kv_cache_group_id]: builder = attn_group.get_metadata_builder() attn_metadata_i = builder.build_for_graph_capture( - common_attn_metadata) + common_attn_metadata, AscendAttentionState.DecodeOnly, + self.get_model()) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -2218,9 +2252,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): inputs_embeds=inputs_embeds) forward_context = get_forward_context() assert forward_context is not None - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: - update_attn_params(self.update_stream, forward_context, - positions.shape[0]) + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ + not forward_context.capturing: + if self.vllm_config.model_config.use_mla: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params(self.update_stream, forward_context, + positions.shape[0]) + else: + update_attn_params(self.update_stream, forward_context, + positions.shape[0]) if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, _ = hidden_states