[UT] Align input arguments with Ascend(Yarn)RotaryEmbedding with vLLM and add ut (#7358)
### What this PR does / why we need it?
This PR adds missing arguments in `AscendRotaryEmbedding`,
`AscendYarnRotaryEmbedding` to conform with vLLM. Besides, corresponding
ut is introduced.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
340
tests/ut/ops/test_rotary_embedding.py
Normal file
340
tests/ut/ops/test_rotary_embedding.py
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from unittest.mock import MagicMock, patch, PropertyMock
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import (YaRNScalingRotaryEmbedding, RotaryEmbedding)
|
||||||
|
from vllm_ascend.ops.rotary_embedding import (AscendYaRNRotaryEmbedding, AscendRotaryEmbedding)
|
||||||
|
|
||||||
|
|
||||||
|
HEAD_SIZE = 64
|
||||||
|
ROTARY_DIM = 64
|
||||||
|
MAX_POS = 2048
|
||||||
|
BASE = 10000.0
|
||||||
|
DTYPE = torch.bfloat16
|
||||||
|
SEQ_LEN = 4
|
||||||
|
NUM_HEADS = 2
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tensors(seq_len=SEQ_LEN, num_heads=NUM_HEADS, head_size=HEAD_SIZE):
|
||||||
|
positions = torch.arange(seq_len, dtype=torch.long)
|
||||||
|
query = torch.randn(seq_len, num_heads * head_size)
|
||||||
|
key = torch.randn(seq_len, num_heads * head_size)
|
||||||
|
return positions, query, key
|
||||||
|
|
||||||
|
|
||||||
|
def check_parent_init_signature_has_not_changed(parent_func, child_func):
|
||||||
|
parent_sig = inspect.signature(parent_func)
|
||||||
|
parent_params = set(parent_sig.parameters) - {"self"}
|
||||||
|
|
||||||
|
child_sig = inspect.signature(child_func)
|
||||||
|
child_params = set(child_sig.parameters) - {"self"}
|
||||||
|
|
||||||
|
added = parent_params - child_params
|
||||||
|
removed = child_params - parent_params
|
||||||
|
|
||||||
|
assert not added, (
|
||||||
|
f"{parent_func.__name__} added new parameter(s): {added}. "
|
||||||
|
f"Check whether {child_func.__name__} needs to forward them."
|
||||||
|
)
|
||||||
|
assert not removed, (
|
||||||
|
f"{parent_func.__name__} removed parameter(s): {removed}. "
|
||||||
|
f"Check whether {child_func.__name__} needs to forward them."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_init_side_effects():
|
||||||
|
"""
|
||||||
|
Suppress all side-effects that fire during __init__ so every test starts
|
||||||
|
from a clean, predictable state without needing real NPU ops or vLLM
|
||||||
|
global config.
|
||||||
|
"""
|
||||||
|
with (
|
||||||
|
patch("vllm_ascend.ops.rotary_embedding._record_cos_sin_cache"),
|
||||||
|
patch("vllm_ascend.ops.rotary_embedding._record_cos_and_sin_cache_interleaved"),
|
||||||
|
patch("vllm_ascend.ops.rotary_embedding.get_current_vllm_config") as mock_cfg,
|
||||||
|
):
|
||||||
|
# Default: speculative_config is None → use_mtp = False
|
||||||
|
mock_cfg.return_value.speculative_config = None
|
||||||
|
yield mock_cfg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def make_embedding(patch_init_side_effects):
|
||||||
|
"""Factory that creates an AscendRotaryEmbedding with controllable use_mtp."""
|
||||||
|
|
||||||
|
def _factory(use_mtp: bool = False, is_neox_style: bool = True):
|
||||||
|
spec_cfg = MagicMock(method="mtp") if use_mtp else None
|
||||||
|
patch_init_side_effects.return_value.speculative_config = spec_cfg
|
||||||
|
|
||||||
|
with patch("vllm_ascend.ops.rotary_embedding.RotaryEmbedding.__init__") as mock_parent_init:
|
||||||
|
mock_parent_init.return_value = None
|
||||||
|
from vllm_ascend.ops.rotary_embedding import AscendRotaryEmbedding
|
||||||
|
|
||||||
|
emb = AscendRotaryEmbedding.__new__(AscendRotaryEmbedding)
|
||||||
|
# Manually set attrs that the real parent would set
|
||||||
|
emb.head_size = HEAD_SIZE
|
||||||
|
emb.rotary_dim = ROTARY_DIM
|
||||||
|
emb.is_neox_style = is_neox_style
|
||||||
|
emb.cos_sin_cache = torch.zeros(MAX_POS, ROTARY_DIM)
|
||||||
|
# Call __init__ to exercise our code path
|
||||||
|
AscendRotaryEmbedding.__init__(
|
||||||
|
emb, HEAD_SIZE, ROTARY_DIM, MAX_POS, BASE, is_neox_style, DTYPE
|
||||||
|
)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
return _factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def make_yarn_embedding(patch_init_side_effects):
|
||||||
|
"""
|
||||||
|
Factory for AscendYaRNRotaryEmbedding with parent __init__ suppressed.
|
||||||
|
patch_init_side_effects is the same autouse fixture as before.
|
||||||
|
"""
|
||||||
|
def _factory(is_neox_style: bool = True):
|
||||||
|
with patch("vllm_ascend.ops.rotary_embedding.YaRNScalingRotaryEmbedding.__init__") as mock_parent_init:
|
||||||
|
mock_parent_init.return_value = None
|
||||||
|
from vllm_ascend.ops.rotary_embedding import AscendYaRNRotaryEmbedding
|
||||||
|
|
||||||
|
emb = AscendYaRNRotaryEmbedding.__new__(AscendYaRNRotaryEmbedding)
|
||||||
|
emb.head_size = HEAD_SIZE
|
||||||
|
emb.rotary_dim = ROTARY_DIM
|
||||||
|
emb.is_neox_style = is_neox_style
|
||||||
|
emb.cos_sin_cache = torch.zeros(MAX_POS, ROTARY_DIM)
|
||||||
|
AscendYaRNRotaryEmbedding.__init__(
|
||||||
|
emb,
|
||||||
|
head_size=HEAD_SIZE,
|
||||||
|
rotary_dim=ROTARY_DIM,
|
||||||
|
max_position_embeddings=MAX_POS,
|
||||||
|
base=BASE,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
|
scaling_factor=1.0,
|
||||||
|
dtype=DTYPE,
|
||||||
|
)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
return _factory
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendEmbeddingForwardOOT:
|
||||||
|
|
||||||
|
@patch("torch.ops.vllm.npu_rotary_embedding")
|
||||||
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||||
|
def test_basic_call_delegates_to_npu_op(self, mock_get_forward_context, mock_npu_op, make_embedding):
|
||||||
|
"""forward_oot always calls npu_rotary_embedding and returns its result."""
|
||||||
|
mock_get_forward_context.return_value = MagicMock()
|
||||||
|
mock_get_forward_context.return_value.is_draft_model = False
|
||||||
|
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
|
||||||
|
expected_output = (torch.randn(SEQ_LEN, NUM_HEADS * HEAD_SIZE),) * 2
|
||||||
|
mock_npu_op.return_value = expected_output
|
||||||
|
|
||||||
|
emb = make_embedding()
|
||||||
|
positions, query, key = _make_tensors()
|
||||||
|
|
||||||
|
result = emb.forward_oot(positions, query, key)
|
||||||
|
|
||||||
|
mock_npu_op.assert_called_once_with(
|
||||||
|
positions, query, key, emb.cos_sin_cache,
|
||||||
|
HEAD_SIZE, ROTARY_DIM, emb.is_neox_style,
|
||||||
|
)
|
||||||
|
assert result is expected_output
|
||||||
|
|
||||||
|
@patch("torch.ops.vllm.npu_rotary_embedding")
|
||||||
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||||
|
def test_neox_style_override_true(self, mock_get_forward_context, mock_npu_op, make_embedding):
|
||||||
|
"""is_neox_style_override=True wins over self.is_neox_style=False."""
|
||||||
|
mock_get_forward_context.return_value = MagicMock()
|
||||||
|
mock_get_forward_context.return_value.is_draft_model = False
|
||||||
|
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
|
||||||
|
mock_npu_op.return_value = MagicMock()
|
||||||
|
|
||||||
|
emb = make_embedding(is_neox_style=False)
|
||||||
|
positions, query, key = _make_tensors()
|
||||||
|
|
||||||
|
emb.forward_oot(positions, query, key, is_neox_style_override=True)
|
||||||
|
|
||||||
|
_, kwargs = mock_npu_op.call_args
|
||||||
|
# Verify the override was forwarded correctly
|
||||||
|
assert mock_npu_op.call_args[0][-1] is True # last positional arg = is_neox_style
|
||||||
|
|
||||||
|
@patch("torch.ops.vllm.npu_rotary_embedding")
|
||||||
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||||
|
def test_neox_style_override_false(self, mock_get_forward_context, mock_npu_op, make_embedding):
|
||||||
|
"""is_neox_style_override=False wins over self.is_neox_style=True."""
|
||||||
|
mock_get_forward_context.return_value = MagicMock()
|
||||||
|
mock_get_forward_context.return_value.is_draft_model = False
|
||||||
|
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
|
||||||
|
mock_npu_op.return_value = MagicMock()
|
||||||
|
|
||||||
|
emb = make_embedding(is_neox_style=True)
|
||||||
|
positions, query, key = _make_tensors()
|
||||||
|
|
||||||
|
emb.forward_oot(positions, query, key, is_neox_style_override=False)
|
||||||
|
|
||||||
|
assert mock_npu_op.call_args[0][-1] is False
|
||||||
|
|
||||||
|
@patch("torch.ops.vllm.npu_rotary_embedding")
|
||||||
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||||
|
def test_neox_style_override_none_uses_self(self, mock_get_forward_context, mock_npu_op, make_embedding):
|
||||||
|
"""When override is None, self.is_neox_style is used unchanged."""
|
||||||
|
mock_get_forward_context.return_value = MagicMock()
|
||||||
|
mock_get_forward_context.return_value.is_draft_model = False
|
||||||
|
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
|
||||||
|
mock_npu_op.return_value = MagicMock()
|
||||||
|
|
||||||
|
emb = make_embedding(is_neox_style=True)
|
||||||
|
positions, query, key = _make_tensors()
|
||||||
|
|
||||||
|
emb.forward_oot(positions, query, key, is_neox_style_override=None)
|
||||||
|
|
||||||
|
assert mock_npu_op.call_args[0][-1] is True
|
||||||
|
|
||||||
|
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
||||||
|
@patch("torch.ops.vllm.npu_rotary_embedding")
|
||||||
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||||
|
def test_gather_unpad_called_when_all_conditions_met(
|
||||||
|
self, mock_get_forward_context, mock_npu_op, mock_gather, make_embedding
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
maybe_all_gather_and_maybe_unpad is called iff:
|
||||||
|
is_draft_model=True AND use_mtp=True AND flash_comm_v1_enabled=True
|
||||||
|
"""
|
||||||
|
mock_get_forward_context.return_value = MagicMock()
|
||||||
|
mock_get_forward_context.return_value.is_draft_model = True
|
||||||
|
mock_get_forward_context.return_value.flash_comm_v1_enabled = True
|
||||||
|
gathered_positions = torch.arange(SEQ_LEN, dtype=torch.long)
|
||||||
|
mock_gather.return_value = gathered_positions
|
||||||
|
mock_npu_op.return_value = MagicMock()
|
||||||
|
|
||||||
|
emb = make_embedding(use_mtp=True)
|
||||||
|
positions, query, key = _make_tensors()
|
||||||
|
|
||||||
|
emb.forward_oot(positions, query, key)
|
||||||
|
|
||||||
|
mock_gather.assert_called_once()
|
||||||
|
# npu op should receive the gathered positions, not the originals
|
||||||
|
assert mock_npu_op.call_args[0][0] is gathered_positions
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_draft_model,flash_comm,use_mtp", [
|
||||||
|
(False, True, True), # not draft
|
||||||
|
(True, False, True), # flash_comm disabled
|
||||||
|
(True, True, False), # use_mtp disabled
|
||||||
|
])
|
||||||
|
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
||||||
|
@patch("torch.ops.vllm.npu_rotary_embedding")
|
||||||
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
||||||
|
def test_gather_unpad_skipped_unless_all_conditions_met(
|
||||||
|
self, mock_get_forward_context, mock_npu_op, mock_gather,
|
||||||
|
is_draft_model, flash_comm, use_mtp, make_embedding,
|
||||||
|
):
|
||||||
|
"""gather/unpad must NOT fire if any one of the three conditions is False."""
|
||||||
|
mock_get_forward_context.return_value = MagicMock()
|
||||||
|
mock_get_forward_context.return_value.is_draft_model = is_draft_model
|
||||||
|
mock_get_forward_context.return_value.flash_comm_v1_enabled = flash_comm
|
||||||
|
mock_npu_op.return_value = MagicMock()
|
||||||
|
|
||||||
|
emb = make_embedding(use_mtp=use_mtp)
|
||||||
|
positions, query, key = _make_tensors()
|
||||||
|
|
||||||
|
emb.forward_oot(positions, query, key)
|
||||||
|
|
||||||
|
mock_gather.assert_not_called()
|
||||||
|
# Original positions tensor is passed through untouched
|
||||||
|
assert mock_npu_op.call_args[0][0] is positions
|
||||||
|
|
||||||
|
def test_parent_init_signature_has_not_changed(self):
|
||||||
|
"""
|
||||||
|
Fail loudly if RotaryEmbedding.__init__ adds, removes, or
|
||||||
|
renames parameters, so a developer knows to update AscendRotaryEmbedding
|
||||||
|
accordingly.
|
||||||
|
"""
|
||||||
|
check_parent_init_signature_has_not_changed(
|
||||||
|
RotaryEmbedding.__init__,
|
||||||
|
AscendRotaryEmbedding.__init__
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendYaRNRotaryEmbeddingForwardOOT:
|
||||||
|
|
||||||
|
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
|
||||||
|
def test_delegates_to_ascend_rotary_forward_oot(self, mock_delegate, make_yarn_embedding):
|
||||||
|
"""forward_oot must delegate to AscendRotaryEmbedding.forward_oot."""
|
||||||
|
expected = MagicMock()
|
||||||
|
mock_delegate.return_value = expected
|
||||||
|
|
||||||
|
emb = make_yarn_embedding()
|
||||||
|
positions, query, key = _make_tensors()
|
||||||
|
|
||||||
|
result = emb.forward_oot(positions, query, key)
|
||||||
|
|
||||||
|
mock_delegate.assert_called_once_with(emb, positions, query, key, None, None)
|
||||||
|
assert result is expected
|
||||||
|
|
||||||
|
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
|
||||||
|
def test_return_value_passed_through(self, mock_delegate, make_yarn_embedding):
|
||||||
|
"""Return value from the delegate is returned unchanged."""
|
||||||
|
sentinel = (torch.randn(SEQ_LEN, HEAD_SIZE), torch.randn(SEQ_LEN, HEAD_SIZE))
|
||||||
|
mock_delegate.return_value = sentinel
|
||||||
|
|
||||||
|
emb = make_yarn_embedding()
|
||||||
|
positions, query, key = _make_tensors()
|
||||||
|
|
||||||
|
result = emb.forward_oot(positions, query, key)
|
||||||
|
|
||||||
|
assert result is sentinel
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("override", [True, False])
|
||||||
|
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
|
||||||
|
def test_is_neox_style_override_forwarded(self, mock_delegate, override, make_yarn_embedding):
|
||||||
|
"""is_neox_style_override must be forwarded verbatim, both True and False."""
|
||||||
|
mock_delegate.return_value = MagicMock()
|
||||||
|
|
||||||
|
emb = make_yarn_embedding()
|
||||||
|
positions, query, key = _make_tensors()
|
||||||
|
|
||||||
|
emb.forward_oot(positions, query, key, is_neox_style_override=override)
|
||||||
|
|
||||||
|
_, call_args, _ = mock_delegate.mock_calls[0]
|
||||||
|
assert call_args[5] is override # 6th positional arg
|
||||||
|
|
||||||
|
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
|
||||||
|
def test_all_args_forwarded_together(self, mock_delegate, make_yarn_embedding):
|
||||||
|
"""Smoke test: all args passed simultaneously are all forwarded correctly."""
|
||||||
|
mock_delegate.return_value = MagicMock()
|
||||||
|
|
||||||
|
emb = make_yarn_embedding()
|
||||||
|
positions, query, key = _make_tensors()
|
||||||
|
offsets = torch.ones(SEQ_LEN, dtype=torch.long)
|
||||||
|
|
||||||
|
emb.forward_oot(positions, query, key, offsets=offsets, is_neox_style_override=False)
|
||||||
|
|
||||||
|
mock_delegate.assert_called_once_with(emb, positions, query, key, offsets, False)
|
||||||
|
|
||||||
|
def test_parent_init_signature_has_not_changed(self):
|
||||||
|
"""
|
||||||
|
Fail loudly if YaRNScalingRotaryEmbedding.__init__ adds, removes, or
|
||||||
|
renames parameters, so a developer knows to update AscendYaRNRotaryEmbedding
|
||||||
|
accordingly.
|
||||||
|
"""
|
||||||
|
check_parent_init_signature_has_not_changed(
|
||||||
|
YaRNScalingRotaryEmbedding.__init__,
|
||||||
|
AscendYaRNRotaryEmbedding.__init__
|
||||||
|
)
|
||||||
@@ -222,8 +222,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
|||||||
base: float,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
init_cache: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype)
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, init_cache)
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.use_mtp = vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp"
|
self.use_mtp = vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp"
|
||||||
_record_cos_sin_cache(self.cos_sin_cache)
|
_record_cos_sin_cache(self.cos_sin_cache)
|
||||||
@@ -264,6 +265,7 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
|||||||
attn_factor: float = 1,
|
attn_factor: float = 1,
|
||||||
beta_fast: int = 32,
|
beta_fast: int = 32,
|
||||||
beta_slow: int = 1,
|
beta_slow: int = 1,
|
||||||
|
apply_yarn_scaling: bool = True,
|
||||||
truncate: bool = False,
|
truncate: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
extra_kwargs = {
|
extra_kwargs = {
|
||||||
@@ -271,6 +273,7 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
|||||||
"attn_factor": attn_factor,
|
"attn_factor": attn_factor,
|
||||||
"beta_fast": beta_fast,
|
"beta_fast": beta_fast,
|
||||||
"beta_slow": beta_slow,
|
"beta_slow": beta_slow,
|
||||||
|
"apply_yarn_scaling": apply_yarn_scaling,
|
||||||
# TODO: current not support actual truncate,adaptation for extra parameters to be compatible with vllm
|
# TODO: current not support actual truncate,adaptation for extra parameters to be compatible with vllm
|
||||||
"truncate": truncate,
|
"truncate": truncate,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user