From 7a205dbaa8feb1a8927d58b6399c0766d395c78b Mon Sep 17 00:00:00 2001 From: rjg-lyh <83491835+rjg-lyh@users.noreply.github.com> Date: Tue, 9 Sep 2025 14:28:14 +0800 Subject: [PATCH] [main] Optimize rope in Qwen Models (#2571) ### What this PR does / why we need it? Optimize rope by caching sin and cos at the first layer in Qwen Models. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/562663a044acbff0b8df11e07a4929b18d71172f --------- Signed-off-by: MengqingCao Signed-off-by: ZYang6263 Signed-off-by: rjg-lyh <1318825571@qq.com> Co-authored-by: Mengqing Cao Co-authored-by: ZYang6263 <51255902183@stu.ecnu.edu.cn> Co-authored-by: ZYang6263 --- tests/ut/ops/test_rotary_embedding.py | 90 ++++++++++++++++++++++----- vllm_ascend/ascend_forward_context.py | 3 + vllm_ascend/models/qwen3_moe.py | 6 ++ vllm_ascend/ops/rotary_embedding.py | 84 +++++++++++++++---------- 4 files changed, 136 insertions(+), 47 deletions(-) diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index 39bc76f..de6f4ef 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -3,12 +3,18 @@ import unittest from unittest.mock import MagicMock, PropertyMock, patch import torch +from transformers.configuration_utils import PretrainedConfig +from vllm.config import ModelConfig, VllmConfig from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) from tests.ut.base import TestBase +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled +MODEL = "Qwen3-0.6B" +MAX_NUM_BATCHED_TOKEND = 10000 + class TestCustomRotaryEmbeddingEnabled(unittest.TestCase): @@ -93,6 +99,10 @@ class TestAscendRotaryEmbedding(unittest.TestCase): @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', return_value=True) @patch('torch.ops._npu_rotary_embedding') + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding, mock_custom_enabled, mock_is_310p, mock__c): @@ -102,9 +112,15 @@ class TestAscendRotaryEmbedding(unittest.TestCase): # Setup mock for custom kernel path mock__c.rotary_embedding.return_value = self.query, self.key - - result_q, result_k = self.layer.forward(self.positions, self.query, - self.key) + vllm_config = VllmConfig() + model_config = ModelConfig(MODEL, + tokenizer=MODEL, + max_model_len=MAX_NUM_BATCHED_TOKEND) + model_config.hf_config = PretrainedConfig() + vllm_config.model_config = model_config + with set_ascend_forward_context(None, vllm_config): + result_q, result_k = self.layer.forward(self.positions, self.query, + self.key) mock__c.rotary_embedding.assert_called_once() self.assertEqual(result_q.shape, self.query.shape) @@ -113,6 +129,10 @@ class TestAscendRotaryEmbedding(unittest.TestCase): @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', return_value=False) @patch('torch_npu._npu_rotary_embedding') + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_contiguous(self, mock_npu_rotary, mock_custom_enabled): mock_config = MagicMock() @@ -121,15 +141,25 @@ class TestAscendRotaryEmbedding(unittest.TestCase): # Test contiguous path when custom is disabled non_contig_query = self.query.transpose(0, 1) non_contig_key = self.key.transpose(0, 1) - - result_q, result_k = self.layer.forward(self.positions, - non_contig_query, - non_contig_key) + vllm_config = VllmConfig() + model_config = ModelConfig(MODEL, + tokenizer=MODEL, + max_model_len=MAX_NUM_BATCHED_TOKEND) + model_config.hf_config = PretrainedConfig() + vllm_config.model_config = model_config + with set_ascend_forward_context(None, vllm_config): + result_q, result_k = self.layer.forward(self.positions, + non_contig_query, + non_contig_key) mock_npu_rotary.assert_called_once() self.assertEqual(result_q.shape, non_contig_query.shape) self.assertEqual(result_k.shape, non_contig_key.shape) + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_with_offsets(self): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False @@ -137,22 +167,41 @@ class TestAscendRotaryEmbedding(unittest.TestCase): # Test that NotImplementedError is raised when offsets is provided offsets = torch.tensor([1, 2, 3]) with self.assertRaises(NotImplementedError): - self.layer.forward(self.positions, self.query, self.key, offsets) + vllm_config = VllmConfig() + model_config = ModelConfig(MODEL, + tokenizer=MODEL, + max_model_len=MAX_NUM_BATCHED_TOKEND) + model_config.hf_config = PretrainedConfig() + vllm_config.model_config = model_config + with set_ascend_forward_context(None, vllm_config): + self.layer.forward(self.positions, self.query, self.key, + offsets) @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', return_value=False) @patch('torch_npu._npu_rotary_embedding') + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary, mock_custom_enabled): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False # Test neox_style override - result_q, result_k = self.layer.forward(self.positions, - self.query, - self.key, - is_neox_style_override=False) - + vllm_config = VllmConfig() + model_config = ModelConfig(MODEL, + tokenizer=MODEL, + max_model_len=MAX_NUM_BATCHED_TOKEND) + model_config.hf_config = PretrainedConfig() + vllm_config.model_config = model_config + with set_ascend_forward_context(None, vllm_config): + result_q, result_k = self.layer.forward( + self.positions, + self.query, + self.key, + is_neox_style_override=False) # Check that neox_style=False was passed to the NPU function args, kwargs = mock_npu_rotary.call_args self.assertFalse(args[-1]) @@ -160,6 +209,10 @@ class TestAscendRotaryEmbedding(unittest.TestCase): @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', return_value=False) @patch('torch_npu._npu_rotary_embedding') + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_rotary_dim_less_than_head_size( self, mock_npu_rotary, mock_custom_enabled): mock_config = MagicMock() @@ -169,8 +222,15 @@ class TestAscendRotaryEmbedding(unittest.TestCase): org_rotary_dim = self.layer.rotary_dim self.layer.rotary_dim = self.layer.head_size // 2 - result_q, result_k = self.layer.forward(self.positions, self.query, - self.key) + vllm_config = VllmConfig() + model_config = ModelConfig(MODEL, + tokenizer=MODEL, + max_model_len=MAX_NUM_BATCHED_TOKEND) + model_config.hf_config = PretrainedConfig() + vllm_config.model_config = model_config + with set_ascend_forward_context(None, vllm_config): + result_q, result_k = self.layer.forward(self.positions, self.query, + self.key) mock_npu_rotary.assert_called_once() self.assertEqual(result_q.shape, self.query.shape) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 9bcddf6..be38b9d 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -119,6 +119,9 @@ def set_ascend_forward_context( forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled + # set this for rope forward_oot using + forward_context.is_first_layer = True + if num_tokens is None and attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 2fa10f0..d6451c0 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -20,6 +20,7 @@ from typing import Optional, Union import torch +import torch_npu from torch import nn from transformers import PretrainedConfig from vllm.compilation.decorators import support_torch_compile @@ -280,6 +281,11 @@ class CustomQwen3MoeModel(Qwen3MoeModel): self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + # Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache) may cause performance degradation at runtime. + x = torch.rand((2, 4), dtype=torch.float16).npu() + weight = torch.rand((2, 4), dtype=torch.float16).npu() + c = torch.rand((4, 4), dtype=torch.float32).npu() + torch_npu._npu_matmul_add_fp32(x, weight, c) def forward( self, diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index b982a7e..ea47c04 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -20,6 +20,7 @@ from typing import Optional, Tuple import torch import torch_npu +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) @@ -37,19 +38,16 @@ def _rope_forward_oot( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - is_neox_style_override: Optional[bool] = None, + is_neox_style: bool, + offsets: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: self.cos_sin_cache = self.cos_sin_cache.to(query.device) if self.cos_sin_cache.dtype != query.dtype: self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) - neox_style = self.is_neox_style - if is_neox_style_override is not None: - neox_style = is_neox_style_override # adopt custom kernel path for rotary_embedding - if _custom_rotary_embedding_enabled(query, neox_style, + if _custom_rotary_embedding_enabled(query, is_neox_style, self.head_size) and not is_310p(): query, key = torch.ops._C.rotary_embedding( positions, @@ -57,14 +55,22 @@ def _rope_forward_oot( key, self.head_size, self.cos_sin_cache, - neox_style, + is_neox_style, ) return query.view(query_shape), key.view(key_shape) if offsets is not None: raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: - if self.rotary_dim < self.head_size: + if self.cos is not None and \ + self.sin is not None: + # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. + # This method requires head_size and rotary_dim equal 128 and neox_style is True + query = query.contiguous().view(1, query.shape[0], -1, + self.head_size) + key = key.contiguous().view(1, key.shape[0], -1, self.head_size) + torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin) + elif self.rotary_dim < self.head_size: num_tokens = query.shape[0] query = query.view(num_tokens, -1, self.head_size) key = key.view(num_tokens, -1, self.head_size) @@ -80,25 +86,26 @@ def _rope_forward_oot( k_rot, self.head_size, self.cos_sin_cache, - neox_style, + is_neox_style, ) q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) return q, k - # TODO: Remove the contiguous in the future. - query = query.contiguous().view(query.shape[0], -1) - key = key.contiguous().view(key.shape[0], -1) - torch_npu._npu_rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - neox_style, - ) - return query.view(query_shape), key.view(key_shape) + else: + # TODO: Remove the contiguous in the future. + query = query.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(key.shape[0], -1) + torch_npu._npu_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + is_neox_style, + ) + return query.view(query_shape), key.view(key_shape) class AscendRotaryEmbedding(RotaryEmbedding): @@ -112,6 +119,8 @@ class AscendRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, dtype: torch.dtype, ) -> None: + self.cos = None + self.sin = None super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) @@ -123,14 +132,25 @@ class AscendRotaryEmbedding(RotaryEmbedding): offsets: Optional[torch.Tensor] = None, is_neox_style_override: Optional[bool] = None, ): - return _rope_forward_oot( - self, - positions, - query, - key, - offsets, - is_neox_style_override, - ) + is_neox_style = self.is_neox_style + if is_neox_style_override is not None: + is_neox_style = is_neox_style_override + forward_context = get_forward_context() + is_first_layer = forward_context.is_first_layer + # Generate cos and sin outside layers to avoid repeated calculation. + if is_neox_style and \ + self.head_size == 128: + if is_first_layer: + cos_sin = self.cos_sin_cache.index_select(0, positions) + last_dim = cos_sin.size()[-1] + cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat( + 1, 1, 2).chunk(2, dim=-2) + # BSNH + self.cos = cos.view(1, -1, 1, last_dim).contiguous() + self.sin = sin.view(1, -1, 1, last_dim).contiguous() + forward_context.is_first_layer = False + return _rope_forward_oot(self, positions, query, key, is_neox_style, + offsets) class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): @@ -322,7 +342,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): # Note: we implement the non neox_style method with shuffle the last dim and neox style # calculation method which is also more compute friendly to the ascend machine # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py - neox_style = True + is_neox_style = True if self.is_neox_style is False: b, h_q, d = query.shape query = query.view(b, h_q, d // 2, @@ -330,6 +350,6 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): b, h_k, d = key.shape key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) - q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets, - neox_style) + q_pe, k_pe = _rope_forward_oot(self, positions, query, key, + is_neox_style, offsets) return q_pe, k_pe