### What this PR does / why we need it?
This PR adds missing arguments in `AscendRotaryEmbedding`,
`AscendYarnRotaryEmbedding` to conform with vLLM. Besides, corresponding
ut is introduced.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Angazenn <supperccell@163.com>
340 lines
14 KiB
Python
340 lines
14 KiB
Python
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
|
|
|
|
import inspect
|
|
import pytest
|
|
import torch
|
|
from unittest.mock import MagicMock, patch, PropertyMock
|
|
|
|
from vllm.model_executor.layers.rotary_embedding import (YaRNScalingRotaryEmbedding, RotaryEmbedding)
|
|
from vllm_ascend.ops.rotary_embedding import (AscendYaRNRotaryEmbedding, AscendRotaryEmbedding)
|
|
|
|
|
|
HEAD_SIZE = 64
|
|
ROTARY_DIM = 64
|
|
MAX_POS = 2048
|
|
BASE = 10000.0
|
|
DTYPE = torch.bfloat16
|
|
SEQ_LEN = 4
|
|
NUM_HEADS = 2
|
|
|
|
|
|
def _make_tensors(seq_len=SEQ_LEN, num_heads=NUM_HEADS, head_size=HEAD_SIZE):
|
|
positions = torch.arange(seq_len, dtype=torch.long)
|
|
query = torch.randn(seq_len, num_heads * head_size)
|
|
key = torch.randn(seq_len, num_heads * head_size)
|
|
return positions, query, key
|
|
|
|
|
|
def check_parent_init_signature_has_not_changed(parent_func, child_func):
|
|
parent_sig = inspect.signature(parent_func)
|
|
parent_params = set(parent_sig.parameters) - {"self"}
|
|
|
|
child_sig = inspect.signature(child_func)
|
|
child_params = set(child_sig.parameters) - {"self"}
|
|
|
|
added = parent_params - child_params
|
|
removed = child_params - parent_params
|
|
|
|
assert not added, (
|
|
f"{parent_func.__name__} added new parameter(s): {added}. "
|
|
f"Check whether {child_func.__name__} needs to forward them."
|
|
)
|
|
assert not removed, (
|
|
f"{parent_func.__name__} removed parameter(s): {removed}. "
|
|
f"Check whether {child_func.__name__} needs to forward them."
|
|
)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def patch_init_side_effects():
|
|
"""
|
|
Suppress all side-effects that fire during __init__ so every test starts
|
|
from a clean, predictable state without needing real NPU ops or vLLM
|
|
global config.
|
|
"""
|
|
with (
|
|
patch("vllm_ascend.ops.rotary_embedding._record_cos_sin_cache"),
|
|
patch("vllm_ascend.ops.rotary_embedding._record_cos_and_sin_cache_interleaved"),
|
|
patch("vllm_ascend.ops.rotary_embedding.get_current_vllm_config") as mock_cfg,
|
|
):
|
|
# Default: speculative_config is None → use_mtp = False
|
|
mock_cfg.return_value.speculative_config = None
|
|
yield mock_cfg
|
|
|
|
|
|
@pytest.fixture()
|
|
def make_embedding(patch_init_side_effects):
|
|
"""Factory that creates an AscendRotaryEmbedding with controllable use_mtp."""
|
|
|
|
def _factory(use_mtp: bool = False, is_neox_style: bool = True):
|
|
spec_cfg = MagicMock(method="mtp") if use_mtp else None
|
|
patch_init_side_effects.return_value.speculative_config = spec_cfg
|
|
|
|
with patch("vllm_ascend.ops.rotary_embedding.RotaryEmbedding.__init__") as mock_parent_init:
|
|
mock_parent_init.return_value = None
|
|
from vllm_ascend.ops.rotary_embedding import AscendRotaryEmbedding
|
|
|
|
emb = AscendRotaryEmbedding.__new__(AscendRotaryEmbedding)
|
|
# Manually set attrs that the real parent would set
|
|
emb.head_size = HEAD_SIZE
|
|
emb.rotary_dim = ROTARY_DIM
|
|
emb.is_neox_style = is_neox_style
|
|
emb.cos_sin_cache = torch.zeros(MAX_POS, ROTARY_DIM)
|
|
# Call __init__ to exercise our code path
|
|
AscendRotaryEmbedding.__init__(
|
|
emb, HEAD_SIZE, ROTARY_DIM, MAX_POS, BASE, is_neox_style, DTYPE
|
|
)
|
|
return emb
|
|
|
|
return _factory
|
|
|
|
|
|
@pytest.fixture()
|
|
def make_yarn_embedding(patch_init_side_effects):
|
|
"""
|
|
Factory for AscendYaRNRotaryEmbedding with parent __init__ suppressed.
|
|
patch_init_side_effects is the same autouse fixture as before.
|
|
"""
|
|
def _factory(is_neox_style: bool = True):
|
|
with patch("vllm_ascend.ops.rotary_embedding.YaRNScalingRotaryEmbedding.__init__") as mock_parent_init:
|
|
mock_parent_init.return_value = None
|
|
from vllm_ascend.ops.rotary_embedding import AscendYaRNRotaryEmbedding
|
|
|
|
emb = AscendYaRNRotaryEmbedding.__new__(AscendYaRNRotaryEmbedding)
|
|
emb.head_size = HEAD_SIZE
|
|
emb.rotary_dim = ROTARY_DIM
|
|
emb.is_neox_style = is_neox_style
|
|
emb.cos_sin_cache = torch.zeros(MAX_POS, ROTARY_DIM)
|
|
AscendYaRNRotaryEmbedding.__init__(
|
|
emb,
|
|
head_size=HEAD_SIZE,
|
|
rotary_dim=ROTARY_DIM,
|
|
max_position_embeddings=MAX_POS,
|
|
base=BASE,
|
|
is_neox_style=is_neox_style,
|
|
scaling_factor=1.0,
|
|
dtype=DTYPE,
|
|
)
|
|
return emb
|
|
|
|
return _factory
|
|
|
|
|
|
class TestAscendEmbeddingForwardOOT:
|
|
|
|
@patch("torch.ops.vllm.npu_rotary_embedding")
|
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
|
def test_basic_call_delegates_to_npu_op(self, mock_get_forward_context, mock_npu_op, make_embedding):
|
|
"""forward_oot always calls npu_rotary_embedding and returns its result."""
|
|
mock_get_forward_context.return_value = MagicMock()
|
|
mock_get_forward_context.return_value.is_draft_model = False
|
|
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
|
|
expected_output = (torch.randn(SEQ_LEN, NUM_HEADS * HEAD_SIZE),) * 2
|
|
mock_npu_op.return_value = expected_output
|
|
|
|
emb = make_embedding()
|
|
positions, query, key = _make_tensors()
|
|
|
|
result = emb.forward_oot(positions, query, key)
|
|
|
|
mock_npu_op.assert_called_once_with(
|
|
positions, query, key, emb.cos_sin_cache,
|
|
HEAD_SIZE, ROTARY_DIM, emb.is_neox_style,
|
|
)
|
|
assert result is expected_output
|
|
|
|
@patch("torch.ops.vllm.npu_rotary_embedding")
|
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
|
def test_neox_style_override_true(self, mock_get_forward_context, mock_npu_op, make_embedding):
|
|
"""is_neox_style_override=True wins over self.is_neox_style=False."""
|
|
mock_get_forward_context.return_value = MagicMock()
|
|
mock_get_forward_context.return_value.is_draft_model = False
|
|
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
|
|
mock_npu_op.return_value = MagicMock()
|
|
|
|
emb = make_embedding(is_neox_style=False)
|
|
positions, query, key = _make_tensors()
|
|
|
|
emb.forward_oot(positions, query, key, is_neox_style_override=True)
|
|
|
|
_, kwargs = mock_npu_op.call_args
|
|
# Verify the override was forwarded correctly
|
|
assert mock_npu_op.call_args[0][-1] is True # last positional arg = is_neox_style
|
|
|
|
@patch("torch.ops.vllm.npu_rotary_embedding")
|
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
|
def test_neox_style_override_false(self, mock_get_forward_context, mock_npu_op, make_embedding):
|
|
"""is_neox_style_override=False wins over self.is_neox_style=True."""
|
|
mock_get_forward_context.return_value = MagicMock()
|
|
mock_get_forward_context.return_value.is_draft_model = False
|
|
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
|
|
mock_npu_op.return_value = MagicMock()
|
|
|
|
emb = make_embedding(is_neox_style=True)
|
|
positions, query, key = _make_tensors()
|
|
|
|
emb.forward_oot(positions, query, key, is_neox_style_override=False)
|
|
|
|
assert mock_npu_op.call_args[0][-1] is False
|
|
|
|
@patch("torch.ops.vllm.npu_rotary_embedding")
|
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
|
def test_neox_style_override_none_uses_self(self, mock_get_forward_context, mock_npu_op, make_embedding):
|
|
"""When override is None, self.is_neox_style is used unchanged."""
|
|
mock_get_forward_context.return_value = MagicMock()
|
|
mock_get_forward_context.return_value.is_draft_model = False
|
|
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
|
|
mock_npu_op.return_value = MagicMock()
|
|
|
|
emb = make_embedding(is_neox_style=True)
|
|
positions, query, key = _make_tensors()
|
|
|
|
emb.forward_oot(positions, query, key, is_neox_style_override=None)
|
|
|
|
assert mock_npu_op.call_args[0][-1] is True
|
|
|
|
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
|
@patch("torch.ops.vllm.npu_rotary_embedding")
|
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
|
def test_gather_unpad_called_when_all_conditions_met(
|
|
self, mock_get_forward_context, mock_npu_op, mock_gather, make_embedding
|
|
):
|
|
"""
|
|
maybe_all_gather_and_maybe_unpad is called iff:
|
|
is_draft_model=True AND use_mtp=True AND flash_comm_v1_enabled=True
|
|
"""
|
|
mock_get_forward_context.return_value = MagicMock()
|
|
mock_get_forward_context.return_value.is_draft_model = True
|
|
mock_get_forward_context.return_value.flash_comm_v1_enabled = True
|
|
gathered_positions = torch.arange(SEQ_LEN, dtype=torch.long)
|
|
mock_gather.return_value = gathered_positions
|
|
mock_npu_op.return_value = MagicMock()
|
|
|
|
emb = make_embedding(use_mtp=True)
|
|
positions, query, key = _make_tensors()
|
|
|
|
emb.forward_oot(positions, query, key)
|
|
|
|
mock_gather.assert_called_once()
|
|
# npu op should receive the gathered positions, not the originals
|
|
assert mock_npu_op.call_args[0][0] is gathered_positions
|
|
|
|
@pytest.mark.parametrize("is_draft_model,flash_comm,use_mtp", [
|
|
(False, True, True), # not draft
|
|
(True, False, True), # flash_comm disabled
|
|
(True, True, False), # use_mtp disabled
|
|
])
|
|
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
|
@patch("torch.ops.vllm.npu_rotary_embedding")
|
|
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
|
|
def test_gather_unpad_skipped_unless_all_conditions_met(
|
|
self, mock_get_forward_context, mock_npu_op, mock_gather,
|
|
is_draft_model, flash_comm, use_mtp, make_embedding,
|
|
):
|
|
"""gather/unpad must NOT fire if any one of the three conditions is False."""
|
|
mock_get_forward_context.return_value = MagicMock()
|
|
mock_get_forward_context.return_value.is_draft_model = is_draft_model
|
|
mock_get_forward_context.return_value.flash_comm_v1_enabled = flash_comm
|
|
mock_npu_op.return_value = MagicMock()
|
|
|
|
emb = make_embedding(use_mtp=use_mtp)
|
|
positions, query, key = _make_tensors()
|
|
|
|
emb.forward_oot(positions, query, key)
|
|
|
|
mock_gather.assert_not_called()
|
|
# Original positions tensor is passed through untouched
|
|
assert mock_npu_op.call_args[0][0] is positions
|
|
|
|
def test_parent_init_signature_has_not_changed(self):
|
|
"""
|
|
Fail loudly if RotaryEmbedding.__init__ adds, removes, or
|
|
renames parameters, so a developer knows to update AscendRotaryEmbedding
|
|
accordingly.
|
|
"""
|
|
check_parent_init_signature_has_not_changed(
|
|
RotaryEmbedding.__init__,
|
|
AscendRotaryEmbedding.__init__
|
|
)
|
|
|
|
|
|
class TestAscendYaRNRotaryEmbeddingForwardOOT:
|
|
|
|
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
|
|
def test_delegates_to_ascend_rotary_forward_oot(self, mock_delegate, make_yarn_embedding):
|
|
"""forward_oot must delegate to AscendRotaryEmbedding.forward_oot."""
|
|
expected = MagicMock()
|
|
mock_delegate.return_value = expected
|
|
|
|
emb = make_yarn_embedding()
|
|
positions, query, key = _make_tensors()
|
|
|
|
result = emb.forward_oot(positions, query, key)
|
|
|
|
mock_delegate.assert_called_once_with(emb, positions, query, key, None, None)
|
|
assert result is expected
|
|
|
|
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
|
|
def test_return_value_passed_through(self, mock_delegate, make_yarn_embedding):
|
|
"""Return value from the delegate is returned unchanged."""
|
|
sentinel = (torch.randn(SEQ_LEN, HEAD_SIZE), torch.randn(SEQ_LEN, HEAD_SIZE))
|
|
mock_delegate.return_value = sentinel
|
|
|
|
emb = make_yarn_embedding()
|
|
positions, query, key = _make_tensors()
|
|
|
|
result = emb.forward_oot(positions, query, key)
|
|
|
|
assert result is sentinel
|
|
|
|
@pytest.mark.parametrize("override", [True, False])
|
|
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
|
|
def test_is_neox_style_override_forwarded(self, mock_delegate, override, make_yarn_embedding):
|
|
"""is_neox_style_override must be forwarded verbatim, both True and False."""
|
|
mock_delegate.return_value = MagicMock()
|
|
|
|
emb = make_yarn_embedding()
|
|
positions, query, key = _make_tensors()
|
|
|
|
emb.forward_oot(positions, query, key, is_neox_style_override=override)
|
|
|
|
_, call_args, _ = mock_delegate.mock_calls[0]
|
|
assert call_args[5] is override # 6th positional arg
|
|
|
|
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
|
|
def test_all_args_forwarded_together(self, mock_delegate, make_yarn_embedding):
|
|
"""Smoke test: all args passed simultaneously are all forwarded correctly."""
|
|
mock_delegate.return_value = MagicMock()
|
|
|
|
emb = make_yarn_embedding()
|
|
positions, query, key = _make_tensors()
|
|
offsets = torch.ones(SEQ_LEN, dtype=torch.long)
|
|
|
|
emb.forward_oot(positions, query, key, offsets=offsets, is_neox_style_override=False)
|
|
|
|
mock_delegate.assert_called_once_with(emb, positions, query, key, offsets, False)
|
|
|
|
def test_parent_init_signature_has_not_changed(self):
|
|
"""
|
|
Fail loudly if YaRNScalingRotaryEmbedding.__init__ adds, removes, or
|
|
renames parameters, so a developer knows to update AscendYaRNRotaryEmbedding
|
|
accordingly.
|
|
"""
|
|
check_parent_init_signature_has_not_changed(
|
|
YaRNScalingRotaryEmbedding.__init__,
|
|
AscendYaRNRotaryEmbedding.__init__
|
|
) |