[Ops][Triton] Add a triton kernel supporting partial rope. (#4413)

### What this PR does / why we need it?
This PR adds a triton rope kernel witch supports scenarios of `rope_dim
!= head_dim`. This can save the split op before rope and the concat op
after rope. Profiling shows improvement.

### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
I will add related ut after ci integrated with triton.


- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-12-02 17:10:19 +08:00
committed by GitHub
parent 8907010815
commit 96b2cdf6d8
6 changed files with 421 additions and 20 deletions

View File

View File

@@ -0,0 +1,141 @@
import gc
import pytest
import torch
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.bfloat16, torch.float16]
HEAD_SIZES = [64, 128]
ROTARY_DIMS = [32, 64]
NUM_Q_HEADS = [64]
NUM_K_HEADS = [1]
NUM_TOKENS = [1, 4, 8, 16, 1024]
SEEDS = [0]
DEVICES = [f"npu:{0}"]
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
def _rope_pytorch_native(
query, key, cos, sin, rope_dim,
is_neox_style) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
orig_dtype = query.dtype
query_rot = query[..., :rope_dim].to(torch.float32)
key_rot = key[..., :rope_dim].to(torch.float32)
head_size = query.shape[-1]
if rope_dim < head_size:
query_pass = query[..., rope_dim:]
key_pass = key[..., rope_dim:]
if is_neox_style:
cos = cos.repeat(1, 2).unsqueeze(-2).to(torch.float32)
sin = sin.repeat(1, 2).unsqueeze(-2).to(torch.float32)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2).to(torch.float32)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2).to(torch.float32)
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if rope_dim < head_size:
query = torch.cat((query_rot.to(orig_dtype), query_pass), dim=-1)
key = torch.cat((key_rot.to(orig_dtype), key_pass), dim=-1)
else:
query = query_rot.to(orig_dtype)
key = key_rot.to(orig_dtype)
return query, key
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads", NUM_Q_HEADS)
@pytest.mark.parametrize("num_k_heads", NUM_K_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_rotary_embedding_triton_kernel(
is_neox_style: bool,
num_tokens: int,
num_q_heads: int,
num_k_heads: int,
head_size: int,
rotary_dim: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()
if rotary_dim == -1:
rotary_dim = head_size
sin = torch.randn(num_tokens, rotary_dim // 2, dtype=dtype, device=device)
cos = torch.randn(num_tokens, rotary_dim // 2, dtype=dtype, device=device)
q_trt = torch.randn(num_tokens,
num_q_heads,
head_size,
dtype=dtype,
device=device)
k_trt = torch.randn(num_tokens,
num_k_heads,
head_size,
dtype=dtype,
device=device)
q_gold = torch.randn(num_tokens,
num_q_heads,
head_size,
dtype=dtype,
device=device)
k_gold = torch.randn(num_tokens,
num_k_heads,
head_size,
dtype=dtype,
device=device)
q_trt.copy_(q_gold)
k_trt.copy_(k_gold)
q_trt, k_trt = rope_forward_triton(q_trt,
k_trt,
cos,
sin,
rope_dim=rotary_dim,
is_neox_style=is_neox_style)
q_gold, k_gold = _rope_pytorch_native(q_gold,
k_gold,
cos,
sin,
rope_dim=rotary_dim,
is_neox_style=is_neox_style)
# Compare the results.
torch.testing.assert_close(q_trt.view(q_gold.size()),
q_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(k_trt.view(k_gold.size()),
k_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -9,6 +9,7 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
@@ -16,6 +17,7 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
wait_for_kv_layer_from_connector) wait_for_kv_layer_from_connector)
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_enable_nz) is_enable_nz)
@@ -492,15 +494,33 @@ class AscendSFAImpl(MLAAttentionImpl):
cos = attn_metadata.cos cos = attn_metadata.cos
sin = attn_metadata.sin sin = attn_metadata.sin
# q process in new stream
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
k_proj, need_gather_q_kv)
k = self.k_norm(k_proj).unsqueeze(1)
k = k.view(-1, 1, self.head_dim)
if HAS_TRITON:
cos = cos.view(-1, self.qk_rope_head_dim)
sin = sin.view(-1, self.qk_rope_head_dim)
q, k = rope_forward_triton(q,
k,
cos,
sin,
rope_dim=self.qk_rope_head_dim,
is_neox_style=True)
else:
cos_q, sin_q = cos, sin cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
# q process in new stream
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_head, self.head_dim) # [b,s,64,128]
q_pe, q_nope = torch.split( q_pe, q_nope = torch.split(
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], q,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64,64+64] dim=-1) # [b,s,64,64+64]
q_pe = q_pe.unsqueeze(2) q_pe = q_pe.unsqueeze(2)
@@ -508,12 +528,9 @@ class AscendSFAImpl(MLAAttentionImpl):
q_pe = q_pe.squeeze(2) q_pe = q_pe.squeeze(2)
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
k_proj, need_gather_q_kv)
k = self.k_norm(k_proj).unsqueeze(1)
k_pe, k_nope = torch.split( k_pe, k_nope = torch.split(
k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], k,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64] dim=-1) # [b,s,64+64]
k_pe = k_pe.unsqueeze(2) k_pe = k_pe.unsqueeze(2)

View File

@@ -0,0 +1,210 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from vllm.triton_utils import HAS_TRITON, tl, triton
if HAS_TRITON:
import torch_npu._inductor # noqa: F401
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
# TODO(whx-sjtu): Add tiling of n_q_head and n_kv_head to support more models.
# I only have tested this kernel on Deepseek V3.2 and Qwen3-Next.
# For models with larger n_q_head and n_kv_head such as GLM 4.6, this is not supported yet.
@triton.jit
def _triton_rope(
q_ptr,
q_row_stride,
k_ptr,
k_row_stride,
cos,
cos_row_stride,
sin,
sin_row_stride,
num_tokens,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
rope_dim: tl.constexpr,
pad_n_qh: tl.constexpr,
pad_n_kh: tl.constexpr,
pad_rope_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_NEOX_STYLE: tl.constexpr,
):
"""
This triton kernel applies rotary embedding on q and k.
It supports rope_dim != head_dim scenario.
It supports both neox style and non-neox style rope computation.
Input tensor layout assumptions:
q size: (num_tokens, num_q_heads, head_dim)
q stride: (num_q_heads * head_dim, head_dim, 1)
k size: (num_tokens, num_kv_heads, head_dim)
k stride: (num_kv_heads * head_dim, head_dim, 1)
cos/sin size: (num_tokens, rope_dim/2)
cos/sin stride: (rope_dim/2, 1)
Different compute pattern of IS_NEOX_STYLE:
if IS_NEOX_STYLE:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if IS_NEOX_STYLE:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
"""
pid = tl.program_id(0).to(tl.int64)
row_block_size = tl.num_programs(0)
for row_idx in tl.range(pid, num_tokens, row_block_size):
q_start_ptr = q_ptr + row_idx * q_row_stride
k_start_ptr = k_ptr + row_idx * k_row_stride
# ####################################################################
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
# m of this program instance
# ####################################################################
cos_start_ptr = cos + row_idx * cos_row_stride
sin_start_ptr = sin + row_idx * sin_row_stride
cos_offsets = tl.arange(0, pad_rope_dim // 2)
cos_mask = cos_offsets < (rope_dim // 2)
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask,
other=0).to(tl.float32)
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask,
other=0).to(tl.float32)
# ####################################################################
# Load the left and right half of q and k for the current
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
if IS_NEOX_STYLE:
first_half_q_offsets = tl.arange(
0, pad_n_qh)[:, None] * hd + tl.arange(
0, pad_rope_dim // 2)[None, :]
first_half_k_offsets = tl.arange(
0, pad_n_kh)[:, None] * hd + tl.arange(
0, pad_rope_dim // 2)[None, :]
else:
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + (
2 * tl.arange(0, pad_rope_dim // 2)[None, :])
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + (
2 * tl.arange(0, pad_rope_dim // 2)[None, :])
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
0, pad_rope_dim // 2)[None, :] < (rope_dim // 2))
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
0, pad_rope_dim // 2)[None, :] < (rope_dim // 2))
q_tile_1 = tl.load(q_start_ptr + first_half_q_offsets,
mask=first_q_mask,
other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_start_ptr + first_half_k_offsets,
mask=first_k_mask,
other=0).to(sin_row.dtype)
# right half of the head
if IS_NEOX_STYLE:
second_half_q_offsets = first_half_q_offsets + (rope_dim // 2)
second_half_k_offsets = first_half_k_offsets + (rope_dim // 2)
else:
second_half_q_offsets = first_half_q_offsets + 1
second_half_k_offsets = first_half_k_offsets + 1
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(q_start_ptr + second_half_q_offsets,
mask=second_q_mask,
other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_start_ptr + second_half_k_offsets,
mask=second_k_mask,
other=0).to(sin_row.dtype)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_start_ptr + first_half_q_offsets,
new_q_tile_1,
mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_start_ptr + second_half_q_offsets,
new_q_tile_2,
mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_start_ptr + first_half_k_offsets,
new_k_tile_1,
mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_start_ptr + second_half_k_offsets,
new_k_tile_2,
mask=second_k_mask)
def rope_forward_triton(q,
k,
cos,
sin,
rope_dim: int = -1,
is_neox_style: bool = True):
if not q.is_contiguous():
q = q.contiguous()
if not k.is_contiguous():
k = k.contiguous()
num_tokens, n_q_head, head_dim = q.shape
n_kv_head = k.shape[1]
cos = cos.view(num_tokens, -1)
sin = sin.view(num_tokens, -1)
if rope_dim == -1:
# If rope_dim is not specified, we assume that input cos/sin is not
# duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2
rope_dim = cos.shape[-1] * 2
assert rope_dim <= head_dim
pad_rope_dim = triton.next_power_of_2(rope_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
num_vectorcore = get_vectorcore_num()
n_row = min(num_tokens, num_vectorcore)
_triton_rope[(n_row, )](
q,
q.stride(0),
k,
k.stride(0),
cos,
cos.stride(0),
sin,
sin.stride(0),
num_tokens,
n_q_head,
n_kv_head,
head_dim,
rope_dim,
pad_n_q_head,
pad_n_kv_head,
pad_rope_dim,
BLOCK_SIZE=BLOCK_SIZE,
IS_NEOX_STYLE=is_neox_style,
)
return q, k

View File

@@ -0,0 +1,30 @@
from typing import Any, Dict
import torch
from vllm.triton_utils import HAS_TRITON, triton
_NUM_AICORE = -1
_NUM_VECTORCORE = -1
def init_device_properties_triton():
global _NUM_AICORE, _NUM_VECTORCORE
if _NUM_AICORE == -1 and HAS_TRITON:
device_properties: Dict[str, Any] = (
triton.runtime.driver.active.utils.get_device_properties(
torch.npu.current_device()))
_NUM_AICORE = device_properties.get("num_aicore", -1)
_NUM_VECTORCORE = device_properties.get("num_vectorcore", -1)
assert _NUM_AICORE > 0 and _NUM_VECTORCORE > 0, "Failed to detect device properties."
def get_aicore_num():
global _NUM_AICORE
assert _NUM_AICORE > 0, "Device properties not initialized. Please call init_device_properties_triton() first."
return _NUM_AICORE
def get_vectorcore_num():
global _NUM_VECTORCORE
assert _NUM_VECTORCORE > 0, "Device properties not initialized. Please call init_device_properties_triton() first."
return _NUM_VECTORCORE

View File

@@ -49,6 +49,7 @@ from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.cpu_binding import bind_cpus from vllm_ascend.cpu_binding import bind_cpus
from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (check_ascend_device_type, is_enable_nz, from vllm_ascend.utils import (check_ascend_device_type, is_enable_nz,
prefill_context_parallel_enable, prefill_context_parallel_enable,
@@ -226,6 +227,8 @@ class NPUWorker(WorkerBase):
self._init_worker_distributed_environment() self._init_worker_distributed_environment()
# Set random seed. # Set random seed.
NPUPlatform.seed_everything(self.model_config.seed) NPUPlatform.seed_everything(self.model_config.seed)
# Initialize device properties used by triton kernels.
init_device_properties_triton()
return device return device
def init_device(self): def init_device(self):