Files
xc-llm-ascend/tests/ut/ops/test_rotary_embedding.py
Angazenn bdb65319a9 [UT] Align input arguments with Ascend(Yarn)RotaryEmbedding with vLLM and add ut (#7358)
### What this PR does / why we need it?
This PR adds missing arguments in `AscendRotaryEmbedding`,
`AscendYarnRotaryEmbedding` to conform with vLLM. Besides, corresponding
ut is introduced.

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Angazenn <supperccell@163.com>
2026-03-24 16:02:56 +08:00

340 lines
14 KiB
Python

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import inspect
import pytest
import torch
from unittest.mock import MagicMock, patch, PropertyMock
from vllm.model_executor.layers.rotary_embedding import (YaRNScalingRotaryEmbedding, RotaryEmbedding)
from vllm_ascend.ops.rotary_embedding import (AscendYaRNRotaryEmbedding, AscendRotaryEmbedding)
HEAD_SIZE = 64
ROTARY_DIM = 64
MAX_POS = 2048
BASE = 10000.0
DTYPE = torch.bfloat16
SEQ_LEN = 4
NUM_HEADS = 2
def _make_tensors(seq_len=SEQ_LEN, num_heads=NUM_HEADS, head_size=HEAD_SIZE):
positions = torch.arange(seq_len, dtype=torch.long)
query = torch.randn(seq_len, num_heads * head_size)
key = torch.randn(seq_len, num_heads * head_size)
return positions, query, key
def check_parent_init_signature_has_not_changed(parent_func, child_func):
parent_sig = inspect.signature(parent_func)
parent_params = set(parent_sig.parameters) - {"self"}
child_sig = inspect.signature(child_func)
child_params = set(child_sig.parameters) - {"self"}
added = parent_params - child_params
removed = child_params - parent_params
assert not added, (
f"{parent_func.__name__} added new parameter(s): {added}. "
f"Check whether {child_func.__name__} needs to forward them."
)
assert not removed, (
f"{parent_func.__name__} removed parameter(s): {removed}. "
f"Check whether {child_func.__name__} needs to forward them."
)
@pytest.fixture(autouse=True)
def patch_init_side_effects():
"""
Suppress all side-effects that fire during __init__ so every test starts
from a clean, predictable state without needing real NPU ops or vLLM
global config.
"""
with (
patch("vllm_ascend.ops.rotary_embedding._record_cos_sin_cache"),
patch("vllm_ascend.ops.rotary_embedding._record_cos_and_sin_cache_interleaved"),
patch("vllm_ascend.ops.rotary_embedding.get_current_vllm_config") as mock_cfg,
):
# Default: speculative_config is None → use_mtp = False
mock_cfg.return_value.speculative_config = None
yield mock_cfg
@pytest.fixture()
def make_embedding(patch_init_side_effects):
"""Factory that creates an AscendRotaryEmbedding with controllable use_mtp."""
def _factory(use_mtp: bool = False, is_neox_style: bool = True):
spec_cfg = MagicMock(method="mtp") if use_mtp else None
patch_init_side_effects.return_value.speculative_config = spec_cfg
with patch("vllm_ascend.ops.rotary_embedding.RotaryEmbedding.__init__") as mock_parent_init:
mock_parent_init.return_value = None
from vllm_ascend.ops.rotary_embedding import AscendRotaryEmbedding
emb = AscendRotaryEmbedding.__new__(AscendRotaryEmbedding)
# Manually set attrs that the real parent would set
emb.head_size = HEAD_SIZE
emb.rotary_dim = ROTARY_DIM
emb.is_neox_style = is_neox_style
emb.cos_sin_cache = torch.zeros(MAX_POS, ROTARY_DIM)
# Call __init__ to exercise our code path
AscendRotaryEmbedding.__init__(
emb, HEAD_SIZE, ROTARY_DIM, MAX_POS, BASE, is_neox_style, DTYPE
)
return emb
return _factory
@pytest.fixture()
def make_yarn_embedding(patch_init_side_effects):
"""
Factory for AscendYaRNRotaryEmbedding with parent __init__ suppressed.
patch_init_side_effects is the same autouse fixture as before.
"""
def _factory(is_neox_style: bool = True):
with patch("vllm_ascend.ops.rotary_embedding.YaRNScalingRotaryEmbedding.__init__") as mock_parent_init:
mock_parent_init.return_value = None
from vllm_ascend.ops.rotary_embedding import AscendYaRNRotaryEmbedding
emb = AscendYaRNRotaryEmbedding.__new__(AscendYaRNRotaryEmbedding)
emb.head_size = HEAD_SIZE
emb.rotary_dim = ROTARY_DIM
emb.is_neox_style = is_neox_style
emb.cos_sin_cache = torch.zeros(MAX_POS, ROTARY_DIM)
AscendYaRNRotaryEmbedding.__init__(
emb,
head_size=HEAD_SIZE,
rotary_dim=ROTARY_DIM,
max_position_embeddings=MAX_POS,
base=BASE,
is_neox_style=is_neox_style,
scaling_factor=1.0,
dtype=DTYPE,
)
return emb
return _factory
class TestAscendEmbeddingForwardOOT:
@patch("torch.ops.vllm.npu_rotary_embedding")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_basic_call_delegates_to_npu_op(self, mock_get_forward_context, mock_npu_op, make_embedding):
"""forward_oot always calls npu_rotary_embedding and returns its result."""
mock_get_forward_context.return_value = MagicMock()
mock_get_forward_context.return_value.is_draft_model = False
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
expected_output = (torch.randn(SEQ_LEN, NUM_HEADS * HEAD_SIZE),) * 2
mock_npu_op.return_value = expected_output
emb = make_embedding()
positions, query, key = _make_tensors()
result = emb.forward_oot(positions, query, key)
mock_npu_op.assert_called_once_with(
positions, query, key, emb.cos_sin_cache,
HEAD_SIZE, ROTARY_DIM, emb.is_neox_style,
)
assert result is expected_output
@patch("torch.ops.vllm.npu_rotary_embedding")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_neox_style_override_true(self, mock_get_forward_context, mock_npu_op, make_embedding):
"""is_neox_style_override=True wins over self.is_neox_style=False."""
mock_get_forward_context.return_value = MagicMock()
mock_get_forward_context.return_value.is_draft_model = False
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
mock_npu_op.return_value = MagicMock()
emb = make_embedding(is_neox_style=False)
positions, query, key = _make_tensors()
emb.forward_oot(positions, query, key, is_neox_style_override=True)
_, kwargs = mock_npu_op.call_args
# Verify the override was forwarded correctly
assert mock_npu_op.call_args[0][-1] is True # last positional arg = is_neox_style
@patch("torch.ops.vllm.npu_rotary_embedding")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_neox_style_override_false(self, mock_get_forward_context, mock_npu_op, make_embedding):
"""is_neox_style_override=False wins over self.is_neox_style=True."""
mock_get_forward_context.return_value = MagicMock()
mock_get_forward_context.return_value.is_draft_model = False
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
mock_npu_op.return_value = MagicMock()
emb = make_embedding(is_neox_style=True)
positions, query, key = _make_tensors()
emb.forward_oot(positions, query, key, is_neox_style_override=False)
assert mock_npu_op.call_args[0][-1] is False
@patch("torch.ops.vllm.npu_rotary_embedding")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_neox_style_override_none_uses_self(self, mock_get_forward_context, mock_npu_op, make_embedding):
"""When override is None, self.is_neox_style is used unchanged."""
mock_get_forward_context.return_value = MagicMock()
mock_get_forward_context.return_value.is_draft_model = False
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
mock_npu_op.return_value = MagicMock()
emb = make_embedding(is_neox_style=True)
positions, query, key = _make_tensors()
emb.forward_oot(positions, query, key, is_neox_style_override=None)
assert mock_npu_op.call_args[0][-1] is True
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("torch.ops.vllm.npu_rotary_embedding")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_gather_unpad_called_when_all_conditions_met(
self, mock_get_forward_context, mock_npu_op, mock_gather, make_embedding
):
"""
maybe_all_gather_and_maybe_unpad is called iff:
is_draft_model=True AND use_mtp=True AND flash_comm_v1_enabled=True
"""
mock_get_forward_context.return_value = MagicMock()
mock_get_forward_context.return_value.is_draft_model = True
mock_get_forward_context.return_value.flash_comm_v1_enabled = True
gathered_positions = torch.arange(SEQ_LEN, dtype=torch.long)
mock_gather.return_value = gathered_positions
mock_npu_op.return_value = MagicMock()
emb = make_embedding(use_mtp=True)
positions, query, key = _make_tensors()
emb.forward_oot(positions, query, key)
mock_gather.assert_called_once()
# npu op should receive the gathered positions, not the originals
assert mock_npu_op.call_args[0][0] is gathered_positions
@pytest.mark.parametrize("is_draft_model,flash_comm,use_mtp", [
(False, True, True), # not draft
(True, False, True), # flash_comm disabled
(True, True, False), # use_mtp disabled
])
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("torch.ops.vllm.npu_rotary_embedding")
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
def test_gather_unpad_skipped_unless_all_conditions_met(
self, mock_get_forward_context, mock_npu_op, mock_gather,
is_draft_model, flash_comm, use_mtp, make_embedding,
):
"""gather/unpad must NOT fire if any one of the three conditions is False."""
mock_get_forward_context.return_value = MagicMock()
mock_get_forward_context.return_value.is_draft_model = is_draft_model
mock_get_forward_context.return_value.flash_comm_v1_enabled = flash_comm
mock_npu_op.return_value = MagicMock()
emb = make_embedding(use_mtp=use_mtp)
positions, query, key = _make_tensors()
emb.forward_oot(positions, query, key)
mock_gather.assert_not_called()
# Original positions tensor is passed through untouched
assert mock_npu_op.call_args[0][0] is positions
def test_parent_init_signature_has_not_changed(self):
"""
Fail loudly if RotaryEmbedding.__init__ adds, removes, or
renames parameters, so a developer knows to update AscendRotaryEmbedding
accordingly.
"""
check_parent_init_signature_has_not_changed(
RotaryEmbedding.__init__,
AscendRotaryEmbedding.__init__
)
class TestAscendYaRNRotaryEmbeddingForwardOOT:
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
def test_delegates_to_ascend_rotary_forward_oot(self, mock_delegate, make_yarn_embedding):
"""forward_oot must delegate to AscendRotaryEmbedding.forward_oot."""
expected = MagicMock()
mock_delegate.return_value = expected
emb = make_yarn_embedding()
positions, query, key = _make_tensors()
result = emb.forward_oot(positions, query, key)
mock_delegate.assert_called_once_with(emb, positions, query, key, None, None)
assert result is expected
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
def test_return_value_passed_through(self, mock_delegate, make_yarn_embedding):
"""Return value from the delegate is returned unchanged."""
sentinel = (torch.randn(SEQ_LEN, HEAD_SIZE), torch.randn(SEQ_LEN, HEAD_SIZE))
mock_delegate.return_value = sentinel
emb = make_yarn_embedding()
positions, query, key = _make_tensors()
result = emb.forward_oot(positions, query, key)
assert result is sentinel
@pytest.mark.parametrize("override", [True, False])
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
def test_is_neox_style_override_forwarded(self, mock_delegate, override, make_yarn_embedding):
"""is_neox_style_override must be forwarded verbatim, both True and False."""
mock_delegate.return_value = MagicMock()
emb = make_yarn_embedding()
positions, query, key = _make_tensors()
emb.forward_oot(positions, query, key, is_neox_style_override=override)
_, call_args, _ = mock_delegate.mock_calls[0]
assert call_args[5] is override # 6th positional arg
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
def test_all_args_forwarded_together(self, mock_delegate, make_yarn_embedding):
"""Smoke test: all args passed simultaneously are all forwarded correctly."""
mock_delegate.return_value = MagicMock()
emb = make_yarn_embedding()
positions, query, key = _make_tensors()
offsets = torch.ones(SEQ_LEN, dtype=torch.long)
emb.forward_oot(positions, query, key, offsets=offsets, is_neox_style_override=False)
mock_delegate.assert_called_once_with(emb, positions, query, key, offsets, False)
def test_parent_init_signature_has_not_changed(self):
"""
Fail loudly if YaRNScalingRotaryEmbedding.__init__ adds, removes, or
renames parameters, so a developer knows to update AscendYaRNRotaryEmbedding
accordingly.
"""
check_parent_init_signature_has_not_changed(
YaRNScalingRotaryEmbedding.__init__,
AscendYaRNRotaryEmbedding.__init__
)