diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 3035718..7f8d5f7 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -331,15 +331,30 @@ class TestAscendMLAMetadataBuilder(TestBase): runner.chunked_prefill_enabled = False runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool) runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool) + runner.dtype = torch.float16 builder = AscendMLAMetadataBuilder(runner=runner, metadata_cls=AscendMLAMetadata) + builder.rope_dim = 64 with patch.object(builder, "_get_graph_runner_block_tables", side_effect=lambda x, y: y): metadata = builder.build_torchair_graph_dummy(3, 3) + sin_golden = torch.ones(3, + 1, + 1, + 64, + dtype=runner.dtype, + device=runner.device) + cos_golden = torch.ones(3, + 1, + 1, + 64, + dtype=runner.dtype, + device=runner.device) + self.assertIsInstance(metadata, AscendMLAMetadata) self.assertEqual(metadata.num_input_tokens, 3) self.assertEqual(metadata.num_actual_tokens, 3) @@ -354,6 +369,8 @@ class TestAscendMLAMetadataBuilder(TestBase): self.assertEqual(metadata.seq_lens.shape[0], 3) self.assertEqual(metadata.slot_mapping.shape[0], 3) self.assertEqual(metadata.query_start_loc.shape[0], 3) + assert torch.equal(sin_golden, metadata.decode.sin) + assert torch.equal(cos_golden, metadata.decode.cos) class TestAscendMLAImpl(TestBase): diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 5d993e0..4e24756 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -80,6 +80,8 @@ class AscendMLAPrefillMetadata: max_query_len: int max_seq_lens: int chunked_context: Optional[ChunkedContextMetadata] = None + sin: torch.Tensor = None + cos: torch.Tensor = None @dataclass @@ -92,6 +94,8 @@ class AscendMLADecodeMetadata: max_seq_lens: int seq_lens_list: list[int] attn_mask: Optional[torch.Tensor] = None + sin: torch.Tensor = None + cos: torch.Tensor = None @dataclass @@ -200,6 +204,9 @@ class AscendMLAMetadataBuilder: ) ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -318,13 +325,27 @@ class AscendMLAMetadataBuilder: -1, dtype=torch.int32, device=device) + sin = torch.ones(num_reqs, + 1, + 1, + self.rope_dim, + dtype=self.runner.dtype, + device=device) + cos = torch.ones(num_reqs, + 1, + 1, + self.rope_dim, + dtype=self.runner.dtype, + device=device) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), max_seq_lens=1, - attn_mask=self.runner.spec_attn_mask) + attn_mask=self.runner.spec_attn_mask, + sin=sin, + cos=cos) return self.metadata_cls( # type: ignore num_input_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens, @@ -370,6 +391,16 @@ class AscendMLAMetadataBuilder: seq_lens = seq_lens_cpu max_query_len = query_lens.max().item() max_seq_lens = seq_lens.max().item() + if self.cos_cache is None: + self.cos_cache = self.runner.get_model( + ).model.layers[0].self_attn.rotary_emb.cos_cached + self.sin_cache = self.runner.get_model( + ).model.layers[0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.runner.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.runner.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.runner.dtype) # type: ignore prefill_metadata = None chunked_context_metadata = None @@ -415,18 +446,26 @@ class AscendMLAMetadataBuilder: chunk_seq_lens=chunk_seq_lens, workspace=self.chunked_prefill_workspace, ) - + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) prefill_metadata = AscendMLAPrefillMetadata( attn_mask=self.runner.attn_mask, query_lens=query_lens[tokens_start:], seq_lens=seq_lens, context_lens=seq_lens[tokens_start:], - input_positions=input_positions[tokens_start:], + input_positions=prefill_input_positions, block_table=block_table[reqs_start:, ...], max_query_len=max_query_len, max_seq_lens=max_seq_lens, query_start_loc=prefill_query_start_loc, chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, ) decode_metadata = None @@ -467,6 +506,10 @@ class AscendMLAMetadataBuilder: dtype=input_positions.dtype, device=input_positions.device) input_positions = torch.cat([input_positions, padding_0]) + cos = self.cos_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, @@ -474,7 +517,9 @@ class AscendMLAMetadataBuilder: seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), max_seq_lens=max_seq_lens, - attn_mask=self.runner.spec_attn_mask) + attn_mask=self.runner.spec_attn_mask, + sin=sin, + cos=cos) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, @@ -1069,15 +1114,8 @@ class AscendMLAImpl(MLAAttentionImpl): decode_k_nope = None assert attn_metadata.decode is not None if self.running_in_graph: - seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor - cos = self.rotary_emb.cos_cached[:seq_len].to( - dtype=decode_hs_or_q_c.dtype) - sin = self.rotary_emb.sin_cached[:seq_len].to( - dtype=decode_hs_or_q_c.dtype) - cos = cos[attn_metadata.decode.input_positions] - sin = sin[attn_metadata.decode.input_positions] - cos = cos[:, None, None, :] - sin = sin[:, None, None, :] + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin with npu_stream_switch("mla_secondary", 0, enabled=enable_multistream_mla): @@ -1124,15 +1162,8 @@ class AscendMLAImpl(MLAAttentionImpl): prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] if self.torchair_graph_enabled: num_tokens = prefill_hs_or_q_c.shape[0] - seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor - cos = self.rotary_emb.cos_cached[:seq_len].to( - dtype=prefill_q_pe.dtype) - sin = self.rotary_emb.sin_cached[:seq_len].to( - dtype=prefill_q_pe.dtype) - cos = cos[attn_metadata.prefill.input_positions] - sin = sin[attn_metadata.prefill.input_positions] - cos = cos[:, None, None, :] - sin = sin[:, None, None, :] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 4c008b4..2bee8dd 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1799,6 +1799,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): attn_metadata.decode.input_positions) torch._dynamo.mark_static( get_forward_context().mc2_mask) + if hasattr(attn_metadata.decode, "sin"): + torch._dynamo.mark_static(attn_metadata.decode.sin) + torch._dynamo.mark_static(attn_metadata.decode.cos) torch._dynamo.mark_static(attn_metadata.slot_mapping) for kv in self.kv_caches: assert isinstance(