[Misc] Clean up uesless code in rotary_embedding (#2663)
Clean up useless code which is only used for torchair in rotary_embedding
- vLLM version: v0.10.1.1
- vLLM main:
a344a5aa0a
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -7,7 +7,7 @@ from vllm.model_executor.layers.rotary_embedding import (
|
|||||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.ops.rotary_embedding import custom_rotary_embedding_enabled
|
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
|
||||||
|
|
||||||
|
|
||||||
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
||||||
@@ -31,37 +31,37 @@ class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
|||||||
# Test when all conditions are True
|
# Test when all conditions are True
|
||||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||||
return_value=True):
|
return_value=True):
|
||||||
result = custom_rotary_embedding_enabled(self.query, True,
|
result = _custom_rotary_embedding_enabled(self.query, True,
|
||||||
self.head_size)
|
self.head_size)
|
||||||
self.assertTrue(result)
|
self.assertTrue(result)
|
||||||
|
|
||||||
# Test when dtype is not float16
|
# Test when dtype is not float16
|
||||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||||
return_value=True):
|
return_value=True):
|
||||||
query = self.query.to(torch.float32)
|
query = self.query.to(torch.float32)
|
||||||
result = custom_rotary_embedding_enabled(query, True,
|
result = _custom_rotary_embedding_enabled(query, True,
|
||||||
self.head_size)
|
self.head_size)
|
||||||
self.assertFalse(result)
|
self.assertFalse(result)
|
||||||
|
|
||||||
# Test when neox_style is False
|
# Test when neox_style is False
|
||||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||||
return_value=True):
|
return_value=True):
|
||||||
result = custom_rotary_embedding_enabled(self.query, False,
|
result = _custom_rotary_embedding_enabled(self.query, False,
|
||||||
self.head_size)
|
self.head_size)
|
||||||
self.assertFalse(result)
|
self.assertFalse(result)
|
||||||
|
|
||||||
# Test when head_size is not divisible by 32
|
# Test when head_size is not divisible by 32
|
||||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||||
return_value=True):
|
return_value=True):
|
||||||
result = custom_rotary_embedding_enabled(self.query, True,
|
result = _custom_rotary_embedding_enabled(self.query, True,
|
||||||
self.head_size + 1)
|
self.head_size + 1)
|
||||||
self.assertFalse(result)
|
self.assertFalse(result)
|
||||||
|
|
||||||
# Test when custom op is disabled
|
# Test when custom op is disabled
|
||||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||||
return_value=False):
|
return_value=False):
|
||||||
result = custom_rotary_embedding_enabled(self.query, True,
|
result = _custom_rotary_embedding_enabled(self.query, True,
|
||||||
self.head_size)
|
self.head_size)
|
||||||
self.assertFalse(result)
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
|
||||||
@@ -90,7 +90,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
|||||||
|
|
||||||
@patch('torch.ops._C')
|
@patch('torch.ops._C')
|
||||||
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
|
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
|
||||||
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
|
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||||
return_value=True)
|
return_value=True)
|
||||||
@patch('torch.ops._npu_rotary_embedding')
|
@patch('torch.ops._npu_rotary_embedding')
|
||||||
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
||||||
@@ -110,7 +110,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
|||||||
self.assertEqual(result_q.shape, self.query.shape)
|
self.assertEqual(result_q.shape, self.query.shape)
|
||||||
self.assertEqual(result_k.shape, self.key.shape)
|
self.assertEqual(result_k.shape, self.key.shape)
|
||||||
|
|
||||||
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
|
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||||
return_value=False)
|
return_value=False)
|
||||||
@patch('torch_npu._npu_rotary_embedding')
|
@patch('torch_npu._npu_rotary_embedding')
|
||||||
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
||||||
@@ -139,7 +139,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
|||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
self.layer.forward(self.positions, self.query, self.key, offsets)
|
self.layer.forward(self.positions, self.query, self.key, offsets)
|
||||||
|
|
||||||
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
|
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||||
return_value=False)
|
return_value=False)
|
||||||
@patch('torch_npu._npu_rotary_embedding')
|
@patch('torch_npu._npu_rotary_embedding')
|
||||||
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
||||||
@@ -198,7 +198,7 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
|
|||||||
def test_native_rope_deepseek_forward_base(self, mock_npuplatform):
|
def test_native_rope_deepseek_forward_base(self, mock_npuplatform):
|
||||||
mock_npuplatform.device_type = torch.device("cpu")
|
mock_npuplatform.device_type = torch.device("cpu")
|
||||||
self.layer = self._create_layer()
|
self.layer = self._create_layer()
|
||||||
with patch("vllm_ascend.ops.rotary_embedding.rope_forward_oot",
|
with patch("vllm_ascend.ops.rotary_embedding._rope_forward_oot",
|
||||||
return_value=(self.query,
|
return_value=(self.query,
|
||||||
self.key)) as mock_rope_forward_oot:
|
self.key)) as mock_rope_forward_oot:
|
||||||
q_pe, k_pe = self.layer.forward(self.positions, self.query,
|
q_pe, k_pe = self.layer.forward(self.positions, self.query,
|
||||||
@@ -207,7 +207,7 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
|
|||||||
assert q_pe.shape == self.query.shape
|
assert q_pe.shape == self.query.shape
|
||||||
assert k_pe.shape == self.key.shape
|
assert k_pe.shape == self.key.shape
|
||||||
|
|
||||||
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
|
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||||
@patch("vllm.platforms.current_platform.device_type",
|
@patch("vllm.platforms.current_platform.device_type",
|
||||||
new=torch.device("cpu"))
|
new=torch.device("cpu"))
|
||||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||||
@@ -229,7 +229,7 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
|
|||||||
assert q_pe.shape == self.query.shape
|
assert q_pe.shape == self.query.shape
|
||||||
assert k_pe.shape == self.key.shape
|
assert k_pe.shape == self.key.shape
|
||||||
|
|
||||||
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
|
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||||
@patch("vllm.platforms.current_platform.device_type",
|
@patch("vllm.platforms.current_platform.device_type",
|
||||||
new=torch.device("cpu"))
|
new=torch.device("cpu"))
|
||||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||||
@@ -248,7 +248,7 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
|
|||||||
assert q_pe.shape == self.query.shape
|
assert q_pe.shape == self.query.shape
|
||||||
assert k_pe.shape == key.shape
|
assert k_pe.shape == key.shape
|
||||||
|
|
||||||
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
|
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||||
@patch("vllm.platforms.current_platform.device_type",
|
@patch("vllm.platforms.current_platform.device_type",
|
||||||
new=torch.device("cpu"))
|
new=torch.device("cpu"))
|
||||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import math
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.model_executor.layers.rotary_embedding import (
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||||
@@ -28,19 +27,18 @@ from vllm_ascend.platform import NPUPlatform
|
|||||||
from vllm_ascend.utils import enable_custom_op, is_310p
|
from vllm_ascend.utils import enable_custom_op, is_310p
|
||||||
|
|
||||||
|
|
||||||
def custom_rotary_embedding_enabled(query, neox_style, head_size):
|
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||||
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
|
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def rope_forward_oot(
|
def _rope_forward_oot(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
is_neox_style_override: Optional[bool] = None,
|
is_neox_style_override: Optional[bool] = None,
|
||||||
is_qwen_torchair: Optional[bool] = False,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
query_shape, key_shape = query.shape, key.shape
|
query_shape, key_shape = query.shape, key.shape
|
||||||
if self.cos_sin_cache.device != query.device:
|
if self.cos_sin_cache.device != query.device:
|
||||||
@@ -51,8 +49,8 @@ def rope_forward_oot(
|
|||||||
if is_neox_style_override is not None:
|
if is_neox_style_override is not None:
|
||||||
neox_style = is_neox_style_override
|
neox_style = is_neox_style_override
|
||||||
# adopt custom kernel path for rotary_embedding
|
# adopt custom kernel path for rotary_embedding
|
||||||
if custom_rotary_embedding_enabled(query, neox_style,
|
if _custom_rotary_embedding_enabled(query, neox_style,
|
||||||
self.head_size) and not is_310p():
|
self.head_size) and not is_310p():
|
||||||
query, key = torch.ops._C.rotary_embedding(
|
query, key = torch.ops._C.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
@@ -80,23 +78,6 @@ def rope_forward_oot(
|
|||||||
return query.view(query_shape), key.view(key_shape)
|
return query.view(query_shape), key.view(key_shape)
|
||||||
|
|
||||||
|
|
||||||
def set_cos_sin_cache(self, seq_len, device, dtype):
|
|
||||||
inv_freq = 1.0 / (self.base**(torch.arange(
|
|
||||||
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
|
|
||||||
(1 / self.rotary_dim)))
|
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
|
||||||
|
|
||||||
t = torch.arange(self.max_position_embeddings,
|
|
||||||
device=self.inv_freq.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
||||||
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
|
|
||||||
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
|
|
||||||
self.embed = F.embedding
|
|
||||||
|
|
||||||
|
|
||||||
class AscendRotaryEmbedding(RotaryEmbedding):
|
class AscendRotaryEmbedding(RotaryEmbedding):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -118,13 +99,15 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
is_neox_style_override: Optional[bool] = None,
|
is_neox_style_override: Optional[bool] = None,
|
||||||
max_seq_len: Optional[int] = None,
|
|
||||||
is_prefill: Optional[bool] = True,
|
|
||||||
is_qwen_torchair: Optional[bool] = False,
|
|
||||||
):
|
):
|
||||||
return rope_forward_oot(self, positions, query, key, offsets,
|
return _rope_forward_oot(
|
||||||
is_neox_style_override,
|
self,
|
||||||
is_qwen_torchair) # type: ignore
|
positions,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
offsets,
|
||||||
|
is_neox_style_override,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||||
@@ -328,6 +311,6 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
|||||||
b, h_k, d = key.shape
|
b, h_k, d = key.shape
|
||||||
key = key.view(b, h_k, d // 2, 2).transpose(3,
|
key = key.view(b, h_k, d // 2, 2).transpose(3,
|
||||||
2).reshape(b, h_k, d)
|
2).reshape(b, h_k, d)
|
||||||
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
|
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets,
|
||||||
neox_style)
|
neox_style)
|
||||||
return q_pe, k_pe
|
return q_pe, k_pe
|
||||||
|
|||||||
Reference in New Issue
Block a user