[Fusion] [Graph] Add qknorm rope fusion operator (#4711)

### What this PR does / why we need it?
This PR add `qkv_rmsnorm_rope` operator and introduces a graph fusion
pass for `qknorm_rope` operations. The implementation includes a new
configuration flag, a pattern matching pass using
`torch._inductor.pattern_matcher`, and a custom Triton kernel for the
fused operation.

Co-authored-by: Angazenn
[supperccell@163.com](mailto:supperccell@163.com)

### Does this PR introduce _any_ user-facing change?
Yes, add new additional_config

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
Icey
2025-12-17 08:53:44 +08:00
committed by GitHub
parent b1a853b0f6
commit cadfa5ddc1
14 changed files with 754 additions and 71 deletions

View File

@@ -17,6 +17,7 @@ from typing import Optional
from uuid import uuid4 from uuid import uuid4
from vllm.logger import logger from vllm.logger import logger
from vllm.triton_utils import HAS_TRITON
def check_kv_extra_config(vllm_config): def check_kv_extra_config(vllm_config):
@@ -231,7 +232,10 @@ class AscendCompilationConfig:
deployed on Ascend platforms. deployed on Ascend platforms.
""" """
def __init__(self, fuse_norm_quant: bool = True, **kwargs): def __init__(self,
fuse_norm_quant: bool = True,
fuse_qknorm_rope: bool = False,
**kwargs):
""" """
Initialize the configuration. Initialize the configuration.
@@ -239,11 +243,12 @@ class AscendCompilationConfig:
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization. fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
When set to True, the system will optimize norm and quant operations. When set to True, the system will optimize norm and quant operations.
Default: True Default: True
fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization.
Default: False
**kwargs: Additional optional parameters for forward compatibility and configuration extension. **kwargs: Additional optional parameters for forward compatibility and configuration extension.
""" """
self.fuse_norm_quant = fuse_norm_quant self.fuse_norm_quant = fuse_norm_quant
# Add more compilation related configs here as needed self.fuse_qknorm_rope = HAS_TRITON or fuse_qknorm_rope
class XliteGraphConfig: class XliteGraphConfig:

View File

@@ -209,37 +209,6 @@ def get_mc2_mask():
return _reserved_mc2_mask return _reserved_mc2_mask
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
device):
global _cos
global _sin
if _cos is not None:
return
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
_cos = torch.ones(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
_sin = torch.zeros(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
else:
_cos = None
_sin = None
def get_cos_and_sin():
return _cos, _sin
def select_moe_comm_method(num_tokens: int, def select_moe_comm_method(num_tokens: int,
vllm_config: VllmConfig) -> Optional[MoECommType]: vllm_config: VllmConfig) -> Optional[MoECommType]:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all

View File

@@ -16,7 +16,6 @@ from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec from vllm.v1.kv_cache_interface import MLAAttentionSpec
from vllm_ascend.ascend_forward_context import get_cos_and_sin
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAImpl, AscendMLAMetadata, AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder, AscendMLAMetadataBuilder,
@@ -29,6 +28,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
wait_for_kv_layer_from_connector) wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params, from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces) update_graph_params_workspaces)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import ( from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, reach_layer_for_shared_weight_series) is_hidden_layer, reach_layer_for_shared_weight_series)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
@@ -286,7 +286,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
decode_metadata = None decode_metadata = None
if num_decodes > 0: if num_decodes > 0:
cos, sin = get_cos_and_sin() cos, sin = get_cos_and_sin_mla()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
1].tolist() 1].tolist()

View File

@@ -22,7 +22,6 @@ from vllm.v1.kv_cache_interface import MLAAttentionSpec
from vllm_ascend import envs from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import get_cos_and_sin
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector, maybe_save_kv_layer_to_connector,
@@ -32,6 +31,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.compilation.acl_graph import (get_graph_params, from vllm_ascend.compilation.acl_graph import (get_graph_params,
get_mtp_graph_params, get_mtp_graph_params,
update_graph_params_workspaces) update_graph_params_workspaces)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import ( from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series, is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series, reach_layer_for_shared_weight_series,
@@ -531,7 +531,7 @@ class AscendMLAMetadataBuilder:
decode_metadata = None decode_metadata = None
if num_decodes > 0: if num_decodes > 0:
cos, sin = get_cos_and_sin() cos, sin = get_cos_and_sin_mla()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
1].tolist() 1].tolist()

View File

@@ -16,12 +16,12 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend import envs from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import get_cos_and_sin
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata, trans_rope_weight, transdata,
wait_for_kv_layer_from_connector) wait_for_kv_layer_from_connector)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import ( from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series, is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series, reach_layer_for_shared_weight_series,
@@ -187,7 +187,7 @@ class AscendSFAMetadataBuilder:
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1] cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs] seq_lens = common_attn_metadata.seq_lens[:num_reqs]
cos, sin = get_cos_and_sin() cos, sin = get_cos_and_sin_mla()
assert self.cos_cache is not None and self.sin_cache is not None assert self.cos_cache is not None and self.sin_cache is not None
new_cos = self.cos_cache[input_positions][:, None, None] new_cos = self.cos_cache[input_positions][:, None, None]

View File

@@ -50,4 +50,7 @@ class GraphFusionPassManager:
from .passes.norm_quant_fusion_pass import \ from .passes.norm_quant_fusion_pass import \
AddRMSNormQuantFusionPass AddRMSNormQuantFusionPass
self.passes.append(AddRMSNormQuantFusionPass(config)) self.passes.append(AddRMSNormQuantFusionPass(config))
# Add more passes here as needed
if self.ascend_compilation_config.get("fuse_qknorm_rope", True):
from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass
self.passes.append(QKNormRopeFusionPass(config))

View File

@@ -0,0 +1,293 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import (PatternMatcherPass,
PatternPrettyPrinter)
from vllm.attention.layer import Attention
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import (VllmConfig, get_current_vllm_config,
get_layers_from_vllm_config)
class QKNormRopeFusionPattern:
def __init__(self,
vllm_config,
head_dim,
num_heads,
num_kv_heads,
eps=1e-6):
self.vllm_config = vllm_config
self.head_dim = head_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps
vllm_config = get_current_vllm_config()
self.device = vllm_config.device_config.device if vllm_config.device_config else None
def get_inputs(self):
T = 5
qkv = torch.empty(T,
self.q_size + 2 * self.kv_size,
dtype=torch.bfloat16,
device="npu")
q_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
k_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
cos = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
sin = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
return [qkv, q_weight, k_weight, cos, sin]
def register(self, pm_pass: PatternMatcherPass):
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
self.eps)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
self.eps)
q_flat = q_norm_out.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
self.head_dim)
k_flat = k_norm_out.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
self.head_dim)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
q_reshape, k_reshape, cos, sin)
return q_rope, k_rope, v
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=self.q_size,
kv_hidden_size=self.kv_size,
head_dim=self.head_dim,
eps=self.eps,
q_bias=None,
k_bias=None,
sin=sin,
cos=cos)
return results
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class QKNormRopeFusionPatternWithBias:
def __init__(self,
vllm_config,
head_dim,
num_heads,
num_kv_heads,
eps=1e-6):
self.head_dim = head_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps
self.vllm_config = vllm_config
self.device = vllm_config.device_config.device if vllm_config.device_config else None
def get_inputs(self):
T = 5
qkv = torch.empty(T,
self.q_size + 2 * self.kv_size,
dtype=torch.bfloat16,
device="npu")
q_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
k_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
cos = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
sin = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin]
def register(self, pm_pass: PatternMatcherPass):
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, q_bias: torch.Tensor,
k_bias: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
self.eps)
q_normed = q_norm_out + q_bias
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
self.eps)
k_normed = k_norm_out + k_bias
q_flat = q_normed.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
self.head_dim)
k_flat = k_normed.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
self.head_dim)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
q_reshape, k_reshape, cos, sin)
return q_rope, k_rope, v
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, q_bias: torch.Tensor,
k_bias: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=self.q_size,
kv_hidden_size=self.kv_size,
head_dim=self.head_dim,
eps=self.eps,
q_bias=q_bias,
k_bias=k_bias,
cos=cos,
sin=sin)
return results
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class QKNormRopeFusionPass(VllmInductorPass):
"""
A pass for fusing QKV split and RMSNorm operations into a single qk_rmsnorm operator.
"""
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
pass_name="qknorm_rope_fusion_pass")
dtype = vllm_config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logging.info(
"QKNorm and Rope fusion not enabled: unsupported dtype %s",
dtype)
return
# use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
vllm_config, Attention)
if len(attn_layers) == 0:
logging.info(
"QKNorm and Rope fusion enabled, but no Attention layers were discovered."
)
return
layer = next(iter(attn_layers.values()))
for epsilon in [1e-6, 1e-5]:
if layer.head_size != 128:
logging.debug(
"QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128",
layer.head_size)
continue
QKNormRopeFusionPattern(vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon).register(
self.pattern_match_passes)
QKNormRopeFusionPatternWithBias(vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon).register(
self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
logging.debug("Fused %s QKNorm and Rope patterns", self.matched_count)
logging.debug("Patterns registered for replacement:")
pattern_idx = 0
for pattern_entry in self.pattern_match_passes.patterns.values():
for p in pattern_entry:
p_str = PatternPrettyPrinter.run(p.pattern)
logging.debug("Pattern %d: %s", pattern_idx, p_str)
pattern_idx += 1
self.end_and_log()
def is_applicable(self, runtime_shape):
"""
Check if the pass is applicable for the current configuration.
"""
return True

View File

@@ -16,10 +16,15 @@
# #
import torch import torch
from vllm.triton_utils import HAS_TRITON
import vllm_ascend.ops.fused_moe.fused_moe # noqa import vllm_ascend.ops.fused_moe.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.register_custom_ops # noqa import vllm_ascend.ops.register_custom_ops # noqa
if HAS_TRITON:
import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_rope # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.rotary_embedding import ( from vllm_ascend.ops.rotary_embedding import (

View File

@@ -20,14 +20,117 @@ from typing import Optional, Tuple
import torch import torch
import torch_npu import torch_npu
from vllm.forward_context import get_forward_context from vllm.config import CUDAGraphMode
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding) YaRNScalingRotaryEmbedding)
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
get_ascend_device_type) get_ascend_device_type, is_vl_model)
# Currently, rope ops used on npu requires detached cos && sin as inputs.
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
# So we have to preprocess cos_sin_cache int cos && sin. In the future,
# we shall implement a new rope ops which accept cos_sin_cache as inputs.
# NOTE(Angazenn): MLA && SFA models uses attn_metadata to pass cos && sin
# to rope in AscendMLA(SFA)Impl. However, since rope is isolated from
# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by
# attn_metadata. This causes that rope in GQA models must pass cos && sin
# by different approaches.
_cos_mla: Optional[torch.Tensor] = None
_sin_mla: Optional[torch.Tensor] = None
_cos_sin_cache: Optional[torch.Tensor] = None
_cos: Optional[torch.Tensor] = None
_sin: Optional[torch.Tensor] = None
_cos_slice: Optional[torch.Tensor] = None
_sin_slice: Optional[torch.Tensor] = None
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
device):
global _cos_mla
global _sin_mla
global _cos
global _sin
if _cos_mla is not None or \
_sin_mla is not None or \
_cos is not None or \
_sin is not None:
return
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
elif not is_vl_model(vllm_config) and not vllm_config.model_config.use_mla:
rope_dim = model_config.get_head_size()
# For models using partial rope like Qwen3-Next.
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
rope_dim = int(rope_dim *
model_config.hf_text_config.partial_rotary_factor)
_cos = torch.ones(1,
max_num_batched_tokens,
1,
rope_dim,
dtype=dtype,
device=device)
_sin = torch.zeros(1,
max_num_batched_tokens,
1,
rope_dim,
dtype=dtype,
device=device)
def get_cos_and_sin_mla():
return _cos_mla, _sin_mla
def _record_cos_sin_cache(cos_sin_cache):
global _cos_sin_cache
if _cos_sin_cache is not None:
return
_cos_sin_cache = cos_sin_cache
def update_cos_sin(positions):
global _cos
global _sin
global _cos_slice
global _sin_slice
if _cos_sin_cache is None or \
_cos is None or \
_sin is None:
return
num_tokens = positions.size(0)
_cos[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[0]
_sin[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[1]
_cos_slice = _cos[:, :num_tokens]
_sin_slice = _sin[:, :num_tokens]
def get_cos_and_sin_slice():
return _cos_slice, _sin_slice
def _custom_rotary_embedding_enabled(query, neox_style, head_size): def _custom_rotary_embedding_enabled(query, neox_style, head_size):
@@ -65,8 +168,9 @@ def _rope_forward_oot(
raise NotImplementedError( raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.") "Batched rotary embedding is currently not supported on NPU.")
else: else:
if hasattr(self, "cos") and hasattr(self, "sin") and \ cos, sin = get_cos_and_sin_slice()
self.cos is not None and self.sin is not None: if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
-1] == 128 and cos is not None and sin is not None:
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
# This method requires head_size and rotary_dim equal 128 and neox_style is True # This method requires head_size and rotary_dim equal 128 and neox_style is True
query = query.contiguous().view(1, query.shape[0], -1, query = query.contiguous().view(1, query.shape[0], -1,
@@ -75,7 +179,7 @@ def _rope_forward_oot(
# Although this function modifies in-place, please retain the function's return value. # Although this function modifies in-place, please retain the function's return value.
# Otherwise, the graph fusion operation may fail. # Otherwise, the graph fusion operation may fail.
query, key = torch_npu.npu_apply_rotary_pos_emb( query, key = torch_npu.npu_apply_rotary_pos_emb(
query, key, self.cos, self.sin) query, key, cos, sin)
elif self.rotary_dim < self.head_size: elif self.rotary_dim < self.head_size:
num_tokens = query.shape[0] num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size) query = query.view(num_tokens, -1, self.head_size)
@@ -125,10 +229,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
) -> None: ) -> None:
self.cos = None
self.sin = None
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype) is_neox_style, dtype)
_record_cos_sin_cache(self.cos_sin_cache)
def forward_oot( def forward_oot(
self, self,
@@ -141,20 +244,6 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style = self.is_neox_style is_neox_style = self.is_neox_style
if is_neox_style_override is not None: if is_neox_style_override is not None:
is_neox_style = is_neox_style_override is_neox_style = is_neox_style_override
forward_context = get_forward_context()
is_first_layer = forward_context.is_first_layer
# Generate cos and sin outside layers to avoid repeated calculation.
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
-1] == 128:
if is_first_layer:
cos_sin = self.cos_sin_cache.index_select(0, positions)
last_dim = cos_sin.size()[-1]
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(
1, 1, 2).chunk(2, dim=-2)
# BSNH
self.cos = cos.view(1, -1, 1, last_dim).contiguous()
self.sin = sin.view(1, -1, 1, last_dim).contiguous()
forward_context.is_first_layer = False
return _rope_forward_oot(self, positions, query, key, is_neox_style, return _rope_forward_oot(self, positions, query, key, is_neox_style,
offsets) offsets)
@@ -176,8 +265,6 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
beta_fast: int = 32, beta_fast: int = 32,
beta_slow: int = 1, beta_slow: int = 1,
) -> None: ) -> None:
self.cos = None
self.sin = None
extra_kwargs = { extra_kwargs = {
"extrapolation_factor": extrapolation_factor, "extrapolation_factor": extrapolation_factor,
"attn_factor": attn_factor, "attn_factor": attn_factor,
@@ -186,6 +273,7 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
} }
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs) is_neox_style, scaling_factor, dtype, **extra_kwargs)
_record_cos_sin_cache(self.cos_sin_cache)
def forward_oot( def forward_oot(
self, self,

View File

@@ -0,0 +1,305 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Optional
import torch
import triton # type: ignore
import triton.language as tl # type: ignore
from vllm.utils.torch_utils import direct_register_custom_op
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
@triton.jit
def split_qkv_rmsnorm_rope_kernel(
input_ptr,
sin_ptr,
cos_ptr,
q_ptr,
k_ptr,
v_ptr,
q_weight_ptr,
q_bias_ptr,
k_weight_ptr,
k_bias_ptr,
batch_size,
q_hidden_size: tl.constexpr,
kv_hidden_size: tl.constexpr,
total_hidden_size: tl.constexpr,
eps: tl.constexpr,
Q_BLOCK_SIZE: tl.constexpr,
KV_BLOCK_SIZE: tl.constexpr,
BIAS: tl.constexpr,
HEAD_DIM: tl.constexpr,
HALF_HEAD_DIM: tl.constexpr,
):
row_pid = tl.program_id(0)
col_pid = tl.program_id(1)
row_step = tl.num_programs(0)
# q
weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM))
if BIAS:
bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM))
input_offset = row_pid * total_hidden_size
output_offset = row_pid * q_hidden_size
input_offset_step = row_step * total_hidden_size
output_offset_step = row_step * q_hidden_size
for row_idx in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE)
valid_mask = col_indices < q_hidden_size
input_values = (tl.load(input_ptr + input_offset + col_indices,
mask=valid_mask,
other=0.0).to(tl.float32).reshape(
Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM))
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
Q_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = (input_values * reciprocal_std
) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values +
bias_values).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(
tl.bfloat16)
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
x1 = tl.extract_slice(
normalized_values,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
x2 = tl.extract_slice(
normalized_values,
offsets=(0, HALF_HEAD_DIM),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
dtype=tl.bfloat16)
cat_x = tl.insert_slice(
cat_x,
-x2,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.insert_slice(
cat_x,
x1,
offsets=(0, HALF_HEAD_DIM),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_q = cat_x * sin + normalized_values * cos
tl.store(
q_ptr + output_offset + col_indices,
roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty),
mask=valid_mask,
)
input_offset += input_offset_step
output_offset += output_offset_step
weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM))
if BIAS:
bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM))
input_offset = row_pid * total_hidden_size + q_hidden_size
output_offset = row_pid * kv_hidden_size
output_offset_step = row_step * kv_hidden_size
for row_idx in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
valid_mask = col_indices < kv_hidden_size
input_values = (tl.load(input_ptr + input_offset + col_indices,
mask=valid_mask,
other=0.0).to(tl.float32).reshape(
KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM))
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
KV_BLOCK_SIZE // HEAD_DIM, 1)
normalized_values = (input_values * reciprocal_std
) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values +
bias_values).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(
tl.bfloat16)
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
x1 = tl.extract_slice(
normalized_values,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
x2 = tl.extract_slice(
normalized_values,
offsets=(0, HALF_HEAD_DIM),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
dtype=tl.bfloat16)
cat_x = tl.insert_slice(
cat_x,
-x2,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.insert_slice(
cat_x,
x1,
offsets=(0, HALF_HEAD_DIM),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_k = cat_x * sin + normalized_values * cos
tl.store(
k_ptr + output_offset + col_indices,
roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE),
mask=valid_mask,
)
input_offset += input_offset_step
output_offset += output_offset_step
input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size
output_offset = row_pid * kv_hidden_size
for _ in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
valid_mask = col_indices < kv_hidden_size
input_values = tl.load(input_ptr + input_offset + col_indices,
mask=valid_mask,
other=0.0)
tl.store(v_ptr + output_offset + col_indices,
input_values,
mask=valid_mask)
input_offset += input_offset_step
output_offset += output_offset_step
def split_qkv_rmsnorm_rope_impl(
input: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_hidden_size: int,
kv_hidden_size: int,
head_dim: int,
eps: float,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
KV_BLOCK_SIZE = triton.next_power_of_2(head_dim)
assert KV_BLOCK_SIZE == head_dim
assert q_hidden_size % kv_hidden_size == 0
Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim
batch_size = input.shape[0]
total_hidden_size = q_hidden_size + kv_hidden_size * 2
q_output = torch.empty(batch_size,
q_hidden_size,
device=input.device,
dtype=input.dtype)
k_output = torch.empty(batch_size,
kv_hidden_size,
device=input.device,
dtype=input.dtype)
v_output = torch.empty(batch_size,
kv_hidden_size,
device=input.device,
dtype=input.dtype)
n_cols = kv_hidden_size // KV_BLOCK_SIZE
num_vectorcore = get_vectorcore_num()
assert num_vectorcore % n_cols == 0
n_rows = num_vectorcore // n_cols
BIAS = q_bias is not None
split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)](
input,
sin,
cos,
q_output,
k_output,
v_output,
q_weight,
q_bias,
k_weight,
k_bias,
batch_size,
q_hidden_size,
kv_hidden_size,
total_hidden_size,
eps,
Q_BLOCK_SIZE,
KV_BLOCK_SIZE,
BIAS,
head_dim,
head_dim // 2,
)
return q_output, k_output, v_output
def split_qkv_rmsnorm_rope_impl_fake(
input: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_hidden_size: int,
kv_hidden_size: int,
head_dim: int,
eps: float,
q_bias: Optional[torch.Tensor] = None,
k_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Fake implementation for shape inference during Dynamo/AOT tracing.
# Note: sin and cos are not used in shape computation, but must be present in signature.
batch_size = input.shape[0]
q_output = torch.empty(
batch_size,
q_hidden_size,
device=input.device,
dtype=input.dtype,
)
k_output = torch.empty(
batch_size,
kv_hidden_size,
device=input.device,
dtype=input.dtype,
)
v_output = torch.empty(
batch_size,
kv_hidden_size,
device=input.device,
dtype=input.dtype,
)
return q_output, k_output, v_output
direct_register_custom_op(op_name="qkv_rmsnorm_rope",
op_func=split_qkv_rmsnorm_rope_impl,
fake_impl=split_qkv_rmsnorm_rope_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")

View File

@@ -25,6 +25,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.ops.rotary_embedding import update_cos_sin
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
PADDING_SLOT_ID = -1 PADDING_SLOT_ID = -1
@@ -143,6 +144,9 @@ class EagleProposer(Proposer):
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None, batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None): dummy_compute_logits=lambda hidden_states: None):
# update global cos, sin
update_cos_sin(self.positions[:num_tokens])
with set_ascend_forward_context(None, with set_ascend_forward_context(None,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens):
@@ -338,6 +342,8 @@ class EagleProposer(Proposer):
builder = self.runner.attn_groups[0][0].get_metadata_builder() builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = builder.build(0, common_attn_metadata, attn_metadata = builder.build(0, common_attn_metadata,
self.runner.get_model()) self.runner.get_model())
# update global cos, sin
update_cos_sin(self.positions[:num_input_tokens])
with set_ascend_forward_context(attn_metadata, with set_ascend_forward_context(attn_metadata,
self.vllm_config, self.vllm_config,
@@ -443,6 +449,10 @@ class EagleProposer(Proposer):
attn_metadata.attn_mask = attn_mask attn_metadata.attn_mask = attn_mask
# Run the model. # Run the model.
# update global cos, sin
update_cos_sin(self.positions[:input_batch_size])
with set_ascend_forward_context(attn_metadata, with set_ascend_forward_context(attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=input_batch_size): num_tokens=input_batch_size):

View File

@@ -84,12 +84,6 @@ from vllm.v1.worker.utils import AttentionGroup
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import (MoECommType,
get_mc2_tokens_capacity,
select_moe_comm_method,
set_ascend_forward_context,
set_cos_and_sin, set_mc2_mask,
set_mc2_tokens_capacity)
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
@@ -111,6 +105,7 @@ from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
from vllm_ascend.eplb.core.eplb_worker import EplbProcess from vllm_ascend.eplb.core.eplb_worker import EplbProcess
from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.eplb.utils import model_register from vllm_ascend.eplb.utils import model_register
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
from vllm_ascend.sample.logits_processor import build_logitsprocs from vllm_ascend.sample.logits_processor import build_logitsprocs
@@ -125,6 +120,10 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_moe_model, lmhead_tp_enable, vllm_version_is) is_moe_model, lmhead_tp_enable, vllm_version_is)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch from vllm_ascend.worker.npu_input_batch import NPUInputBatch
from vllm_ascend.ascend_forward_context import ( # isort: skip
MoECommType, get_mc2_tokens_capacity, select_moe_comm_method,
set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity)
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped] import xgrammar as xgr # type: ignore[import-untyped]
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
@@ -1122,6 +1121,9 @@ class NPUModelRunner(GPUModelRunner):
for layer_name in attn_group.layer_names: for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
# update global cos, sin
update_cos_sin(positions)
if lmhead_tp_enable(): if lmhead_tp_enable():
max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len
logits_indices = nn.functional.pad( logits_indices = nn.functional.pad(
@@ -2084,6 +2086,9 @@ class NPUModelRunner(GPUModelRunner):
else: else:
positions = self.positions.gpu[:num_tokens_padded] positions = self.positions.gpu[:num_tokens_padded]
# update global cos, sin
update_cos_sin(positions)
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
intermediate_tensors = None intermediate_tensors = None
else: else: