[Fusion] [Graph] Add qknorm rope fusion operator (#4711)

### What this PR does / why we need it?
This PR add `qkv_rmsnorm_rope` operator and introduces a graph fusion
pass for `qknorm_rope` operations. The implementation includes a new
configuration flag, a pattern matching pass using
`torch._inductor.pattern_matcher`, and a custom Triton kernel for the
fused operation.

Co-authored-by: Angazenn
[supperccell@163.com](mailto:supperccell@163.com)

### Does this PR introduce _any_ user-facing change?
Yes, add new additional_config

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
Icey
2025-12-17 08:53:44 +08:00
committed by GitHub
parent b1a853b0f6
commit cadfa5ddc1
14 changed files with 754 additions and 71 deletions

View File

@@ -20,14 +20,117 @@ from typing import Optional, Tuple
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm.config import CUDAGraphMode
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding)
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
get_ascend_device_type)
get_ascend_device_type, is_vl_model)
# Currently, rope ops used on npu requires detached cos && sin as inputs.
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
# So we have to preprocess cos_sin_cache int cos && sin. In the future,
# we shall implement a new rope ops which accept cos_sin_cache as inputs.
# NOTE(Angazenn): MLA && SFA models uses attn_metadata to pass cos && sin
# to rope in AscendMLA(SFA)Impl. However, since rope is isolated from
# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by
# attn_metadata. This causes that rope in GQA models must pass cos && sin
# by different approaches.
_cos_mla: Optional[torch.Tensor] = None
_sin_mla: Optional[torch.Tensor] = None
_cos_sin_cache: Optional[torch.Tensor] = None
_cos: Optional[torch.Tensor] = None
_sin: Optional[torch.Tensor] = None
_cos_slice: Optional[torch.Tensor] = None
_sin_slice: Optional[torch.Tensor] = None
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
device):
global _cos_mla
global _sin_mla
global _cos
global _sin
if _cos_mla is not None or \
_sin_mla is not None or \
_cos is not None or \
_sin is not None:
return
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
elif not is_vl_model(vllm_config) and not vllm_config.model_config.use_mla:
rope_dim = model_config.get_head_size()
# For models using partial rope like Qwen3-Next.
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
rope_dim = int(rope_dim *
model_config.hf_text_config.partial_rotary_factor)
_cos = torch.ones(1,
max_num_batched_tokens,
1,
rope_dim,
dtype=dtype,
device=device)
_sin = torch.zeros(1,
max_num_batched_tokens,
1,
rope_dim,
dtype=dtype,
device=device)
def get_cos_and_sin_mla():
return _cos_mla, _sin_mla
def _record_cos_sin_cache(cos_sin_cache):
global _cos_sin_cache
if _cos_sin_cache is not None:
return
_cos_sin_cache = cos_sin_cache
def update_cos_sin(positions):
global _cos
global _sin
global _cos_slice
global _sin_slice
if _cos_sin_cache is None or \
_cos is None or \
_sin is None:
return
num_tokens = positions.size(0)
_cos[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[0]
_sin[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[1]
_cos_slice = _cos[:, :num_tokens]
_sin_slice = _sin[:, :num_tokens]
def get_cos_and_sin_slice():
return _cos_slice, _sin_slice
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
@@ -65,8 +168,9 @@ def _rope_forward_oot(
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
if hasattr(self, "cos") and hasattr(self, "sin") and \
self.cos is not None and self.sin is not None:
cos, sin = get_cos_and_sin_slice()
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
-1] == 128 and cos is not None and sin is not None:
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
# This method requires head_size and rotary_dim equal 128 and neox_style is True
query = query.contiguous().view(1, query.shape[0], -1,
@@ -75,7 +179,7 @@ def _rope_forward_oot(
# Although this function modifies in-place, please retain the function's return value.
# Otherwise, the graph fusion operation may fail.
query, key = torch_npu.npu_apply_rotary_pos_emb(
query, key, self.cos, self.sin)
query, key, cos, sin)
elif self.rotary_dim < self.head_size:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
@@ -125,10 +229,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
self.cos = None
self.sin = None
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
_record_cos_sin_cache(self.cos_sin_cache)
def forward_oot(
self,
@@ -141,20 +244,6 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style = self.is_neox_style
if is_neox_style_override is not None:
is_neox_style = is_neox_style_override
forward_context = get_forward_context()
is_first_layer = forward_context.is_first_layer
# Generate cos and sin outside layers to avoid repeated calculation.
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
-1] == 128:
if is_first_layer:
cos_sin = self.cos_sin_cache.index_select(0, positions)
last_dim = cos_sin.size()[-1]
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(
1, 1, 2).chunk(2, dim=-2)
# BSNH
self.cos = cos.view(1, -1, 1, last_dim).contiguous()
self.sin = sin.view(1, -1, 1, last_dim).contiguous()
forward_context.is_first_layer = False
return _rope_forward_oot(self, positions, query, key, is_neox_style,
offsets)
@@ -176,8 +265,6 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.cos = None
self.sin = None
extra_kwargs = {
"extrapolation_factor": extrapolation_factor,
"attn_factor": attn_factor,
@@ -186,6 +273,7 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
}
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
_record_cos_sin_cache(self.cos_sin_cache)
def forward_oot(
self,