Support mrope triton kernel and add unit test (#11722)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: b8zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
Yuan Luo
2025-10-20 11:51:07 +08:00
committed by GitHub
parent c4e81e64fb
commit 271d3d0d50
4 changed files with 630 additions and 1 deletions

View File

@@ -0,0 +1,140 @@
from typing import NamedTuple
import pytest
import torch
from packaging.version import Version
from transformers import AutoConfig
from transformers import __version__ as TRANSFORMERS_VERSION
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
is_cuda,
is_hip,
is_npu,
is_xpu,
)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_npu = is_npu()
_is_xpu = is_xpu()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def generate_test_data(
num_tokens: int,
num_q_heads: int,
num_kv_heads: int,
head_size: int,
max_position_embeddings: int,
dtype: torch.dtype,
device: torch.device,
):
"""Generate test data for given configuration."""
torch.manual_seed(42)
# Create 2D positions (3, num_tokens) for multimodal case
positions = torch.randint(
0, max_position_embeddings // 4, (3, num_tokens), device=device
)
# Create query and key tensors
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
return positions, query, key
class MRoPETestInfo(NamedTuple):
model_name: str
atol: float = 1e-2
rtol: float = 1.6e-2
marks: list[pytest.MarkDecorator] = []
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
MODELS_TO_TEST = [
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
]
num_tokens_list = [11, 8192]
@pytest.mark.skipif(not _is_cuda, reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize(
"model_info, model_name",
[
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
],
)
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope(
model_name: str,
model_info: MRoPETestInfo,
tp_size: int,
dtype: torch.dtype,
num_tokens: int,
):
atol = model_info.atol
rtol = model_info.rtol
config = AutoConfig.from_pretrained(model_name)
config = config.get_text_config()
# get the model config
total_num_kv_heads = config.num_key_value_heads
total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = (
config.head_dim
if hasattr(config, "head_dim")
else config.hidden_size // total_num_heads
)
is_neox_style = True
rope_theta = config.rope_theta
max_position = config.max_position_embeddings
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
rotary_dim = int(head_dim * partial_rotary_factor)
mrope_helper_class = get_rope(
head_size=head_dim,
rotary_dim=rotary_dim,
max_position=max_position,
base=rope_theta,
is_neox_style=is_neox_style,
rope_scaling=config.rope_scaling,
dtype=dtype,
).to(device=device)
# create q k v input tensors
# create rotary pos emb input tensors
positions, query, key = generate_test_data(
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
)
query_native, key_native = mrope_helper_class.forward_native(
positions,
query.clone(),
key.clone(),
)
query_cuda, key_cuda = mrope_helper_class.forward(
positions,
query.clone(),
key.clone(),
)
torch.testing.assert_close(query_native, query_cuda, atol=atol, rtol=rtol)
torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol)

View File

@@ -77,6 +77,7 @@ suites = {
TestFile("test_eval_fp8_accuracy.py", 303),
TestFile("test_fa3.py", 376),
# TestFile("test_flashmla.py", 352),
TestFile("rotary_embedding/test_mrope.py", 300),
TestFile("test_function_call_parser.py", 10),
TestFile("test_fused_moe.py", 30),
TestFile("test_gpt_oss_1gpu.py", 600),