[Fusion] [Graph] Add qknorm rope fusion operator (#4711)
### What this PR does / why we need it?
This PR add `qkv_rmsnorm_rope` operator and introduces a graph fusion
pass for `qknorm_rope` operations. The implementation includes a new
configuration flag, a pattern matching pass using
`torch._inductor.pattern_matcher`, and a custom Triton kernel for the
fused operation.
Co-authored-by: Angazenn
[supperccell@163.com](mailto:supperccell@163.com)
### Does this PR introduce _any_ user-facing change?
Yes, add new additional_config
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
@@ -20,14 +20,117 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding)
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
||||
get_ascend_device_type)
|
||||
get_ascend_device_type, is_vl_model)
|
||||
|
||||
# Currently, rope ops used on npu requires detached cos && sin as inputs.
|
||||
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
|
||||
# So we have to preprocess cos_sin_cache int cos && sin. In the future,
|
||||
# we shall implement a new rope ops which accept cos_sin_cache as inputs.
|
||||
# NOTE(Angazenn): MLA && SFA models uses attn_metadata to pass cos && sin
|
||||
# to rope in AscendMLA(SFA)Impl. However, since rope is isolated from
|
||||
# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by
|
||||
# attn_metadata. This causes that rope in GQA models must pass cos && sin
|
||||
# by different approaches.
|
||||
_cos_mla: Optional[torch.Tensor] = None
|
||||
_sin_mla: Optional[torch.Tensor] = None
|
||||
_cos_sin_cache: Optional[torch.Tensor] = None
|
||||
_cos: Optional[torch.Tensor] = None
|
||||
_sin: Optional[torch.Tensor] = None
|
||||
_cos_slice: Optional[torch.Tensor] = None
|
||||
_sin_slice: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
|
||||
device):
|
||||
global _cos_mla
|
||||
global _sin_mla
|
||||
global _cos
|
||||
global _sin
|
||||
|
||||
if _cos_mla is not None or \
|
||||
_sin_mla is not None or \
|
||||
_cos is not None or \
|
||||
_sin is not None:
|
||||
return
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
model_config = vllm_config.model_config
|
||||
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
|
||||
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
rope_dim = model_config.hf_text_config.qk_rope_head_dim
|
||||
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
elif not is_vl_model(vllm_config) and not vllm_config.model_config.use_mla:
|
||||
rope_dim = model_config.get_head_size()
|
||||
# For models using partial rope like Qwen3-Next.
|
||||
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
|
||||
rope_dim = int(rope_dim *
|
||||
model_config.hf_text_config.partial_rotary_factor)
|
||||
_cos = torch.ones(1,
|
||||
max_num_batched_tokens,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
_sin = torch.zeros(1,
|
||||
max_num_batched_tokens,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
|
||||
def get_cos_and_sin_mla():
|
||||
return _cos_mla, _sin_mla
|
||||
|
||||
|
||||
def _record_cos_sin_cache(cos_sin_cache):
|
||||
global _cos_sin_cache
|
||||
if _cos_sin_cache is not None:
|
||||
return
|
||||
_cos_sin_cache = cos_sin_cache
|
||||
|
||||
|
||||
def update_cos_sin(positions):
|
||||
global _cos
|
||||
global _sin
|
||||
global _cos_slice
|
||||
global _sin_slice
|
||||
|
||||
if _cos_sin_cache is None or \
|
||||
_cos is None or \
|
||||
_sin is None:
|
||||
return
|
||||
|
||||
num_tokens = positions.size(0)
|
||||
_cos[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
|
||||
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[0]
|
||||
_sin[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
|
||||
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[1]
|
||||
_cos_slice = _cos[:, :num_tokens]
|
||||
_sin_slice = _sin[:, :num_tokens]
|
||||
|
||||
|
||||
def get_cos_and_sin_slice():
|
||||
return _cos_slice, _sin_slice
|
||||
|
||||
|
||||
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
@@ -65,8 +168,9 @@ def _rope_forward_oot(
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
if hasattr(self, "cos") and hasattr(self, "sin") and \
|
||||
self.cos is not None and self.sin is not None:
|
||||
cos, sin = get_cos_and_sin_slice()
|
||||
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
|
||||
-1] == 128 and cos is not None and sin is not None:
|
||||
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
|
||||
# This method requires head_size and rotary_dim equal 128 and neox_style is True
|
||||
query = query.contiguous().view(1, query.shape[0], -1,
|
||||
@@ -75,7 +179,7 @@ def _rope_forward_oot(
|
||||
# Although this function modifies in-place, please retain the function's return value.
|
||||
# Otherwise, the graph fusion operation may fail.
|
||||
query, key = torch_npu.npu_apply_rotary_pos_emb(
|
||||
query, key, self.cos, self.sin)
|
||||
query, key, cos, sin)
|
||||
elif self.rotary_dim < self.head_size:
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
@@ -125,10 +229,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.cos = None
|
||||
self.sin = None
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
_record_cos_sin_cache(self.cos_sin_cache)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
@@ -141,20 +244,6 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
is_neox_style = is_neox_style_override
|
||||
forward_context = get_forward_context()
|
||||
is_first_layer = forward_context.is_first_layer
|
||||
# Generate cos and sin outside layers to avoid repeated calculation.
|
||||
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
|
||||
-1] == 128:
|
||||
if is_first_layer:
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
last_dim = cos_sin.size()[-1]
|
||||
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(
|
||||
1, 1, 2).chunk(2, dim=-2)
|
||||
# BSNH
|
||||
self.cos = cos.view(1, -1, 1, last_dim).contiguous()
|
||||
self.sin = sin.view(1, -1, 1, last_dim).contiguous()
|
||||
forward_context.is_first_layer = False
|
||||
return _rope_forward_oot(self, positions, query, key, is_neox_style,
|
||||
offsets)
|
||||
|
||||
@@ -176,8 +265,6 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
) -> None:
|
||||
self.cos = None
|
||||
self.sin = None
|
||||
extra_kwargs = {
|
||||
"extrapolation_factor": extrapolation_factor,
|
||||
"attn_factor": attn_factor,
|
||||
@@ -186,6 +273,7 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
||||
}
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
||||
_record_cos_sin_cache(self.cos_sin_cache)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user