From c2c97f3079957efafbfba29ba97ded417e23fd55 Mon Sep 17 00:00:00 2001 From: Wang Yixuan <88923622+hust17yixuan@users.noreply.github.com> Date: Mon, 1 Sep 2025 09:09:21 +0800 Subject: [PATCH] [5/N][refactor]add torchair rotary ops (#2559) ### What this PR does / why we need it? Move torchair related rotary ops into torchair dir to make the code clear. Next step we'll remove all torchair related code outside of torchair rotary ops. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: main vLLM main: https://github.com/vllm-project/vllm/commit/ab9f2cfd1942f7ddfee658ce86ea96b4789862af - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/81eea3d348c26fb1e6ff0185ad109aedd60a28a2 Signed-off-by: hust17yixuan <303660421@qq.com> --- .../ops/test_torchair_rotary_embedding.py | 332 ++++++++++++++++ .../torchair/ops/torchair_rotary_embedding.py | 372 ++++++++++++++++++ vllm_ascend/torchair/torchair_model_runner.py | 13 +- vllm_ascend/torchair/utils.py | 15 + 4 files changed, 725 insertions(+), 7 deletions(-) create mode 100644 tests/ut/torchair/ops/test_torchair_rotary_embedding.py create mode 100644 vllm_ascend/torchair/ops/torchair_rotary_embedding.py diff --git a/tests/ut/torchair/ops/test_torchair_rotary_embedding.py b/tests/ut/torchair/ops/test_torchair_rotary_embedding.py new file mode 100644 index 0000000..e7c68f7 --- /dev/null +++ b/tests/ut/torchair/ops/test_torchair_rotary_embedding.py @@ -0,0 +1,332 @@ +import math +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.torchair.ops.torchair_rotary_embedding import ( + custom_rotary_embedding_enabled, native_rope_deepseek_forward, + rope_forward_oot, rotate_half, yarn_find_correction_dim, yarn_get_mscale) + + +class TestCustomRotaryEmbeddingEnabled(TestBase): + + def setUp(self): + # Common setup for tests + self.positions = torch.tensor([1, 2, 3]) + self.query = torch.randn(3, 4, dtype=torch.float16) + self.key = torch.randn(3, 4, dtype=torch.float16) + self.head_size = 32 + self.cos_sin_cache = torch.randn(3, 4) + + # Mock self object for rope_forward_oot + self.mock_self = MagicMock() + self.mock_self.head_size = self.head_size + self.mock_self.cos_sin_cache = self.cos_sin_cache + self.mock_self.is_neox_style = True + self.mock_self.forward_native.return_value = (self.query, self.key) + + def test_custom_rotary_embedding_enabled(self): + # Test when all conditions are True + with patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op', + return_value=True): + result = custom_rotary_embedding_enabled(self.query, True, + self.head_size) + self.assertTrue(result) + + # Test when dtype is not float16 + with patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op', + return_value=True): + query = self.query.to(torch.float32) + result = custom_rotary_embedding_enabled(query, True, + self.head_size) + self.assertFalse(result) + + # Test when neox_style is False + with patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op', + return_value=True): + result = custom_rotary_embedding_enabled(self.query, False, + self.head_size) + self.assertFalse(result) + + # Test when head_size is not divisible by 32 + with patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op', + return_value=True): + result = custom_rotary_embedding_enabled(self.query, True, + self.head_size + 1) + self.assertFalse(result) + + # Test when custom op is disabled + with patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op', + return_value=False): + result = custom_rotary_embedding_enabled(self.query, True, + self.head_size) + self.assertFalse(result) + + +class TestRopeForwardOot(TestBase): + + def setUp(self): + # Common setup for tests + self.positions = torch.tensor([1, 2, 3]) + self.query = torch.randn(3, 4, dtype=torch.float16) + self.key = torch.randn(3, 4, dtype=torch.float16) + self.head_size = 32 + self.cos_sin_cache = torch.randn(3, 4) + + # Mock self object for rope_forward_oot + self.mock_self = MagicMock() + self.mock_self.head_size = self.head_size + self.mock_self.cos_sin_cache = self.cos_sin_cache + self.mock_self.is_neox_style = True + self.mock_self.forward_native.return_value = (self.query, self.key) + + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') + def test_rope_forward_oot_torchair_enabled_base(self, + mock_get_ascend_config): + # Setup mock for torchair enabled + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = True + mock_get_ascend_config.return_value = mock_config + + result_q, result_k = rope_forward_oot(self.mock_self, self.positions, + self.query, self.key) + + self.mock_self.forward_native.assert_called_once_with( + self.positions, self.query, self.key, None) + self.assertTrue(torch.equal(result_q, self.query)) + self.assertTrue(torch.equal(result_k, self.key)) + + @patch('torch.ops._C') + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') + @patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p', + return_value=False) + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled', + return_value=True) + @patch('torch.ops._npu_rotary_embedding') + def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding, + mock_custom_enabled, mock_is_310p, + mock_get_ascend_config, mock__c): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_ascend_config.return_value = mock_config + + # Setup mock for custom kernel path + + mock__c.rotary_embedding.return_value = self.query, self.key + + result_q, result_k = rope_forward_oot(self.mock_self, self.positions, + self.query, self.key) + + self.assertEqual(result_q.shape, self.query.shape) + self.assertEqual(result_k.shape, self.key.shape) + + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled', + return_value=False) + @patch('torch_npu._npu_rotary_embedding') + def test_rope_forward_oot_contiguous(self, mock_npu_rotary, + mock_custom_enabled, + mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_ascend_config.return_value = mock_config + + # Test contiguous path when custom is disabled + non_contig_query = self.query.transpose(0, 1) + non_contig_key = self.key.transpose(0, 1) + + result_q, result_k = rope_forward_oot(self.mock_self, self.positions, + non_contig_query, non_contig_key) + + mock_npu_rotary.assert_called_once() + self.assertEqual(result_q.shape, non_contig_query.shape) + self.assertEqual(result_k.shape, non_contig_key.shape) + + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') + def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_ascend_config.return_value = mock_config + + # Test that NotImplementedError is raised when offsets is provided + offsets = torch.tensor([1, 2, 3]) + with self.assertRaises(NotImplementedError): + rope_forward_oot(self.mock_self, self.positions, self.query, + self.key, offsets) + + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled', + return_value=False) + @patch('torch_npu._npu_rotary_embedding') + def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary, + mock_custom_enabled, + mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_ascend_config.return_value = mock_config + + # Test neox_style override + result_q, result_k = rope_forward_oot(self.mock_self, + self.positions, + self.query, + self.key, + is_neox_style_override=False) + + # Check that neox_style=False was passed to the NPU function + args, kwargs = mock_npu_rotary.call_args + self.assertFalse(args[-1]) + + +class MockRopeModule: + + def __init__(self, max_seq_len=2048, is_neox_style=True): + self.max_seq_len = max_seq_len + self.is_neox_style = is_neox_style + self.cos_cached = None + self.sin_cached = None + self.rotary_dim = 1 + self.base = 1 + + +class TestNativeRopeDeepseekForward(TestBase): + + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot') + def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot): + module = MockRopeModule() + positions = torch.tensor([1, 2, 3]) + query = torch.randn(1, 8, 128) + key = torch.randn(1, 8, 128) + + mock_rope_forward_oot.return_value = (query, key) + + q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, + key) + + assert q_pe.shape == query.shape + assert k_pe.shape == key.shape + + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding._set_cos_sin_cache' + ) + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot') + def test_native_rope_deepseek_forward_cache_handling( + self, mock_rope_forward_oot, mock_set_cache): + # Test cache situation is true + module = MockRopeModule(max_seq_len=1024) + positions = torch.tensor([1, 2, 3]) + query = torch.randn(1, 8, 128) + key = torch.randn(1, 8, 128) + + mock_rope_forward_oot.return_value = (query, key) + + q_pe, k_pe = native_rope_deepseek_forward(module, + positions, + query, + key, + max_seq_len=2048) + + assert q_pe.shape == query.shape + assert k_pe.shape == key.shape + + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot') + def test_native_rope_deepseek_forward_key_reshaping( + self, mock_rope_forward_oot): + module = MockRopeModule() + positions = torch.tensor([1, 2, 3]) + query = torch.randn(1, 8, 128) + key = torch.randn(1, 128) + + mock_rope_forward_oot.return_value = (query, key) + + q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, + key) + + assert q_pe.shape == query.shape + assert k_pe.shape == (1, 128) + + @patch( + 'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot') + def test_native_rope_deepseek_forward_non_neox_style( + self, mock_rope_forward_oot): + module = MockRopeModule(is_neox_style=False) + positions = torch.tensor([1, 2, 3]) + query = torch.randn(1, 8, 128) + key = torch.randn(1, 8, 128) + + mock_rope_forward_oot.return_value = (query, key) + + q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, + key) + + assert q_pe.shape == query.shape + assert k_pe.shape == key.shape + + +class TestRotateHalf(TestBase): + + def test_rotate_half_even_dim(self): + # Test with even dimension + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) + expected = torch.tensor([-3.0, -4.0, 1.0, 2.0]) + result = rotate_half(x) + self.assertTrue(torch.allclose(result, expected)) + + +class TestYarnFindCorrectionDim(TestBase): + + def test_basic_case(self): + # Test with standard values + num_rotations = 100 + dim = 512 + base = 10000 + max_position_embeddings = 2048 + + result = yarn_find_correction_dim(num_rotations, dim, base, + max_position_embeddings) + + # Calculate expected value manually + expected = (dim * torch.log( + torch.tensor(max_position_embeddings) / + (num_rotations * 2 * torch.pi))) / (2 * + torch.log(torch.tensor(base))) + + self.assertTrue(torch.allclose(result, expected)) + + +class TestYarnGetMscale(TestBase): + + def test_scale_less_than_or_equal_1(self): + self.assertEqual(yarn_get_mscale(scale=0.5), 1.0) + self.assertEqual(yarn_get_mscale(scale=1.0), 1.0) + self.assertEqual(yarn_get_mscale(scale=0.999), 1.0) + + def test_scale_greater_than_1(self): + test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)), + (10.0, 1.0, 1.0 + 0.1 * math.log(10.0)), + (5.0, 2.0, 1.0 + 0.2 * math.log(5.0)), + (math.e, 1.0, 1.0 + 0.1)] + + for scale, mscale, expected in test_cases: + result = yarn_get_mscale(scale, mscale) + self.assertAlmostEqual( + result, + expected, + places=6, + msg=f"Failed for scale={scale}, mscale={mscale}") diff --git a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py new file mode 100644 index 0000000..5793288 --- /dev/null +++ b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py @@ -0,0 +1,372 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import math +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import torch_npu +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.utils import enable_custom_op, is_310p + + +def custom_rotary_embedding_enabled(query, neox_style, head_size): + return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op( + ) + + +def rope_forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + is_neox_style_override: Optional[bool] = None, + is_qwen_torchair: Optional[bool] = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if get_ascend_config( + ).torchair_graph_config.enabled and not is_qwen_torchair: + return self.forward_native( + positions, + query, + key, + offsets, + ) + + query_shape, key_shape = query.shape, key.shape + if self.cos_sin_cache.device != query.device: + self.cos_sin_cache = self.cos_sin_cache.to(query.device) + if self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + neox_style = self.is_neox_style + if is_neox_style_override is not None: + neox_style = is_neox_style_override + # adopt custom kernel path for rotary_embedding + if custom_rotary_embedding_enabled(query, neox_style, + self.head_size) and not is_310p(): + query, key = torch.ops._C.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + neox_style, + ) + return query.view(query_shape), key.view(key_shape) + if offsets is not None: + raise NotImplementedError( + "Batched rotary embedding is currently not supported on NPU.") + else: + # TODO: Remove the contiguous in the future. + query = query.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(key.shape[0], -1) + torch_npu._npu_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + neox_style, + ) + return query.view(query_shape), key.view(key_shape) + + +def native_rope_deepseek_forward(self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + max_seq_len: Optional[int] = None): + if max_seq_len is not None and max_seq_len > self.max_seq_len: + _set_cos_sin_cache(self, max_seq_len, query.device, query.dtype) + if len(key.shape) == 2: + key = key[:, None, :] + # Note: we implement the non neox_style method with shuffle the last dim and neox style + # calculation method which is also more compute friendly to the ascend machine + # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py + neox_style = True + if self.is_neox_style is False: + b, h_q, d = query.shape + query = query.view(b, h_q, d // 2, 2).transpose(3, + 2).reshape(b, h_q, d) + b, h_k, d = key.shape + key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) + q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets, + neox_style) + return q_pe, k_pe + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim(num_rotations, + dim, + base=10000, + max_position_embeddings=2048): + # Note: use torch instead of math to solve MTP compilation error. + return (dim * torch.log( + torch.tensor(max_position_embeddings) / + (num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base))) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +# Find dim range bounds based on rotations +def yarn_find_correction_range(low_rot, + high_rot, + dim, + base=10000, + max_position_embeddings=2048): + # Note: use torch instead of math to solve MTP compilation error. + low = torch.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = torch.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + # Note: use torch instead of max/min to solve MTP compilation error. + return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1) + + +def yarn_linear_ramp_mask(min_value, max_value, dim): + # Note: The if conditional branch is not used here + # to solve MTP compilation error. + max_value += (min_value == max_value).float() * 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - + min_value) / (max_value - min_value) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids] + sin = sin[position_ids] + cos = cos[:, None, None, :] + sin = sin[:, None, None, :] + + if len(q.shape) == 3: + q = q[:, :, None, :] + if len(k.shape) == 2: + k = k[:, None, None, :] + elif len(k.shape) == 3: + k = k[:, :, None, :] + + b, h_q, s, d = q.shape + q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d) + + b, h_k, s, d = k.shape + k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + q_embed = q_embed.view(b, h_q, d) + k_embed = k_embed.view(b, h_k, d) + + return q_embed, k_embed + + +def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.rotary_dim + + freq_extra = 1.0 / (self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freq_inter = 1.0 / (self.scaling_factor * self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len * self.scaling_factor, + device=device, + dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale + sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale + cos_cached = cos_cached.to(dtype) + sin_cached = sin_cached.to(dtype) + cache = torch.cat([freqs.cos() * self.mscale, + freqs.sin() * self.mscale], + dim=-1).to(dtype) + self.register_buffer("cos_sin_cache", cache, persistent=False) + self.register_buffer("cos_cached", cos_cached, persistent=False) + self.register_buffer("sin_cached", sin_cached, persistent=False) + + +def __set_cos_sin_cache(self, seq_len, device, dtype): + inv_freq = 1.0 / (self.base**(torch.arange( + 0, self.rotary_dim, 2, device=device, dtype=torch.float32) * + (1 / self.rotary_dim))) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False) + self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False) + self.embed = F.embedding + + +_original_re_init = RotaryEmbedding.__init__ + + +def qwen_rope_init_func( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, +) -> None: + _original_re_init(self, head_size, rotary_dim, max_position_embeddings, + base, is_neox_style, dtype) + if get_ascend_config().torchair_graph_config.enabled: + __set_cos_sin_cache(self, + seq_len=max_position_embeddings, + device="npu", + dtype=dtype) + + +def rope_forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + is_neox_style_override: Optional[bool] = None, + max_seq_len: Optional[int] = None, + is_prefill: Optional[bool] = True, + is_qwen_torchair: Optional[bool] = False, +): + if get_ascend_config().torchair_graph_config.enabled \ + and is_qwen_torchair and not is_prefill: + if max_seq_len is not None and torch.gt(max_seq_len, + self.max_position_embeddings): + __set_cos_sin_cache(self, + seq_len=max_seq_len, + device=query.device, + dtype=torch.float32) + + # bsnd/bnsd + if positions is not None: + cos = self.embed(positions, self.cos) + sin = self.embed(positions, self.sin) + self.cos_embed = cos + self.sin_embed = sin + else: + cos = self.cos_embed + sin = self.sin_embed + + query = query.view(*query.shape[:-1], -1, self.head_size).contiguous() + key = key.view(*key.shape[:-1], -1, self.head_size).contiguous() + + cos = cos.unsqueeze(-2).unsqueeze(-2) + sin = sin.unsqueeze(-2).unsqueeze(-2) + + query = query.unsqueeze(1) + key = key.unsqueeze(1) + + q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb( + query, key, cos, sin) + return q_embed.flatten(-2), k_embed.flatten(-2) + else: + return rope_forward_oot(self, positions, query, key, offsets, + is_neox_style_override, + is_qwen_torchair) # type: ignore + + +def deepseek_rope_init_func( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, +) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super(DeepseekScalingRotaryEmbedding, + self).__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + self.max_seq_len = max_position_embeddings + _set_cos_sin_cache(self, + max_position_embeddings, + dtype=dtype, + device="npu") diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index f371c7d..2a0dc15 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -15,7 +15,7 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py -# +# isort: skip_file import types from typing import Optional @@ -34,12 +34,10 @@ from vllm.logger import logger import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.platform import NPUPlatform -from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, - check_torchair_cache_exist, - converting_weight_acl_format, - register_torchair_model, - torchair_quant_method_register, - write_kv_cache_bytes_to_file) +from vllm_ascend.torchair.utils import ( + TorchairCommonAttentionMetadata, check_torchair_cache_exist, + converting_weight_acl_format, register_torchair_model, torchair_ops_patch, + torchair_quant_method_register, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, is_310p) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -68,6 +66,7 @@ class NPUTorchairModelRunner(NPUModelRunner): self._check_batch_sizes_consistency() register_torchair_model() + torchair_ops_patch() torchair_quant_method_register() def _sync_metadata_across_dp( diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 2f3a28e..a9bbdae 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -182,3 +182,18 @@ def torchair_quant_method_register(): "W8A8_DYNAMIC"] = TorchairW8A8DYNAMICQuantizer SUPPORT_ASCEND_QUANTIZER_TYPE[ "W4A8_DYNAMIC"] = TorchairW4A8DYNAMICQuantizer + + +def torchair_ops_patch(): + from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) + + from vllm_ascend.torchair.ops.torchair_rotary_embedding import ( + deepseek_rope_init_func, native_rope_deepseek_forward, + qwen_rope_init_func, rope_forward) + + RotaryEmbedding.__init__ = qwen_rope_init_func + RotaryEmbedding.forward_oot = rope_forward + + DeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func + DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward