[Ops][Refactor] Remove custom rotary_embedding operator (#6523)
### What this PR does / why we need it? This PR removes the custom `rotary_embedding` operator and its associated C++ kernel implementation, PyTorch bindings, and tests. The codebase now falls back to using the native `torch_npu._npu_rotary_embedding` implementation. This change simplifies the codebase by removing custom, platform-specific kernel code and relying on the standard NPU library implementation, which is presumably more optimized and easier to maintain. ### Does this PR introduce _any_ user-facing change? No. This is an internal refactoring and does not introduce any user-facing changes. ### How was this patch tested? The tests for the custom `rotary_embedding` operator have been removed along with the operator itself. The correctness of the fallback to the native `torch_npu` implementation is verified by existing CI tests for attention layers and models that use rotary embeddings. - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -1,351 +0,0 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
|
||||
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py
|
||||
|
||||
import gc
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
# Only Neox style true scenario is supported for now
|
||||
IS_NEOX_STYLE = [True]
|
||||
DTYPES = [torch.half]
|
||||
HEAD_SIZES = [64, 64, 96, 128, 256]
|
||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||
NUM_HEADS = [17] # Arbitrary values for testing
|
||||
BATCH_SIZES = [5] # Arbitrary values for testing
|
||||
SEQ_LENS = [11, 4096] # Arbitrary values for testing
|
||||
NUM_TOKENS = [10, 21]
|
||||
SEEDS = [0]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
# Set tolerance to 1 for quant ops
|
||||
DEFAULT_ATOL = 1e-3
|
||||
DEFAULT_RTOL = 1e-3
|
||||
|
||||
|
||||
def _apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding.py
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
# use CPU to compute the cache and then move it to GPU. However, we
|
||||
# create the cache on GPU for faster initialization. This may cause
|
||||
# a slight numerical difference between the HF implementation and ours.
|
||||
inv_freq = 1.0 / (base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
# test with leading dimension and merge seqlen and batch_size as num_tokens
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding_quant_with_leading_dim(
|
||||
is_neox_style: bool,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, dtype)
|
||||
rope = rope.to(dtype=dtype)
|
||||
num_tokens = batch_size * seq_len
|
||||
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
|
||||
qkv_tensor = torch.randn(num_tokens,
|
||||
num_heads * head_size * 3,
|
||||
dtype=dtype)
|
||||
query, key, _ = qkv_tensor.split(
|
||||
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
rope.head_size,
|
||||
rope.cos_sin_cache,
|
||||
rope.is_neox_style,
|
||||
)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(query.view(ref_query.size()),
|
||||
ref_query,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
torch.testing.assert_close(key.view(ref_key.size()),
|
||||
ref_key,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
class ModelwithRotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3)
|
||||
self.rope = RotaryEmbedding(
|
||||
head_size=head_size,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=base,
|
||||
is_neox_style=is_neox_style,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.o_proj = nn.Linear(num_heads * head_size, hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
q,
|
||||
k,
|
||||
self.rope.head_size,
|
||||
self.rope.cos_sin_cache,
|
||||
self.rope.is_neox_style,
|
||||
)
|
||||
query = query.view(q.shape)
|
||||
key = key.view(k.shape)
|
||||
o = self.o_proj(query)
|
||||
return o
|
||||
|
||||
|
||||
# The first graph seems will have some accuracy issue when directly run pytest on the ops folder,
|
||||
# add a warmup graph replay for workaround
|
||||
ACL_GRPAH_FIRST_RUN = True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("num_tokens", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_capture_rotary_embedding_in_aclgraph(
|
||||
is_neox_style: bool,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position_embeddings: int = 8192,
|
||||
base: int = 10000,
|
||||
):
|
||||
"""Test if the rotary embedding can be captured in aclgraph."""
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
model = ModelwithRotaryEmbedding(
|
||||
hidden_size=num_heads * head_size,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=base,
|
||||
is_neox_style=is_neox_style,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input):
|
||||
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
|
||||
# string match
|
||||
graph = str(gm.graph)
|
||||
assert "_C_ascend.rotary_embedding" in graph
|
||||
return gm
|
||||
|
||||
static_positions = torch.randint(0, max_position_embeddings,
|
||||
(num_tokens, ))
|
||||
static_hidden_states = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device="npu")
|
||||
compiled_model = torch.compile(model, backend=custom_op_checking_backend)
|
||||
stream = torch.npu.Stream()
|
||||
stream.wait_stream(torch.npu.current_stream())
|
||||
with torch.npu.stream(stream):
|
||||
# warmup the fx graph before capture
|
||||
for i in range(3):
|
||||
static_output = compiled_model(static_positions,
|
||||
static_hidden_states,
|
||||
offsets=None)
|
||||
stream.wait_stream(torch.npu.current_stream())
|
||||
|
||||
aclgraph = torch.npu.NPUGraph()
|
||||
|
||||
with torch.npu.graph(aclgraph):
|
||||
# Capture the model in aclgraph.
|
||||
static_output = compiled_model(static_positions, static_hidden_states)
|
||||
# Capture the model in aclgraph.
|
||||
random_filled_positions = torch.randint(0,
|
||||
max_position_embeddings,
|
||||
(num_tokens, ),
|
||||
device="npu")
|
||||
random_filled_hidden_states = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device="npu")
|
||||
static_positions.copy_(random_filled_positions)
|
||||
static_hidden_states.copy_(random_filled_hidden_states)
|
||||
|
||||
aclgraph.replay()
|
||||
global ACL_GRPAH_FIRST_RUN
|
||||
if ACL_GRPAH_FIRST_RUN:
|
||||
ACL_GRPAH_FIRST_RUN = False
|
||||
return
|
||||
output_reference = model(static_positions, static_hidden_states)
|
||||
torch.testing.assert_close(static_output,
|
||||
output_reference,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -1,470 +0,0 @@
|
||||
import math
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding)
|
||||
from vllm.platforms import CpuArchEnum
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
|
||||
from vllm_ascend.utils import AscendDeviceType
|
||||
|
||||
MODEL = "Qwen3-0.6B"
|
||||
MODEL_VL = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
MAX_NUM_BATCHED_TOKEND = 10000
|
||||
|
||||
|
||||
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
def test_custom_rotary_embedding_enabled(self):
|
||||
# Test when all conditions are True
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = _custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Test when dtype is not float16
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
query = self.query.to(torch.float32)
|
||||
result = _custom_rotary_embedding_enabled(query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when neox_style is False
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = _custom_rotary_embedding_enabled(self.query, False,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when head_size is not divisible by 32
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = _custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size + 1)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when custom op is disabled
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=False):
|
||||
result = _custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.config_patcher = patch('vllm.config.vllm.get_current_vllm_config')
|
||||
self.mock_get_config = self.config_patcher.start()
|
||||
mock_config = MagicMock()
|
||||
mock_config.compilation_config.custom_ops = ["all"]
|
||||
|
||||
self.mock_get_config.return_value = mock_config
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.rotary_dim = self.head_size
|
||||
self.max_position = 16
|
||||
self.rope_theta = 10000
|
||||
self.is_neox_style = True
|
||||
self.cos_sin_cache = torch.randn(3, 1, 32)
|
||||
self.layer = RotaryEmbedding(self.head_size, self.rotary_dim,
|
||||
self.max_position, self.rope_theta,
|
||||
self.is_neox_style, torch.float16)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = self.is_neox_style
|
||||
|
||||
@patch('torch.ops._C_ascend')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType.A3)
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=True)
|
||||
@patch('torch.ops._npu_rotary_embedding')
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
||||
mock_custom_enabled,
|
||||
mock_soc_version, mock__c):
|
||||
mock__c.rotary_embedding.return_value = self.query, self.key
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL,
|
||||
tokenizer=MODEL,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKEND)
|
||||
model_config.hf_text_config = PretrainedConfig()
|
||||
vllm_config.model_config = model_config
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
result_q, result_k = self.layer.forward(self.positions, self.query,
|
||||
self.key)
|
||||
|
||||
mock__c.rotary_embedding.assert_called_once()
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
self.assertEqual(result_k.shape, self.key.shape)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
||||
mock_custom_enabled):
|
||||
# Test contiguous path when custom is disabled
|
||||
non_contig_query = self.query.transpose(0, 1)
|
||||
non_contig_key = self.key.transpose(0, 1)
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL,
|
||||
tokenizer=MODEL,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKEND)
|
||||
model_config.hf_text_config = PretrainedConfig()
|
||||
vllm_config.model_config = model_config
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
result_q, result_k = self.layer.forward(self.positions,
|
||||
non_contig_query,
|
||||
non_contig_key)
|
||||
|
||||
mock_npu_rotary.assert_called_once()
|
||||
self.assertEqual(result_q.shape, non_contig_query.shape)
|
||||
self.assertEqual(result_k.shape, non_contig_key.shape)
|
||||
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_with_offsets(self):
|
||||
# Test that NotImplementedError is raised when offsets is provided
|
||||
offsets = torch.tensor([1, 2, 3])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL,
|
||||
tokenizer=MODEL,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKEND)
|
||||
model_config.hf_text_config = PretrainedConfig()
|
||||
vllm_config.model_config = model_config
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
self.layer.forward(self.positions, self.query, self.key,
|
||||
offsets)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
||||
mock_custom_enabled):
|
||||
# Test neox_style override
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL,
|
||||
tokenizer=MODEL,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKEND)
|
||||
model_config.hf_text_config = PretrainedConfig()
|
||||
vllm_config.model_config = model_config
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
result_q, result_k = self.layer.forward(
|
||||
self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
is_neox_style_override=False)
|
||||
# Check that neox_style=False was passed to the NPU function
|
||||
args, kwargs = mock_npu_rotary.call_args
|
||||
self.assertFalse(args[-1])
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_rotary_dim_less_than_head_size(
|
||||
self, mock_npu_rotary, mock_custom_enabled):
|
||||
# test case when rotary_dim < head_size
|
||||
org_rotary_dim = self.layer.rotary_dim
|
||||
self.layer.rotary_dim = self.layer.head_size // 2
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL,
|
||||
tokenizer=MODEL,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKEND)
|
||||
model_config.hf_text_config = PretrainedConfig()
|
||||
vllm_config.model_config = model_config
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
result_q, result_k = self.layer.forward(self.positions, self.query,
|
||||
self.key)
|
||||
|
||||
mock_npu_rotary.assert_called_once()
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
self.assertEqual(result_k.shape, self.key.shape)
|
||||
|
||||
# restore rotary_dim
|
||||
self.layer.rotary_dim = org_rotary_dim
|
||||
|
||||
|
||||
class MockRopeModule:
|
||||
|
||||
def __init__(self, max_seq_len=2048, is_neox_style=True):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.is_neox_style = is_neox_style
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
self.rotary_dim = 1
|
||||
self.base = 1
|
||||
|
||||
|
||||
class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.config_patcher = patch('vllm.config.vllm.get_current_vllm_config')
|
||||
self.mock_get_config = self.config_patcher.start()
|
||||
mock_config = MagicMock()
|
||||
mock_config.compilation_config.custom_ops = ["all"]
|
||||
|
||||
self.mock_get_config.return_value = mock_config
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.rotary_dim = self.head_size
|
||||
self.max_position = 16
|
||||
self.rope_theta = 10000
|
||||
self.is_neox_style = True
|
||||
self.scaling_factor = 1
|
||||
self.layer = None
|
||||
|
||||
def _create_layer(self):
|
||||
self.layer = DeepseekScalingRotaryEmbedding(
|
||||
self.head_size, self.rotary_dim, self.max_position,
|
||||
self.rope_theta, self.is_neox_style, self.scaling_factor,
|
||||
torch.float16)
|
||||
return self.layer
|
||||
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_native_rope_deepseek_forward_base(self, mock_npuplatform):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
with patch("vllm_ascend.ops.rotary_embedding._rope_forward_oot",
|
||||
return_value=(self.query,
|
||||
self.key)) as mock_rope_forward_oot:
|
||||
q_pe, k_pe = self.layer.forward(self.positions, self.query,
|
||||
self.key)
|
||||
mock_rope_forward_oot.assert_called_once()
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == self.key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_native_rope_deepseek_forward_key_reshaping(
|
||||
self, mock_npuplatform, mock_rope_forward_oot):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
|
||||
key = torch.randn(1, 32)
|
||||
|
||||
mock_rope_forward_oot.return_value = (self.query, key)
|
||||
|
||||
q_pe, k_pe = self.layer.forward(self.positions, self.query, key)
|
||||
mock_rope_forward_oot.assert_called_once()
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_native_rope_deepseek_forward_non_neox_style(
|
||||
self, mock_npuplatform, mock_rope_forward_oot):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
|
||||
mock_rope_forward_oot.return_value = (self.query, self.key)
|
||||
|
||||
q_pe, k_pe = self.layer.forward(self.positions, self.query, self.key)
|
||||
|
||||
mock_rope_forward_oot.assert_called_once()
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == self.key.shape
|
||||
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_basic_case(self, mock_npuplatform):
|
||||
# Test with standard values
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
num_rotations = 100
|
||||
dim = 512
|
||||
base = 10000
|
||||
max_position_embeddings = 2048
|
||||
|
||||
result = self.layer._yarn_find_correction_dim(num_rotations, dim, base,
|
||||
max_position_embeddings)
|
||||
|
||||
# Calculate expected value manually
|
||||
expected = (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 *
|
||||
torch.log(torch.tensor(base)))
|
||||
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_yarn_get_mscale(self, mock_npuplatform):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
|
||||
# test_scale_less_than_or_equal_1
|
||||
self.assertEqual(self.layer._yarn_get_mscale(scale=0.5), 1.0)
|
||||
self.assertEqual(self.layer._yarn_get_mscale(scale=1.0), 1.0)
|
||||
self.assertEqual(self.layer._yarn_get_mscale(scale=0.999), 1.0)
|
||||
|
||||
# test_scale_greater_than_1:
|
||||
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
|
||||
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
|
||||
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
|
||||
(math.e, 1.0, 1.0 + 0.1)]
|
||||
|
||||
for scale, mscale, expected in test_cases:
|
||||
result = self.layer._yarn_get_mscale(scale, mscale)
|
||||
self.assertAlmostEqual(
|
||||
result,
|
||||
expected,
|
||||
places=6,
|
||||
msg=f"Failed for scale={scale}, mscale={mscale}")
|
||||
|
||||
|
||||
class TestAscendMRotaryEmbedding(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.config_patcher = patch('vllm.config.vllm.get_current_vllm_config')
|
||||
self.mock_get_config = self.config_patcher.start()
|
||||
mock_config = MagicMock()
|
||||
mock_config.compilation_config.custom_ops = ["all"]
|
||||
self.mock_get_config.return_value = mock_config
|
||||
self.number_tokens = 3
|
||||
self.num_head = 8
|
||||
self.num_kvhead = 8
|
||||
self.head_size = 128
|
||||
self.max_position_embeddings = 128000
|
||||
self.is_neox_style = True
|
||||
self.rope_theta = 1000000.0
|
||||
self.positions_1d = torch.tensor([1, 2, 3])
|
||||
self.positions_2d = torch.randint(1, 10, (3, self.number_tokens))
|
||||
|
||||
self.query = torch.randn(
|
||||
(self.number_tokens, self.num_head * self.head_size),
|
||||
dtype=torch.bfloat16)
|
||||
self.key = torch.randn(
|
||||
(self.number_tokens, self.num_kvhead * self.head_size),
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
# Qwen2.5-VL mrope section case
|
||||
self.mrope_section = [16, 24, 24]
|
||||
|
||||
self.layer = MRotaryEmbedding(self.head_size,
|
||||
self.head_size,
|
||||
self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
is_neox_style=self.is_neox_style,
|
||||
dtype=torch.bfloat16,
|
||||
mrope_section=self.mrope_section)
|
||||
|
||||
self.mock_config = MagicMock()
|
||||
|
||||
def _create_vllm_config(self):
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL_VL,
|
||||
tokenizer=MODEL_VL,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKEND)
|
||||
model_config.hf_text_config = PretrainedConfig()
|
||||
vllm_config.model_config = model_config
|
||||
return vllm_config
|
||||
|
||||
@patch('torch_npu.npu_mrope')
|
||||
@patch('vllm_ascend.platform.NPUPlatform.get_cpu_architecture')
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_forward_oot_1d_positions(self, mock_cpu_arc, mock_npu_mrope):
|
||||
mock_cpu_arc.return_value = CpuArchEnum.ARM
|
||||
|
||||
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
|
||||
torch.zeros_like(self.key))
|
||||
|
||||
vllm_config = self._create_vllm_config()
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
result_q, result_k = self.layer.forward_oot(
|
||||
self.positions_1d, self.query, self.key)
|
||||
|
||||
mock_npu_mrope.assert_called_once()
|
||||
self.assertFalse(torch.isnan(result_q).any().item())
|
||||
self.assertFalse(torch.isnan(result_k).any().item())
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
|
||||
@patch('torch_npu.npu_mrope')
|
||||
@patch('vllm_ascend.platform.NPUPlatform.get_cpu_architecture')
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_forward_oot_2d_positions(self, mock_cpu_arc, mock_npu_mrope):
|
||||
mock_cpu_arc.return_value = CpuArchEnum.ARM
|
||||
|
||||
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
|
||||
torch.zeros_like(self.key))
|
||||
|
||||
vllm_config = self._create_vllm_config()
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
result_q, result_k = self.layer.forward_oot(
|
||||
self.positions_2d, self.query, self.key)
|
||||
|
||||
mock_npu_mrope.assert_called_once()
|
||||
self.assertFalse(torch.isnan(result_q).any().item())
|
||||
self.assertFalse(torch.isnan(result_k).any().item())
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
Reference in New Issue
Block a user