From 48b624e4cc95099bad6f58981c56db02c232e619 Mon Sep 17 00:00:00 2001 From: Hexiang Wang <56632993+whx-sjtu@users.noreply.github.com> Date: Mon, 9 Mar 2026 23:08:43 +0800 Subject: [PATCH] [BugFix] Fix implementation bug of triton rope_siso (#7082) ### What this PR does / why we need it? Previously implemention of triton rope_siso missing the storage of second half of rope results, which will result in: 1. accuracy problem in neox-style scenario 2. ub overflow in non neox-style scenario This PR fixes it and supplement nightly test case for it. - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d Signed-off-by: whx-sjtu <2952154980@qq.com> --- .../ops/singlecard_ops/triton/test_rope.py | 86 ++++++++++++++++++- vllm_ascend/ops/triton/rope.py | 3 + 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rope.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rope.py index 7e68bf96..6e577a39 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rope.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rope.py @@ -3,7 +3,7 @@ import gc import pytest import torch -from vllm_ascend.ops.triton.rope import rope_forward_triton +from vllm_ascend.ops.triton.rope import rope_forward_triton, rope_forward_triton_siso from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton IS_NEOX_STYLE = [True, False] @@ -65,6 +65,33 @@ def _rope_pytorch_native( key = key_rot.to(orig_dtype) return query, key +def _rope_siso_pytorch_native( + query, cos, sin, rope_dim, + is_neox_style) -> tuple[torch.Tensor, torch.Tensor | None]: + """PyTorch-native implementation equivalent to forward().""" + assert query is not None + orig_dtype = query.dtype + query_rot = query[..., :rope_dim].to(torch.float32) + head_size = query.shape[-1] + if rope_dim < head_size: + query_pass = query[..., rope_dim:] + + if is_neox_style: + cos = cos.repeat(1, 2).unsqueeze(-2).to(torch.float32) + sin = sin.repeat(1, 2).unsqueeze(-2).to(torch.float32) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2).to(torch.float32) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2).to(torch.float32) + + rotate_fn = rotate_neox if is_neox_style else rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + + if rope_dim < head_size: + query = torch.cat((query_rot.to(orig_dtype), query_pass), dim=-1) + else: + query = query_rot.to(orig_dtype) + return query + @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -220,3 +247,60 @@ def test_rotary_embedding_triton_kernel_with_cos_sin_cache( gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() + + +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_q_heads", NUM_Q_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_rotary_embedding_triton_kernel_siso( + is_neox_style: bool, + num_tokens: int, + num_q_heads: int, + head_size: int, + rotary_dim: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + torch.manual_seed(seed) + torch.set_default_device(device) + init_device_properties_triton() + if rotary_dim == -1: + rotary_dim = head_size + sin = torch.randn(num_tokens, rotary_dim // 2, dtype=dtype, device=device) + cos = torch.randn(num_tokens, rotary_dim // 2, dtype=dtype, device=device) + q_trt = torch.randn(num_tokens, + num_q_heads, + head_size, + dtype=dtype, + device=device) + q_gold = torch.randn(num_tokens, + num_q_heads, + head_size, + dtype=dtype, + device=device) + q_trt.copy_(q_gold) + q_trt = rope_forward_triton_siso(q_trt, + cos, + sin, + rope_dim=rotary_dim, + is_neox_style=is_neox_style) + q_gold = _rope_siso_pytorch_native(q_gold, + cos, + sin, + rope_dim=rotary_dim, + is_neox_style=is_neox_style) + # Compare the results. + torch.testing.assert_close(q_trt.view(q_gold.size()), + q_gold, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() \ No newline at end of file diff --git a/vllm_ascend/ops/triton/rope.py b/vllm_ascend/ops/triton/rope.py index 90906517..02a8f043 100644 --- a/vllm_ascend/ops/triton/rope.py +++ b/vllm_ascend/ops/triton/rope.py @@ -218,6 +218,9 @@ def _triton_rope_siso( new_qk_tile_1 = qk_tile_1 * cos_row - qk_tile_2 * sin_row tl.store(qk_start_ptr + first_half_offsets, new_qk_tile_1, mask=first_mask) + new_qk_tile_2 = qk_tile_2 * cos_row + qk_tile_1 * sin_row + tl.store(qk_start_ptr + second_half_offsets, new_qk_tile_2, mask=second_mask) + def rope_forward_triton( q: torch.Tensor,