[5/N][refactor]add torchair rotary ops (#2559)
### What this PR does / why we need it? Move torchair related rotary ops into torchair dir to make the code clear. Next step we'll remove all torchair related code outside of torchair rotary ops. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: main vLLM main:ab9f2cfd19- vLLM version: v0.10.1.1 - vLLM main:81eea3d348Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
332
tests/ut/torchair/ops/test_torchair_rotary_embedding.py
Normal file
332
tests/ut/torchair/ops/test_torchair_rotary_embedding.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import math
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
|
||||
custom_rotary_embedding_enabled, native_rope_deepseek_forward,
|
||||
rope_forward_oot, rotate_half, yarn_find_correction_dim, yarn_get_mscale)
|
||||
|
||||
|
||||
class TestCustomRotaryEmbeddingEnabled(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
def test_custom_rotary_embedding_enabled(self):
|
||||
# Test when all conditions are True
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Test when dtype is not float16
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
query = self.query.to(torch.float32)
|
||||
result = custom_rotary_embedding_enabled(query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when neox_style is False
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, False,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when head_size is not divisible by 32
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size + 1)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when custom op is disabled
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=False):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestRopeForwardOot(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
def test_rope_forward_oot_torchair_enabled_base(self,
|
||||
mock_get_ascend_config):
|
||||
# Setup mock for torchair enabled
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
self.query, self.key)
|
||||
|
||||
self.mock_self.forward_native.assert_called_once_with(
|
||||
self.positions, self.query, self.key, None)
|
||||
self.assertTrue(torch.equal(result_q, self.query))
|
||||
self.assertTrue(torch.equal(result_k, self.key))
|
||||
|
||||
@patch('torch.ops._C')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p',
|
||||
return_value=False)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=True)
|
||||
@patch('torch.ops._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
||||
mock_custom_enabled, mock_is_310p,
|
||||
mock_get_ascend_config, mock__c):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Setup mock for custom kernel path
|
||||
|
||||
mock__c.rotary_embedding.return_value = self.query, self.key
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
self.query, self.key)
|
||||
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
self.assertEqual(result_k.shape, self.key.shape)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
||||
mock_custom_enabled,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test contiguous path when custom is disabled
|
||||
non_contig_query = self.query.transpose(0, 1)
|
||||
non_contig_key = self.key.transpose(0, 1)
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
non_contig_query, non_contig_key)
|
||||
|
||||
mock_npu_rotary.assert_called_once()
|
||||
self.assertEqual(result_q.shape, non_contig_query.shape)
|
||||
self.assertEqual(result_k.shape, non_contig_key.shape)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test that NotImplementedError is raised when offsets is provided
|
||||
offsets = torch.tensor([1, 2, 3])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
rope_forward_oot(self.mock_self, self.positions, self.query,
|
||||
self.key, offsets)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
||||
mock_custom_enabled,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test neox_style override
|
||||
result_q, result_k = rope_forward_oot(self.mock_self,
|
||||
self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
is_neox_style_override=False)
|
||||
|
||||
# Check that neox_style=False was passed to the NPU function
|
||||
args, kwargs = mock_npu_rotary.call_args
|
||||
self.assertFalse(args[-1])
|
||||
|
||||
|
||||
class MockRopeModule:
|
||||
|
||||
def __init__(self, max_seq_len=2048, is_neox_style=True):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.is_neox_style = is_neox_style
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
self.rotary_dim = 1
|
||||
self.base = 1
|
||||
|
||||
|
||||
class TestNativeRopeDeepseekForward(TestBase):
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
|
||||
module = MockRopeModule()
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding._set_cos_sin_cache'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_cache_handling(
|
||||
self, mock_rope_forward_oot, mock_set_cache):
|
||||
# Test cache situation is true
|
||||
module = MockRopeModule(max_seq_len=1024)
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
max_seq_len=2048)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_key_reshaping(
|
||||
self, mock_rope_forward_oot):
|
||||
module = MockRopeModule()
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == (1, 128)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_non_neox_style(
|
||||
self, mock_rope_forward_oot):
|
||||
module = MockRopeModule(is_neox_style=False)
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
|
||||
class TestRotateHalf(TestBase):
|
||||
|
||||
def test_rotate_half_even_dim(self):
|
||||
# Test with even dimension
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
expected = torch.tensor([-3.0, -4.0, 1.0, 2.0])
|
||||
result = rotate_half(x)
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
class TestYarnFindCorrectionDim(TestBase):
|
||||
|
||||
def test_basic_case(self):
|
||||
# Test with standard values
|
||||
num_rotations = 100
|
||||
dim = 512
|
||||
base = 10000
|
||||
max_position_embeddings = 2048
|
||||
|
||||
result = yarn_find_correction_dim(num_rotations, dim, base,
|
||||
max_position_embeddings)
|
||||
|
||||
# Calculate expected value manually
|
||||
expected = (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 *
|
||||
torch.log(torch.tensor(base)))
|
||||
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
class TestYarnGetMscale(TestBase):
|
||||
|
||||
def test_scale_less_than_or_equal_1(self):
|
||||
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
|
||||
self.assertEqual(yarn_get_mscale(scale=1.0), 1.0)
|
||||
self.assertEqual(yarn_get_mscale(scale=0.999), 1.0)
|
||||
|
||||
def test_scale_greater_than_1(self):
|
||||
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
|
||||
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
|
||||
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
|
||||
(math.e, 1.0, 1.0 + 0.1)]
|
||||
|
||||
for scale, mscale, expected in test_cases:
|
||||
result = yarn_get_mscale(scale, mscale)
|
||||
self.assertAlmostEqual(
|
||||
result,
|
||||
expected,
|
||||
places=6,
|
||||
msg=f"Failed for scale={scale}, mscale={mscale}")
|
||||
372
vllm_ascend/torchair/ops/torchair_rotary_embedding.py
Normal file
372
vllm_ascend/torchair/ops/torchair_rotary_embedding.py
Normal file
@@ -0,0 +1,372 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import enable_custom_op, is_310p
|
||||
|
||||
|
||||
def custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def rope_forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
is_qwen_torchair: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if get_ascend_config(
|
||||
).torchair_graph_config.enabled and not is_qwen_torchair:
|
||||
return self.forward_native(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
offsets,
|
||||
)
|
||||
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
neox_style = is_neox_style_override
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if custom_rotary_embedding_enabled(query, neox_style,
|
||||
self.head_size) and not is_310p():
|
||||
query, key = torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
def native_rope_deepseek_forward(self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
max_seq_len: Optional[int] = None):
|
||||
if max_seq_len is not None and max_seq_len > self.max_seq_len:
|
||||
_set_cos_sin_cache(self, max_seq_len, query.device, query.dtype)
|
||||
if len(key.shape) == 2:
|
||||
key = key[:, None, :]
|
||||
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
||||
# calculation method which is also more compute friendly to the ascend machine
|
||||
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
|
||||
neox_style = True
|
||||
if self.is_neox_style is False:
|
||||
b, h_q, d = query.shape
|
||||
query = query.view(b, h_q, d // 2, 2).transpose(3,
|
||||
2).reshape(b, h_q, d)
|
||||
b, h_k, d = key.shape
|
||||
key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
|
||||
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
|
||||
neox_style)
|
||||
return q_pe, k_pe
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def yarn_find_correction_dim(num_rotations,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
return (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base)))
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def yarn_find_correction_range(low_rot,
|
||||
high_rot,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
low = torch.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = torch.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
# Note: use torch instead of max/min to solve MTP compilation error.
|
||||
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min_value, max_value, dim):
|
||||
# Note: The if conditional branch is not used here
|
||||
# to solve MTP compilation error.
|
||||
max_value += (min_value == max_value).float() * 0.001
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) -
|
||||
min_value) / (max_value - min_value)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids]
|
||||
sin = sin[position_ids]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
|
||||
if len(q.shape) == 3:
|
||||
q = q[:, :, None, :]
|
||||
if len(k.shape) == 2:
|
||||
k = k[:, None, None, :]
|
||||
elif len(k.shape) == 3:
|
||||
k = k[:, :, None, :]
|
||||
|
||||
b, h_q, s, d = q.shape
|
||||
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
|
||||
|
||||
b, h_k, s, d = k.shape
|
||||
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
q_embed = q_embed.view(b, h_q, d)
|
||||
k_embed = k_embed.view(b, h_k, d)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
dim = self.rotary_dim
|
||||
|
||||
freq_extra = 1.0 / (self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
freq_inter = 1.0 / (self.scaling_factor * self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
|
||||
device=device, dtype=torch.float32)
|
||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len * self.scaling_factor,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
|
||||
cos_cached = cos_cached.to(dtype)
|
||||
sin_cached = sin_cached.to(dtype)
|
||||
cache = torch.cat([freqs.cos() * self.mscale,
|
||||
freqs.sin() * self.mscale],
|
||||
dim=-1).to(dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
|
||||
|
||||
def __set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
inv_freq = 1.0 / (self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
|
||||
(1 / self.rotary_dim)))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
t = torch.arange(self.max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
|
||||
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
|
||||
self.embed = F.embedding
|
||||
|
||||
|
||||
_original_re_init = RotaryEmbedding.__init__
|
||||
|
||||
|
||||
def qwen_rope_init_func(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
_original_re_init(self, head_size, rotary_dim, max_position_embeddings,
|
||||
base, is_neox_style, dtype)
|
||||
if get_ascend_config().torchair_graph_config.enabled:
|
||||
__set_cos_sin_cache(self,
|
||||
seq_len=max_position_embeddings,
|
||||
device="npu",
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def rope_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
max_seq_len: Optional[int] = None,
|
||||
is_prefill: Optional[bool] = True,
|
||||
is_qwen_torchair: Optional[bool] = False,
|
||||
):
|
||||
if get_ascend_config().torchair_graph_config.enabled \
|
||||
and is_qwen_torchair and not is_prefill:
|
||||
if max_seq_len is not None and torch.gt(max_seq_len,
|
||||
self.max_position_embeddings):
|
||||
__set_cos_sin_cache(self,
|
||||
seq_len=max_seq_len,
|
||||
device=query.device,
|
||||
dtype=torch.float32)
|
||||
|
||||
# bsnd/bnsd
|
||||
if positions is not None:
|
||||
cos = self.embed(positions, self.cos)
|
||||
sin = self.embed(positions, self.sin)
|
||||
self.cos_embed = cos
|
||||
self.sin_embed = sin
|
||||
else:
|
||||
cos = self.cos_embed
|
||||
sin = self.sin_embed
|
||||
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size).contiguous()
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
|
||||
|
||||
cos = cos.unsqueeze(-2).unsqueeze(-2)
|
||||
sin = sin.unsqueeze(-2).unsqueeze(-2)
|
||||
|
||||
query = query.unsqueeze(1)
|
||||
key = key.unsqueeze(1)
|
||||
|
||||
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
|
||||
query, key, cos, sin)
|
||||
return q_embed.flatten(-2), k_embed.flatten(-2)
|
||||
else:
|
||||
return rope_forward_oot(self, positions, query, key, offsets,
|
||||
is_neox_style_override,
|
||||
is_qwen_torchair) # type: ignore
|
||||
|
||||
|
||||
def deepseek_rope_init_func(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
super(DeepseekScalingRotaryEmbedding,
|
||||
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
self.max_seq_len = max_position_embeddings
|
||||
_set_cos_sin_cache(self,
|
||||
max_position_embeddings,
|
||||
dtype=dtype,
|
||||
device="npu")
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
#
|
||||
# isort: skip_file
|
||||
|
||||
import types
|
||||
from typing import Optional
|
||||
@@ -34,12 +34,10 @@ from vllm.logger import logger
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||
check_torchair_cache_exist,
|
||||
converting_weight_acl_format,
|
||||
register_torchair_model,
|
||||
torchair_quant_method_register,
|
||||
write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.torchair.utils import (
|
||||
TorchairCommonAttentionMetadata, check_torchair_cache_exist,
|
||||
converting_weight_acl_format, register_torchair_model, torchair_ops_patch,
|
||||
torchair_quant_method_register, write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_310p)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
@@ -68,6 +66,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
self._check_batch_sizes_consistency()
|
||||
register_torchair_model()
|
||||
torchair_ops_patch()
|
||||
torchair_quant_method_register()
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
|
||||
@@ -182,3 +182,18 @@ def torchair_quant_method_register():
|
||||
"W8A8_DYNAMIC"] = TorchairW8A8DYNAMICQuantizer
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE[
|
||||
"W4A8_DYNAMIC"] = TorchairW4A8DYNAMICQuantizer
|
||||
|
||||
|
||||
def torchair_ops_patch():
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
|
||||
deepseek_rope_init_func, native_rope_deepseek_forward,
|
||||
qwen_rope_init_func, rope_forward)
|
||||
|
||||
RotaryEmbedding.__init__ = qwen_rope_init_func
|
||||
RotaryEmbedding.forward_oot = rope_forward
|
||||
|
||||
DeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func
|
||||
DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward
|
||||
|
||||
Reference in New Issue
Block a user