[CustomOp] Register RotaryEmbedding instead of overwrite forward (#2385)

### What this PR does / why we need it?
Register RotaryEmbedding instead of overwrite forward

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.10.0
- vLLM main:
808d2e9aa0

---------

Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
Icey
2025-08-25 09:32:35 +08:00
committed by GitHub
parent 950c4b219a
commit f796e6280b
6 changed files with 426 additions and 381 deletions

View File

@@ -287,4 +287,4 @@ jobs:
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \
--ignore=tests/e2e/multicard/test_data_parallel.py \
--ignore=tests/e2e/multicard/test_offline_inference_310p.py
--ignore=tests/e2e/multicard/test_offline_inference_310p.py

View File

@@ -1,17 +1,16 @@
import math
from unittest.mock import MagicMock, patch
import unittest
from unittest.mock import MagicMock, PropertyMock, patch
import torch
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from tests.ut.base import TestBase
from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled,
native_rope_deepseek_forward,
rope_forward_oot, rotate_half,
yarn_find_correction_dim,
yarn_get_mscale)
from vllm_ascend.ops.rotary_embedding import custom_rotary_embedding_enabled
class TestCustomRotaryEmbeddingEnabled(TestBase):
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
def setUp(self):
# Common setup for tests
@@ -66,22 +65,28 @@ class TestCustomRotaryEmbeddingEnabled(TestBase):
self.assertFalse(result)
class TestRopeForwardOot(TestBase):
class TestAscendRotaryEmbedding(unittest.TestCase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 4, dtype=torch.float16)
self.key = torch.randn(3, 4, dtype=torch.float16)
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
self.head_size = 32
self.cos_sin_cache = torch.randn(3, 4)
self.rotary_dim = self.head_size
self.max_position = 16
self.rope_theta = 10000
self.is_neox_style = True
self.cos_sin_cache = torch.randn(3, 1, 32)
self.layer = RotaryEmbedding(self.head_size, self.rotary_dim,
self.max_position, self.rope_theta,
self.is_neox_style, torch.float16)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = True
self.mock_self.forward_native.return_value = (self.query, self.key)
self.mock_self.is_neox_style = self.is_neox_style
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
def test_rope_forward_oot_torchair_enabled_base(self,
@@ -90,12 +95,14 @@ class TestRopeForwardOot(TestBase):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
with patch.object(self.layer,
"forward_native",
return_value=(self.query,
self.key)) as mock_forward_native:
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
self.query, self.key)
self.mock_self.forward_native.assert_called_once_with(
self.positions, self.query, self.key, None)
mock_forward_native.assert_called_once()
self.assertTrue(torch.equal(result_q, self.query))
self.assertTrue(torch.equal(result_k, self.key))
@@ -116,9 +123,10 @@ class TestRopeForwardOot(TestBase):
mock__c.rotary_embedding.return_value = self.query, self.key
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
self.query, self.key)
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
mock__c.rotary_embedding.assert_called_once()
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)
@@ -137,8 +145,9 @@ class TestRopeForwardOot(TestBase):
non_contig_query = self.query.transpose(0, 1)
non_contig_key = self.key.transpose(0, 1)
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
non_contig_query, non_contig_key)
result_q, result_k = self.layer.forward(self.positions,
non_contig_query,
non_contig_key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, non_contig_query.shape)
@@ -153,8 +162,7 @@ class TestRopeForwardOot(TestBase):
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
rope_forward_oot(self.mock_self, self.positions, self.query,
self.key, offsets)
self.layer.forward(self.positions, self.query, self.key, offsets)
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
@@ -168,11 +176,10 @@ class TestRopeForwardOot(TestBase):
mock_get_ascend_config.return_value = mock_config
# Test neox_style override
result_q, result_k = rope_forward_oot(self.mock_self,
self.positions,
self.query,
self.key,
is_neox_style_override=False)
result_q, result_k = self.layer.forward(self.positions,
self.query,
self.key,
is_neox_style_override=False)
# Check that neox_style=False was passed to the NPU function
args, kwargs = mock_npu_rotary.call_args
@@ -190,98 +197,118 @@ class MockRopeModule:
self.base = 1
class TestNativeRopeDeepseekForward(TestBase):
class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
self.head_size = 32
self.rotary_dim = self.head_size
self.max_position = 16
self.rope_theta = 10000
self.is_neox_style = True
self.scaling_factor = 1
self.layer = None
def _create_layer(self):
self.layer = DeepseekScalingRotaryEmbedding(
self.head_size, self.rotary_dim, self.max_position,
self.rope_theta, self.is_neox_style, self.scaling_factor,
torch.float16)
return self.layer
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_base(self, mock_npuplatform):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
with patch("vllm_ascend.ops.rotary_embedding.rope_forward_oot",
return_value=(self.query,
self.key)) as mock_rope_forward_oot:
q_pe, k_pe = self.layer.forward(self.positions, self.query,
self.key)
mock_rope_forward_oot.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
module = MockRopeModule()
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
@patch('vllm_ascend.ops.rotary_embedding._set_cos_sin_cache')
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_cache_handling(
self, mock_rope_forward_oot, mock_set_cache):
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
self.layer.max_seq_len = 1024
# Test cache situation is true
module = MockRopeModule(max_seq_len=1024)
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache:
mock_rope_forward_oot.return_value = (self.query, self.key)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module,
positions,
query,
key,
max_seq_len=2048)
assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
q_pe, k_pe = self.layer.forward(self.positions,
self.query,
self.key,
max_seq_len=2048)
mock_set_cache.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_key_reshaping(
self, mock_rope_forward_oot):
module = MockRopeModule()
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 128)
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
mock_rope_forward_oot.return_value = (query, key)
key = torch.randn(1, 32)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
mock_rope_forward_oot.return_value = (self.query, key)
assert q_pe.shape == query.shape
assert k_pe.shape == (1, 128)
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_non_neox_style(
self, mock_rope_forward_oot):
module = MockRopeModule(is_neox_style=False)
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
q_pe, k_pe = self.layer.forward(self.positions, self.query, key)
mock_rope_forward_oot.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == key.shape
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_non_neox_style(
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
class TestRotateHalf(TestBase):
mock_rope_forward_oot.return_value = (self.query, self.key)
def test_rotate_half_even_dim(self):
# Test with even dimension
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
expected = torch.tensor([-3.0, -4.0, 1.0, 2.0])
result = rotate_half(x)
self.assertTrue(torch.allclose(result, expected))
q_pe, k_pe = self.layer.forward(self.positions, self.query, self.key)
mock_rope_forward_oot.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
class TestYarnFindCorrectionDim(TestBase):
def test_basic_case(self):
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_basic_case(self, mock_npuplatform):
# Test with standard values
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
num_rotations = 100
dim = 512
base = 10000
max_position_embeddings = 2048
result = yarn_find_correction_dim(num_rotations, dim, base,
max_position_embeddings)
result = self.layer._yarn_find_correction_dim(num_rotations, dim, base,
max_position_embeddings)
# Calculate expected value manually
expected = (dim * torch.log(
@@ -291,22 +318,27 @@ class TestYarnFindCorrectionDim(TestBase):
self.assertTrue(torch.allclose(result, expected))
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_yarn_get_mscale(self, mock_npuplatform):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
class TestYarnGetMscale(TestBase):
# test_scale_less_than_or_equal_1
self.assertEqual(self.layer._yarn_get_mscale(scale=0.5), 1.0)
self.assertEqual(self.layer._yarn_get_mscale(scale=1.0), 1.0)
self.assertEqual(self.layer._yarn_get_mscale(scale=0.999), 1.0)
def test_scale_less_than_or_equal_1(self):
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
self.assertEqual(yarn_get_mscale(scale=1.0), 1.0)
self.assertEqual(yarn_get_mscale(scale=0.999), 1.0)
def test_scale_greater_than_1(self):
# test_scale_greater_than_1:
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
(math.e, 1.0, 1.0 + 0.1)]
for scale, mscale, expected in test_cases:
result = yarn_get_mscale(scale, mscale)
result = self.layer._yarn_get_mscale(scale, mscale)
self.assertAlmostEqual(
result,
expected,

View File

@@ -356,13 +356,13 @@ class TestUtils(TestBase):
# ascend custom op is not registered
utils.register_ascend_customop()
# should call register_oot three
self.assertEqual(mock_customop.register_oot.call_count, 6)
self.assertEqual(mock_customop.register_oot.call_count, 8)
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
# ascend custom op is already registered
utils.register_ascend_customop()
# should not register_oot again, thus only called three in this ut
self.assertEqual(mock_customop.register_oot.call_count, 6)
self.assertEqual(mock_customop.register_oot.call_count, 8)
class TestProfileExecuteDuration(TestBase):

View File

@@ -17,12 +17,13 @@
import torch
import vllm_ascend.ops.activation # noqa
import vllm_ascend.ops.common_fused_moe # noqa
import vllm_ascend.ops.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.rotary_embedding # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
class dummyFusionOp:
@@ -47,3 +48,9 @@ def register_dummy_fusion_op() -> None:
name="fused_add_rms_norm_static_fp8_quant")
torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp(
name="rms_norm_dynamic_per_token_quant")
__all__ = [
"AscendQuickGELU", "AscendSiluAndMul", "AscendRotaryEmbedding",
"AscendDeepseekScalingRotaryEmbedding"
]

View File

@@ -25,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import enable_custom_op, is_310p
@@ -89,167 +90,7 @@ def rope_forward_oot(
return query.view(query_shape), key.view(key_shape)
def native_rope_deepseek_forward(self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
max_seq_len: Optional[int] = None):
if max_seq_len is not None and max_seq_len > self.max_seq_len:
_set_cos_sin_cache(self, max_seq_len, query.device, query.dtype)
if len(key.shape) == 2:
key = key[:, None, :]
# Note: we implement the non neox_style method with shuffle the last dim and neox style
# calculation method which is also more compute friendly to the ascend machine
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
neox_style = True
if self.is_neox_style is False:
b, h_q, d = query.shape
query = query.view(b, h_q, d // 2, 2).transpose(3,
2).reshape(b, h_q, d)
b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
neox_style)
return q_pe, k_pe
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(num_rotations,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error.
return (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base)))
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
# Find dim range bounds based on rotations
def yarn_find_correction_range(low_rot,
high_rot,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error.
low = torch.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = torch.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
# Note: use torch instead of max/min to solve MTP compilation error.
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
def yarn_linear_ramp_mask(min_value, max_value, dim):
# Note: The if conditional branch is not used here
# to solve MTP compilation error.
max_value += (min_value == max_value).float() * 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) -
min_value) / (max_value - min_value)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids]
sin = sin[position_ids]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
if len(q.shape) == 3:
q = q[:, :, None, :]
if len(k.shape) == 2:
k = k[:, None, None, :]
elif len(k.shape) == 3:
k = k[:, :, None, :]
b, h_q, s, d = q.shape
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
b, h_k, s, d = k.shape
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.view(b, h_q, d)
k_embed = k_embed.view(b, h_k, d)
return q_embed, k_embed
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
dim = self.rotary_dim
freq_extra = 1.0 / (self.base**(
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
freq_inter = 1.0 / (self.scaling_factor * self.base**(
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.max_position_embeddings,
)
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
device=device, dtype=torch.float32)
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(seq_len * self.scaling_factor,
device=device,
dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
cos_cached = cos_cached.to(dtype)
sin_cached = sin_cached.to(dtype)
cache = torch.cat([freqs.cos() * self.mscale,
freqs.sin() * self.mscale],
dim=-1).to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
def __set_cos_sin_cache(self, seq_len, device, dtype):
def set_cos_sin_cache(self, seq_len, device, dtype):
inv_freq = 1.0 / (self.base**(torch.arange(
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
(1 / self.rotary_dim)))
@@ -266,117 +107,275 @@ def __set_cos_sin_cache(self, seq_len, device, dtype):
self.embed = F.embedding
_original_re_init = RotaryEmbedding.__init__
class AscendRotaryEmbedding(RotaryEmbedding):
def qwen_rope_init_func(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
_original_re_init(self, head_size, rotary_dim, max_position_embeddings,
base, is_neox_style, dtype)
if get_ascend_config().torchair_graph_config.enabled:
__set_cos_sin_cache(self,
seq_len=max_position_embeddings,
device="npu",
dtype=dtype)
def rope_forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
max_seq_len: Optional[int] = None,
is_prefill: Optional[bool] = True,
is_qwen_torchair: Optional[bool] = False,
):
if get_ascend_config().torchair_graph_config.enabled \
and is_qwen_torchair and not is_prefill:
if max_seq_len is not None and torch.gt(max_seq_len,
self.max_position_embeddings):
__set_cos_sin_cache(self,
seq_len=max_seq_len,
device=query.device,
dtype=torch.float32)
# bsnd/bnsd
if positions is not None:
cos = self.embed(positions, self.cos)
sin = self.embed(positions, self.sin)
self.cos_embed = cos
self.sin_embed = sin
else:
cos = self.cos_embed
sin = self.sin_embed
query = query.view(*query.shape[:-1], -1, self.head_size).contiguous()
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
cos = cos.unsqueeze(-2).unsqueeze(-2)
sin = sin.unsqueeze(-2).unsqueeze(-2)
query = query.unsqueeze(1)
key = key.unsqueeze(1)
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
query, key, cos, sin)
return q_embed.flatten(-2), k_embed.flatten(-2)
else:
return rope_forward_oot(self, positions, query, key, offsets,
is_neox_style_override,
is_qwen_torchair) # type: ignore
def deepseek_rope_init_func(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super(DeepseekScalingRotaryEmbedding,
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
self.max_seq_len = max_position_embeddings
_set_cos_sin_cache(self,
max_position_embeddings,
dtype=dtype,
device="npu")
if get_ascend_config().torchair_graph_config.enabled:
set_cos_sin_cache(self,
seq_len=max_position_embeddings,
device="npu",
dtype=dtype)
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
max_seq_len: Optional[int] = None,
is_prefill: Optional[bool] = True,
is_qwen_torchair: Optional[bool] = False,
):
if get_ascend_config().torchair_graph_config.enabled \
and is_qwen_torchair and not is_prefill:
if max_seq_len is not None and torch.gt(
max_seq_len, self.max_position_embeddings):
set_cos_sin_cache(self,
seq_len=max_seq_len,
device=query.device,
dtype=torch.float32)
# bsnd/bnsd
if positions is not None:
cos = self.embed(positions, self.cos)
sin = self.embed(positions, self.sin)
self.cos_embed = cos
self.sin_embed = sin
else:
cos = self.cos_embed
sin = self.sin_embed
query = query.view(*query.shape[:-1], -1,
self.head_size).contiguous()
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
cos = cos.unsqueeze(-2).unsqueeze(-2)
sin = sin.unsqueeze(-2).unsqueeze(-2)
query = query.unsqueeze(1)
key = key.unsqueeze(1)
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
query, key, cos, sin)
return q_embed.flatten(-2), k_embed.flatten(-2)
else:
return rope_forward_oot(self, positions, query, key, offsets,
is_neox_style_override,
is_qwen_torchair) # type: ignore
RotaryEmbedding.__init__ = qwen_rope_init_func
RotaryEmbedding.forward_oot = rope_forward
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
# Note: we adopt the native huggingface deepseek rope initialization code from
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
# its more ascend compute friendly
DeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func
DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
# Note: we adopt the native huggingface deepseek rope initialization code from
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
# its more ascend compute friendly
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
self._yarn_get_mscale(self.scaling_factor, float(mscale)) /
self._yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super(DeepseekScalingRotaryEmbedding,
self).__init__(head_size, rotary_dim, max_position_embeddings,
base, is_neox_style, dtype)
self.max_seq_len = max_position_embeddings
self._set_cos_sin_cache(seq_len=max_position_embeddings,
device=NPUPlatform.device_type,
dtype=dtype)
def _yarn_get_mscale(self, scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def _rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def _yarn_linear_ramp_mask(self, min_value, max_value, dim):
# Note: The if conditional branch is not used here
# to solve MTP compilation error.
max_value += (min_value == max_value).float() * 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) -
min_value) / (max_value - min_value)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(self,
num_rotations,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error.
return (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 *
torch.log(torch.tensor(base)))
# Find dim range bounds based on rotations
def _yarn_find_correction_range(self,
low_rot,
high_rot,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error.
low = torch.floor(
self._yarn_find_correction_dim(low_rot, dim, base,
max_position_embeddings))
high = torch.ceil(
self._yarn_find_correction_dim(high_rot, dim, base,
max_position_embeddings))
# Note: use torch instead of max/min to solve MTP compilation error.
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def _apply_rotary_pos_emb(self,
q,
k,
cos,
sin,
position_ids,
unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids]
sin = sin[position_ids]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
if len(q.shape) == 3:
q = q[:, :, None, :]
if len(k.shape) == 2:
k = k[:, None, None, :]
elif len(k.shape) == 3:
k = k[:, :, None, :]
b, h_q, s, d = q.shape
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
b, h_k, s, d = k.shape
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
q_embed = (q * cos) + (self._rotate_half(q) * sin)
k_embed = (k * cos) + (self._rotate_half(k) * sin)
q_embed = q_embed.view(b, h_q, d)
k_embed = k_embed.view(b, h_k, d)
return q_embed, k_embed
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
dim = self.rotary_dim
freq_extra = 1.0 / (self.base**(
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
freq_inter = 1.0 / (self.scaling_factor * self.base**(
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
low, high = self._yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.max_position_embeddings,
)
inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(
low, high, dim // 2).to(device=device, dtype=torch.float32)
inv_freq = freq_inter * (1 -
inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(seq_len * self.scaling_factor,
device=device,
dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
cos_cached = cos_cached.to(dtype)
sin_cached = sin_cached.to(dtype)
cache = torch.cat(
[freqs.cos() * self.mscale,
freqs.sin() * self.mscale], dim=-1).to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
def forward(self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
max_seq_len: Optional[int] = None):
if max_seq_len is not None and max_seq_len > self.max_seq_len:
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
if len(key.shape) == 2:
key = key[:, None, :]
# Note: we implement the non neox_style method with shuffle the last dim and neox style
# calculation method which is also more compute friendly to the ascend machine
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
neox_style = True
if self.is_neox_style is False:
b, h_q, d = query.shape
query = query.view(b, h_q, d // 2,
2).transpose(3, 2).reshape(b, h_q, d)
b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3,
2).reshape(b, h_k, d)
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
neox_style)
return q_pe, k_pe

View File

@@ -478,9 +478,16 @@ def register_ascend_customop():
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
AscendMlpMergedColumnParallelLinear,
AscendMlpRowParallelLinear)
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
CustomOp.register_oot(_decorated_op_cls=AscendQuickGELU, name="QuickGELU")
CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul,
name="SiluAndMul")
CustomOp.register_oot(_decorated_op_cls=AscendRotaryEmbedding,
name="RotaryEmbedding")
CustomOp.register_oot(
_decorated_op_cls=AscendDeepseekScalingRotaryEmbedding,
name="DeepseekScalingRotaryEmbedding")
if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE:
CustomOp.register_oot(_decorated_op_cls=AscendMlpColumnParallelLinear,
name="ColumnParallelLinear")