[Misc] Clean up uesless code in rotary_embedding (#2663)

Clean up useless code which is only used for torchair in rotary_embedding

- vLLM version: v0.10.1.1
- vLLM main:
a344a5aa0a

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-09-02 17:25:33 +08:00
committed by GitHub
parent 253b01b9a5
commit c1e607b7b7
2 changed files with 32 additions and 49 deletions

View File

@@ -7,7 +7,7 @@ from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from tests.ut.base import TestBase
from vllm_ascend.ops.rotary_embedding import custom_rotary_embedding_enabled
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
@@ -31,37 +31,37 @@ class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
# Test when all conditions are True
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size)
result = _custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertTrue(result)
# Test when dtype is not float16
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
query = self.query.to(torch.float32)
result = custom_rotary_embedding_enabled(query, True,
self.head_size)
result = _custom_rotary_embedding_enabled(query, True,
self.head_size)
self.assertFalse(result)
# Test when neox_style is False
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, False,
self.head_size)
result = _custom_rotary_embedding_enabled(self.query, False,
self.head_size)
self.assertFalse(result)
# Test when head_size is not divisible by 32
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size + 1)
result = _custom_rotary_embedding_enabled(self.query, True,
self.head_size + 1)
self.assertFalse(result)
# Test when custom op is disabled
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=False):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size)
result = _custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertFalse(result)
@@ -90,7 +90,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('torch.ops._C')
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=True)
@patch('torch.ops._npu_rotary_embedding')
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
@@ -110,7 +110,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
@@ -139,7 +139,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
with self.assertRaises(NotImplementedError):
self.layer.forward(self.positions, self.query, self.key, offsets)
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
@@ -198,7 +198,7 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
def test_native_rope_deepseek_forward_base(self, mock_npuplatform):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
with patch("vllm_ascend.ops.rotary_embedding.rope_forward_oot",
with patch("vllm_ascend.ops.rotary_embedding._rope_forward_oot",
return_value=(self.query,
self.key)) as mock_rope_forward_oot:
q_pe, k_pe = self.layer.forward(self.positions, self.query,
@@ -207,7 +207,7 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
@@ -229,7 +229,7 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
@@ -248,7 +248,7 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
assert q_pe.shape == self.query.shape
assert k_pe.shape == key.shape
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",

View File

@@ -19,7 +19,6 @@ import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import torch_npu
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
@@ -28,19 +27,18 @@ from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import enable_custom_op, is_310p
def custom_rotary_embedding_enabled(query, neox_style, head_size):
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
)
def rope_forward_oot(
def _rope_forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
is_qwen_torchair: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
@@ -51,8 +49,8 @@ def rope_forward_oot(
if is_neox_style_override is not None:
neox_style = is_neox_style_override
# adopt custom kernel path for rotary_embedding
if custom_rotary_embedding_enabled(query, neox_style,
self.head_size) and not is_310p():
if _custom_rotary_embedding_enabled(query, neox_style,
self.head_size) and not is_310p():
query, key = torch.ops._C.rotary_embedding(
positions,
query,
@@ -80,23 +78,6 @@ def rope_forward_oot(
return query.view(query_shape), key.view(key_shape)
def set_cos_sin_cache(self, seq_len, device, dtype):
inv_freq = 1.0 / (self.base**(torch.arange(
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
(1 / self.rotary_dim)))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
self.embed = F.embedding
class AscendRotaryEmbedding(RotaryEmbedding):
def __init__(
@@ -118,13 +99,15 @@ class AscendRotaryEmbedding(RotaryEmbedding):
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
max_seq_len: Optional[int] = None,
is_prefill: Optional[bool] = True,
is_qwen_torchair: Optional[bool] = False,
):
return rope_forward_oot(self, positions, query, key, offsets,
is_neox_style_override,
is_qwen_torchair) # type: ignore
return _rope_forward_oot(
self,
positions,
query,
key,
offsets,
is_neox_style_override,
)
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
@@ -328,6 +311,6 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3,
2).reshape(b, h_k, d)
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
neox_style)
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets,
neox_style)
return q_pe, k_pe