[Fusion] [Graph] Add qknorm rope fusion operator (#4711)

### What this PR does / why we need it?
This PR add `qkv_rmsnorm_rope` operator and introduces a graph fusion
pass for `qknorm_rope` operations. The implementation includes a new
configuration flag, a pattern matching pass using
`torch._inductor.pattern_matcher`, and a custom Triton kernel for the
fused operation.

Co-authored-by: Angazenn
[supperccell@163.com](mailto:supperccell@163.com)

### Does this PR introduce _any_ user-facing change?
Yes, add new additional_config

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
Icey
2025-12-17 08:53:44 +08:00
committed by GitHub
parent b1a853b0f6
commit cadfa5ddc1
14 changed files with 754 additions and 71 deletions

View File

@@ -210,4 +210,4 @@ def test_aclgraph_enable():
# after check_and_update_config, mode should be VLLM_COMPILE and piecewise cudagraph
NPUPlatform.check_and_update_config(VllmConfig)
assert VllmConfig.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert VllmConfig.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
assert VllmConfig.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE

View File

@@ -17,6 +17,7 @@ from typing import Optional
from uuid import uuid4
from vllm.logger import logger
from vllm.triton_utils import HAS_TRITON
def check_kv_extra_config(vllm_config):
@@ -231,7 +232,10 @@ class AscendCompilationConfig:
deployed on Ascend platforms.
"""
def __init__(self, fuse_norm_quant: bool = True, **kwargs):
def __init__(self,
fuse_norm_quant: bool = True,
fuse_qknorm_rope: bool = False,
**kwargs):
"""
Initialize the configuration.
@@ -239,11 +243,12 @@ class AscendCompilationConfig:
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
When set to True, the system will optimize norm and quant operations.
Default: True
fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization.
Default: False
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.fuse_norm_quant = fuse_norm_quant
# Add more compilation related configs here as needed
self.fuse_qknorm_rope = HAS_TRITON or fuse_qknorm_rope
class XliteGraphConfig:

View File

@@ -209,37 +209,6 @@ def get_mc2_mask():
return _reserved_mc2_mask
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
device):
global _cos
global _sin
if _cos is not None:
return
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
_cos = torch.ones(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
_sin = torch.zeros(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
else:
_cos = None
_sin = None
def get_cos_and_sin():
return _cos, _sin
def select_moe_comm_method(num_tokens: int,
vllm_config: VllmConfig) -> Optional[MoECommType]:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all

View File

@@ -16,7 +16,6 @@ from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec
from vllm_ascend.ascend_forward_context import get_cos_and_sin
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder,
@@ -29,6 +28,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, reach_layer_for_shared_weight_series)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
@@ -286,7 +286,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
decode_metadata = None
if num_decodes > 0:
cos, sin = get_cos_and_sin()
cos, sin = get_cos_and_sin_mla()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
1].tolist()

View File

@@ -22,7 +22,6 @@ from vllm.v1.kv_cache_interface import MLAAttentionSpec
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import get_cos_and_sin
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
@@ -32,6 +31,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.compilation.acl_graph import (get_graph_params,
get_mtp_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
@@ -531,7 +531,7 @@ class AscendMLAMetadataBuilder:
decode_metadata = None
if num_decodes > 0:
cos, sin = get_cos_and_sin()
cos, sin = get_cos_and_sin_mla()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
1].tolist()

View File

@@ -16,12 +16,12 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import get_cos_and_sin
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
@@ -187,7 +187,7 @@ class AscendSFAMetadataBuilder:
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
cos, sin = get_cos_and_sin()
cos, sin = get_cos_and_sin_mla()
assert self.cos_cache is not None and self.sin_cache is not None
new_cos = self.cos_cache[input_positions][:, None, None]

View File

@@ -50,4 +50,7 @@ class GraphFusionPassManager:
from .passes.norm_quant_fusion_pass import \
AddRMSNormQuantFusionPass
self.passes.append(AddRMSNormQuantFusionPass(config))
# Add more passes here as needed
if self.ascend_compilation_config.get("fuse_qknorm_rope", True):
from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass
self.passes.append(QKNormRopeFusionPass(config))

View File

@@ -0,0 +1,293 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import (PatternMatcherPass,
PatternPrettyPrinter)
from vllm.attention.layer import Attention
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import (VllmConfig, get_current_vllm_config,
get_layers_from_vllm_config)
class QKNormRopeFusionPattern:
def __init__(self,
vllm_config,
head_dim,
num_heads,
num_kv_heads,
eps=1e-6):
self.vllm_config = vllm_config
self.head_dim = head_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps
vllm_config = get_current_vllm_config()
self.device = vllm_config.device_config.device if vllm_config.device_config else None
def get_inputs(self):
T = 5
qkv = torch.empty(T,
self.q_size + 2 * self.kv_size,
dtype=torch.bfloat16,
device="npu")
q_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
k_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
cos = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
sin = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
return [qkv, q_weight, k_weight, cos, sin]
def register(self, pm_pass: PatternMatcherPass):
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
self.eps)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
self.eps)
q_flat = q_norm_out.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
self.head_dim)
k_flat = k_norm_out.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
self.head_dim)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
q_reshape, k_reshape, cos, sin)
return q_rope, k_rope, v
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=self.q_size,
kv_hidden_size=self.kv_size,
head_dim=self.head_dim,
eps=self.eps,
q_bias=None,
k_bias=None,
sin=sin,
cos=cos)
return results
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class QKNormRopeFusionPatternWithBias:
def __init__(self,
vllm_config,
head_dim,
num_heads,
num_kv_heads,
eps=1e-6):
self.head_dim = head_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps
self.vllm_config = vllm_config
self.device = vllm_config.device_config.device if vllm_config.device_config else None
def get_inputs(self):
T = 5
qkv = torch.empty(T,
self.q_size + 2 * self.kv_size,
dtype=torch.bfloat16,
device="npu")
q_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
k_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
cos = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
sin = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin]
def register(self, pm_pass: PatternMatcherPass):
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, q_bias: torch.Tensor,
k_bias: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
self.eps)
q_normed = q_norm_out + q_bias
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
self.eps)
k_normed = k_norm_out + k_bias
q_flat = q_normed.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
self.head_dim)
k_flat = k_normed.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
self.head_dim)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
q_reshape, k_reshape, cos, sin)
return q_rope, k_rope, v
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, q_bias: torch.Tensor,
k_bias: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=self.q_size,
kv_hidden_size=self.kv_size,
head_dim=self.head_dim,
eps=self.eps,
q_bias=q_bias,
k_bias=k_bias,
cos=cos,
sin=sin)
return results
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class QKNormRopeFusionPass(VllmInductorPass):
"""
A pass for fusing QKV split and RMSNorm operations into a single qk_rmsnorm operator.
"""
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
pass_name="qknorm_rope_fusion_pass")
dtype = vllm_config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logging.info(
"QKNorm and Rope fusion not enabled: unsupported dtype %s",
dtype)
return
# use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
vllm_config, Attention)
if len(attn_layers) == 0:
logging.info(
"QKNorm and Rope fusion enabled, but no Attention layers were discovered."
)
return
layer = next(iter(attn_layers.values()))
for epsilon in [1e-6, 1e-5]:
if layer.head_size != 128:
logging.debug(
"QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128",
layer.head_size)
continue
QKNormRopeFusionPattern(vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon).register(
self.pattern_match_passes)
QKNormRopeFusionPatternWithBias(vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon).register(
self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
logging.debug("Fused %s QKNorm and Rope patterns", self.matched_count)
logging.debug("Patterns registered for replacement:")
pattern_idx = 0
for pattern_entry in self.pattern_match_passes.patterns.values():
for p in pattern_entry:
p_str = PatternPrettyPrinter.run(p.pattern)
logging.debug("Pattern %d: %s", pattern_idx, p_str)
pattern_idx += 1
self.end_and_log()
def is_applicable(self, runtime_shape):
"""
Check if the pass is applicable for the current configuration.
"""
return True

View File

@@ -16,10 +16,15 @@
#
import torch
from vllm.triton_utils import HAS_TRITON
import vllm_ascend.ops.fused_moe.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.register_custom_ops # noqa
if HAS_TRITON:
import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_rope # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.rotary_embedding import (

View File

@@ -20,14 +20,117 @@ from typing import Optional, Tuple
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm.config import CUDAGraphMode
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding)
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
get_ascend_device_type)
get_ascend_device_type, is_vl_model)
# Currently, rope ops used on npu requires detached cos && sin as inputs.
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
# So we have to preprocess cos_sin_cache int cos && sin. In the future,
# we shall implement a new rope ops which accept cos_sin_cache as inputs.
# NOTE(Angazenn): MLA && SFA models uses attn_metadata to pass cos && sin
# to rope in AscendMLA(SFA)Impl. However, since rope is isolated from
# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by
# attn_metadata. This causes that rope in GQA models must pass cos && sin
# by different approaches.
_cos_mla: Optional[torch.Tensor] = None
_sin_mla: Optional[torch.Tensor] = None
_cos_sin_cache: Optional[torch.Tensor] = None
_cos: Optional[torch.Tensor] = None
_sin: Optional[torch.Tensor] = None
_cos_slice: Optional[torch.Tensor] = None
_sin_slice: Optional[torch.Tensor] = None
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
device):
global _cos_mla
global _sin_mla
global _cos
global _sin
if _cos_mla is not None or \
_sin_mla is not None or \
_cos is not None or \
_sin is not None:
return
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
elif not is_vl_model(vllm_config) and not vllm_config.model_config.use_mla:
rope_dim = model_config.get_head_size()
# For models using partial rope like Qwen3-Next.
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
rope_dim = int(rope_dim *
model_config.hf_text_config.partial_rotary_factor)
_cos = torch.ones(1,
max_num_batched_tokens,
1,
rope_dim,
dtype=dtype,
device=device)
_sin = torch.zeros(1,
max_num_batched_tokens,
1,
rope_dim,
dtype=dtype,
device=device)
def get_cos_and_sin_mla():
return _cos_mla, _sin_mla
def _record_cos_sin_cache(cos_sin_cache):
global _cos_sin_cache
if _cos_sin_cache is not None:
return
_cos_sin_cache = cos_sin_cache
def update_cos_sin(positions):
global _cos
global _sin
global _cos_slice
global _sin_slice
if _cos_sin_cache is None or \
_cos is None or \
_sin is None:
return
num_tokens = positions.size(0)
_cos[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[0]
_sin[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[1]
_cos_slice = _cos[:, :num_tokens]
_sin_slice = _sin[:, :num_tokens]
def get_cos_and_sin_slice():
return _cos_slice, _sin_slice
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
@@ -65,8 +168,9 @@ def _rope_forward_oot(
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
if hasattr(self, "cos") and hasattr(self, "sin") and \
self.cos is not None and self.sin is not None:
cos, sin = get_cos_and_sin_slice()
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
-1] == 128 and cos is not None and sin is not None:
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
# This method requires head_size and rotary_dim equal 128 and neox_style is True
query = query.contiguous().view(1, query.shape[0], -1,
@@ -75,7 +179,7 @@ def _rope_forward_oot(
# Although this function modifies in-place, please retain the function's return value.
# Otherwise, the graph fusion operation may fail.
query, key = torch_npu.npu_apply_rotary_pos_emb(
query, key, self.cos, self.sin)
query, key, cos, sin)
elif self.rotary_dim < self.head_size:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
@@ -125,10 +229,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
self.cos = None
self.sin = None
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
_record_cos_sin_cache(self.cos_sin_cache)
def forward_oot(
self,
@@ -141,20 +244,6 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style = self.is_neox_style
if is_neox_style_override is not None:
is_neox_style = is_neox_style_override
forward_context = get_forward_context()
is_first_layer = forward_context.is_first_layer
# Generate cos and sin outside layers to avoid repeated calculation.
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
-1] == 128:
if is_first_layer:
cos_sin = self.cos_sin_cache.index_select(0, positions)
last_dim = cos_sin.size()[-1]
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(
1, 1, 2).chunk(2, dim=-2)
# BSNH
self.cos = cos.view(1, -1, 1, last_dim).contiguous()
self.sin = sin.view(1, -1, 1, last_dim).contiguous()
forward_context.is_first_layer = False
return _rope_forward_oot(self, positions, query, key, is_neox_style,
offsets)
@@ -176,8 +265,6 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.cos = None
self.sin = None
extra_kwargs = {
"extrapolation_factor": extrapolation_factor,
"attn_factor": attn_factor,
@@ -186,6 +273,7 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
}
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
_record_cos_sin_cache(self.cos_sin_cache)
def forward_oot(
self,

View File

@@ -0,0 +1,305 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Optional
import torch
import triton # type: ignore
import triton.language as tl # type: ignore
from vllm.utils.torch_utils import direct_register_custom_op
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
@triton.jit
def split_qkv_rmsnorm_rope_kernel(
input_ptr,
sin_ptr,
cos_ptr,
q_ptr,
k_ptr,
v_ptr,
q_weight_ptr,
q_bias_ptr,
k_weight_ptr,
k_bias_ptr,
batch_size,
q_hidden_size: tl.constexpr,
kv_hidden_size: tl.constexpr,
total_hidden_size: tl.constexpr,
eps: tl.constexpr,
Q_BLOCK_SIZE: tl.constexpr,
KV_BLOCK_SIZE: tl.constexpr,
BIAS: tl.constexpr,
HEAD_DIM: tl.constexpr,
HALF_HEAD_DIM: tl.constexpr,
):
row_pid = tl.program_id(0)
col_pid = tl.program_id(1)
row_step = tl.num_programs(0)
# q
weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM))
if BIAS:
bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM))
input_offset = row_pid * total_hidden_size
output_offset = row_pid * q_hidden_size
input_offset_step = row_step * total_hidden_size
output_offset_step = row_step * q_hidden_size
for row_idx in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE)
valid_mask = col_indices < q_hidden_size
input_values = (tl.load(input_ptr + input_offset + col_indices,
mask=valid_mask,
other=0.0).to(tl.float32).reshape(
Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM))
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
Q_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = (input_values * reciprocal_std
) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values +
bias_values).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(
tl.bfloat16)
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
x1 = tl.extract_slice(
normalized_values,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
x2 = tl.extract_slice(
normalized_values,
offsets=(0, HALF_HEAD_DIM),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
dtype=tl.bfloat16)
cat_x = tl.insert_slice(
cat_x,
-x2,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.insert_slice(
cat_x,
x1,
offsets=(0, HALF_HEAD_DIM),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_q = cat_x * sin + normalized_values * cos
tl.store(
q_ptr + output_offset + col_indices,
roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty),
mask=valid_mask,
)
input_offset += input_offset_step
output_offset += output_offset_step
weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM))
if BIAS:
bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM))
input_offset = row_pid * total_hidden_size + q_hidden_size
output_offset = row_pid * kv_hidden_size
output_offset_step = row_step * kv_hidden_size
for row_idx in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
valid_mask = col_indices < kv_hidden_size
input_values = (tl.load(input_ptr + input_offset + col_indices,
mask=valid_mask,
other=0.0).to(tl.float32).reshape(
KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM))
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
KV_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = (input_values * reciprocal_std
) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values +
bias_values).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(
tl.bfloat16)
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
x1 = tl.extract_slice(
normalized_values,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
x2 = tl.extract_slice(
normalized_values,
offsets=(0, HALF_HEAD_DIM),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
dtype=tl.bfloat16)
cat_x = tl.insert_slice(
cat_x,
-x2,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.insert_slice(
cat_x,
x1,
offsets=(0, HALF_HEAD_DIM),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_k = cat_x * sin + normalized_values * cos
tl.store(
k_ptr + output_offset + col_indices,
roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE),
mask=valid_mask,
)
input_offset += input_offset_step
output_offset += output_offset_step
input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size
output_offset = row_pid * kv_hidden_size
for _ in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
valid_mask = col_indices < kv_hidden_size
input_values = tl.load(input_ptr + input_offset + col_indices,
mask=valid_mask,
other=0.0)
tl.store(v_ptr + output_offset + col_indices,
input_values,
mask=valid_mask)
input_offset += input_offset_step
output_offset += output_offset_step
def split_qkv_rmsnorm_rope_impl(
input: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_hidden_size: int,
kv_hidden_size: int,
head_dim: int,
eps: float,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
KV_BLOCK_SIZE = triton.next_power_of_2(head_dim)
assert KV_BLOCK_SIZE == head_dim
assert q_hidden_size % kv_hidden_size == 0
Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim
batch_size = input.shape[0]
total_hidden_size = q_hidden_size + kv_hidden_size * 2
q_output = torch.empty(batch_size,
q_hidden_size,
device=input.device,
dtype=input.dtype)
k_output = torch.empty(batch_size,
kv_hidden_size,
device=input.device,
dtype=input.dtype)
v_output = torch.empty(batch_size,
kv_hidden_size,
device=input.device,
dtype=input.dtype)
n_cols = kv_hidden_size // KV_BLOCK_SIZE
num_vectorcore = get_vectorcore_num()
assert num_vectorcore % n_cols == 0
n_rows = num_vectorcore // n_cols
BIAS = q_bias is not None
split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)](
input,
sin,
cos,
q_output,
k_output,
v_output,
q_weight,
q_bias,
k_weight,
k_bias,
batch_size,
q_hidden_size,
kv_hidden_size,
total_hidden_size,
eps,
Q_BLOCK_SIZE,
KV_BLOCK_SIZE,
BIAS,
head_dim,
head_dim // 2,
)
return q_output, k_output, v_output
def split_qkv_rmsnorm_rope_impl_fake(
input: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_hidden_size: int,
kv_hidden_size: int,
head_dim: int,
eps: float,
q_bias: Optional[torch.Tensor] = None,
k_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Fake implementation for shape inference during Dynamo/AOT tracing.
# Note: sin and cos are not used in shape computation, but must be present in signature.
batch_size = input.shape[0]
q_output = torch.empty(
batch_size,
q_hidden_size,
device=input.device,
dtype=input.dtype,
)
k_output = torch.empty(
batch_size,
kv_hidden_size,
device=input.device,
dtype=input.dtype,
)
v_output = torch.empty(
batch_size,
kv_hidden_size,
device=input.device,
dtype=input.dtype,
)
return q_output, k_output, v_output
direct_register_custom_op(op_name="qkv_rmsnorm_rope",
op_func=split_qkv_rmsnorm_rope_impl,
fake_impl=split_qkv_rmsnorm_rope_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")

View File

@@ -25,6 +25,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.ops.rotary_embedding import update_cos_sin
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
PADDING_SLOT_ID = -1
@@ -143,6 +144,9 @@ class EagleProposer(Proposer):
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None):
# update global cos, sin
update_cos_sin(self.positions[:num_tokens])
with set_ascend_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
@@ -338,6 +342,8 @@ class EagleProposer(Proposer):
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = builder.build(0, common_attn_metadata,
self.runner.get_model())
# update global cos, sin
update_cos_sin(self.positions[:num_input_tokens])
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
@@ -443,6 +449,10 @@ class EagleProposer(Proposer):
attn_metadata.attn_mask = attn_mask
# Run the model.
# update global cos, sin
update_cos_sin(self.positions[:input_batch_size])
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):

View File

@@ -84,12 +84,6 @@ from vllm.v1.worker.utils import AttentionGroup
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import (MoECommType,
get_mc2_tokens_capacity,
select_moe_comm_method,
set_ascend_forward_context,
set_cos_and_sin, set_mc2_mask,
set_mc2_tokens_capacity)
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
@@ -111,6 +105,7 @@ from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.eplb.utils import model_register
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
from vllm_ascend.sample.logits_processor import build_logitsprocs
@@ -125,6 +120,10 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_moe_model, lmhead_tp_enable, vllm_version_is)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
from vllm_ascend.ascend_forward_context import ( # isort: skip
MoECommType, get_mc2_tokens_capacity, select_moe_comm_method,
set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity)
if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped]
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
@@ -1122,6 +1121,9 @@ class NPUModelRunner(GPUModelRunner):
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
# update global cos, sin
update_cos_sin(positions)
if lmhead_tp_enable():
max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len
logits_indices = nn.functional.pad(
@@ -2084,6 +2086,9 @@ class NPUModelRunner(GPUModelRunner):
else:
positions = self.positions.gpu[:num_tokens_padded]
# update global cos, sin
update_cos_sin(positions)
if get_pp_group().is_first_rank:
intermediate_tensors = None
else: