Sync from v0.13
This commit is contained in:
253
tests/kernels/core/test_mrope.py
Normal file
253
tests/kernels/core/test_mrope.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def generate_test_data(
|
||||
num_tokens: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
max_position_embeddings: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Generate test data for given configuration."""
|
||||
current_platform.seed_everything(42)
|
||||
# Create 2D positions (3, num_tokens) for multimodal case
|
||||
positions = torch.randint(
|
||||
0, max_position_embeddings // 4, (3, num_tokens), device=device
|
||||
)
|
||||
|
||||
# Create query and key tensors
|
||||
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
|
||||
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||
|
||||
return positions, query, key
|
||||
|
||||
|
||||
class MRoPETestInfo(NamedTuple):
|
||||
model_name: str
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
|
||||
atol: float = 1e-2
|
||||
rtol: float = 1.6e-2
|
||||
marks: list[pytest.MarkDecorator] = []
|
||||
|
||||
|
||||
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
|
||||
|
||||
MODELS_TO_TEST = [
|
||||
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
|
||||
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
|
||||
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
|
||||
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
|
||||
MRoPETestInfo(
|
||||
model_name="Qwen/Qwen3-VL-4B-Instruct",
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
|
||||
reason="Qwen3-VL only available after Transformers v4.57",
|
||||
)
|
||||
],
|
||||
),
|
||||
MRoPETestInfo(
|
||||
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
|
||||
reason="Qwen3-VL only available after Transformers v4.57",
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
num_tokens_list = [11, 8192]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"model_info, model_name",
|
||||
[
|
||||
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
|
||||
for test_config in MODELS_TO_TEST
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
def test_mrope(
|
||||
model_name: str,
|
||||
model_info: MRoPETestInfo,
|
||||
tp_size: int,
|
||||
dtype: torch.dtype,
|
||||
num_tokens: int,
|
||||
):
|
||||
atol = model_info.atol
|
||||
rtol = model_info.rtol
|
||||
|
||||
config = get_config(model_name, False).get_text_config()
|
||||
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = (
|
||||
config.head_dim
|
||||
if hasattr(config, "head_dim")
|
||||
else config.hidden_size // total_num_heads
|
||||
)
|
||||
is_neox_style = True
|
||||
|
||||
max_position = config.max_position_embeddings
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=config.rope_parameters,
|
||||
dtype=dtype,
|
||||
).to(device=device)
|
||||
|
||||
# create q k v input tensors
|
||||
# create rotary pos emb input tensors
|
||||
positions, query, key = generate_test_data(
|
||||
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
|
||||
)
|
||||
|
||||
query_native, key_native = mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
query_cuda, key_cuda = mrope_helper_class.forward_cuda(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
torch.testing.assert_close(query_native, query_cuda, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"model_info, model_name",
|
||||
[
|
||||
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
|
||||
for test_config in MODELS_TO_TEST
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
def test_mrope_torch_compile_tracing(
|
||||
model_name: str,
|
||||
model_info: MRoPETestInfo,
|
||||
tp_size: int,
|
||||
dtype: torch.dtype,
|
||||
num_tokens: int,
|
||||
):
|
||||
atol = model_info.atol
|
||||
rtol = model_info.rtol
|
||||
|
||||
config = get_config(model_name, False).get_text_config()
|
||||
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = (
|
||||
config.head_dim
|
||||
if hasattr(config, "head_dim")
|
||||
else config.hidden_size // total_num_heads
|
||||
)
|
||||
is_neox_style = True
|
||||
max_position = config.max_position_embeddings
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=config.rope_parameters,
|
||||
dtype=dtype,
|
||||
).to(device=device)
|
||||
|
||||
# Generate test data
|
||||
positions, query, key = generate_test_data(
|
||||
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
|
||||
)
|
||||
|
||||
# Create a wrapper that makes the in-place function appear functional
|
||||
def functional_forward_cuda(pos, q, k):
|
||||
"""Wrapper that converts in-place operation to functional style
|
||||
|
||||
CUDA Graph does not support in-place operations.
|
||||
This wrapper creates working copies of the
|
||||
input tensors and modifies them.
|
||||
"""
|
||||
q_work = q.clone() # Create working copies
|
||||
k_work = k.clone()
|
||||
# Your in-place function modifies q_work and k_work
|
||||
mrope_helper_class.forward_cuda(pos, q_work, k_work)
|
||||
return q_work, k_work # Return the modified tensors
|
||||
|
||||
# Get reference results
|
||||
query_native, key_native = mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
try:
|
||||
compiled_forward_cuda = torch.compile(
|
||||
functional_forward_cuda,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
# Run compiled version
|
||||
query_compiled_cuda, key_compiled_cuda = compiled_forward_cuda(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
)
|
||||
|
||||
# Run original version for comparison
|
||||
query_cuda = query.clone()
|
||||
key_cuda = key.clone()
|
||||
mrope_helper_class.forward_cuda(positions, query_cuda, key_cuda)
|
||||
|
||||
# Verify results
|
||||
torch.testing.assert_close(
|
||||
query_compiled_cuda, query_cuda, atol=atol, rtol=rtol
|
||||
)
|
||||
torch.testing.assert_close(key_compiled_cuda, key_cuda, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(
|
||||
query_compiled_cuda, query_native, atol=atol, rtol=rtol
|
||||
)
|
||||
torch.testing.assert_close(key_compiled_cuda, key_native, atol=atol, rtol=rtol)
|
||||
|
||||
print("✓ forward_cuda successfully traced with torch.compile inductor")
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"forward_cuda failed to trace with torch.compile inductor: {e}")
|
||||
Reference in New Issue
Block a user