[Fusion] [Graph] Add qknorm rope fusion operator (#4711)
### What this PR does / why we need it?
This PR add `qkv_rmsnorm_rope` operator and introduces a graph fusion
pass for `qknorm_rope` operations. The implementation includes a new
configuration flag, a pattern matching pass using
`torch._inductor.pattern_matcher`, and a custom Triton kernel for the
fused operation.
Co-authored-by: Angazenn
[supperccell@163.com](mailto:supperccell@163.com)
### Does this PR introduce _any_ user-facing change?
Yes, add new additional_config
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
@@ -16,10 +16,15 @@
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
import vllm_ascend.ops.fused_moe.fused_moe # noqa
|
||||
import vllm_ascend.ops.layernorm # noqa
|
||||
import vllm_ascend.ops.register_custom_ops # noqa
|
||||
|
||||
if HAS_TRITON:
|
||||
import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_rope # noqa
|
||||
|
||||
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
|
||||
@@ -20,14 +20,117 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding)
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
||||
get_ascend_device_type)
|
||||
get_ascend_device_type, is_vl_model)
|
||||
|
||||
# Currently, rope ops used on npu requires detached cos && sin as inputs.
|
||||
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
|
||||
# So we have to preprocess cos_sin_cache int cos && sin. In the future,
|
||||
# we shall implement a new rope ops which accept cos_sin_cache as inputs.
|
||||
# NOTE(Angazenn): MLA && SFA models uses attn_metadata to pass cos && sin
|
||||
# to rope in AscendMLA(SFA)Impl. However, since rope is isolated from
|
||||
# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by
|
||||
# attn_metadata. This causes that rope in GQA models must pass cos && sin
|
||||
# by different approaches.
|
||||
_cos_mla: Optional[torch.Tensor] = None
|
||||
_sin_mla: Optional[torch.Tensor] = None
|
||||
_cos_sin_cache: Optional[torch.Tensor] = None
|
||||
_cos: Optional[torch.Tensor] = None
|
||||
_sin: Optional[torch.Tensor] = None
|
||||
_cos_slice: Optional[torch.Tensor] = None
|
||||
_sin_slice: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
|
||||
device):
|
||||
global _cos_mla
|
||||
global _sin_mla
|
||||
global _cos
|
||||
global _sin
|
||||
|
||||
if _cos_mla is not None or \
|
||||
_sin_mla is not None or \
|
||||
_cos is not None or \
|
||||
_sin is not None:
|
||||
return
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
model_config = vllm_config.model_config
|
||||
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
|
||||
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
rope_dim = model_config.hf_text_config.qk_rope_head_dim
|
||||
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req,
|
||||
1,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
elif not is_vl_model(vllm_config) and not vllm_config.model_config.use_mla:
|
||||
rope_dim = model_config.get_head_size()
|
||||
# For models using partial rope like Qwen3-Next.
|
||||
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
|
||||
rope_dim = int(rope_dim *
|
||||
model_config.hf_text_config.partial_rotary_factor)
|
||||
_cos = torch.ones(1,
|
||||
max_num_batched_tokens,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
_sin = torch.zeros(1,
|
||||
max_num_batched_tokens,
|
||||
1,
|
||||
rope_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
|
||||
def get_cos_and_sin_mla():
|
||||
return _cos_mla, _sin_mla
|
||||
|
||||
|
||||
def _record_cos_sin_cache(cos_sin_cache):
|
||||
global _cos_sin_cache
|
||||
if _cos_sin_cache is not None:
|
||||
return
|
||||
_cos_sin_cache = cos_sin_cache
|
||||
|
||||
|
||||
def update_cos_sin(positions):
|
||||
global _cos
|
||||
global _sin
|
||||
global _cos_slice
|
||||
global _sin_slice
|
||||
|
||||
if _cos_sin_cache is None or \
|
||||
_cos is None or \
|
||||
_sin is None:
|
||||
return
|
||||
|
||||
num_tokens = positions.size(0)
|
||||
_cos[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
|
||||
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[0]
|
||||
_sin[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
|
||||
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[1]
|
||||
_cos_slice = _cos[:, :num_tokens]
|
||||
_sin_slice = _sin[:, :num_tokens]
|
||||
|
||||
|
||||
def get_cos_and_sin_slice():
|
||||
return _cos_slice, _sin_slice
|
||||
|
||||
|
||||
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
@@ -65,8 +168,9 @@ def _rope_forward_oot(
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
if hasattr(self, "cos") and hasattr(self, "sin") and \
|
||||
self.cos is not None and self.sin is not None:
|
||||
cos, sin = get_cos_and_sin_slice()
|
||||
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
|
||||
-1] == 128 and cos is not None and sin is not None:
|
||||
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
|
||||
# This method requires head_size and rotary_dim equal 128 and neox_style is True
|
||||
query = query.contiguous().view(1, query.shape[0], -1,
|
||||
@@ -75,7 +179,7 @@ def _rope_forward_oot(
|
||||
# Although this function modifies in-place, please retain the function's return value.
|
||||
# Otherwise, the graph fusion operation may fail.
|
||||
query, key = torch_npu.npu_apply_rotary_pos_emb(
|
||||
query, key, self.cos, self.sin)
|
||||
query, key, cos, sin)
|
||||
elif self.rotary_dim < self.head_size:
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
@@ -125,10 +229,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.cos = None
|
||||
self.sin = None
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
_record_cos_sin_cache(self.cos_sin_cache)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
@@ -141,20 +244,6 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
is_neox_style = is_neox_style_override
|
||||
forward_context = get_forward_context()
|
||||
is_first_layer = forward_context.is_first_layer
|
||||
# Generate cos and sin outside layers to avoid repeated calculation.
|
||||
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
|
||||
-1] == 128:
|
||||
if is_first_layer:
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
last_dim = cos_sin.size()[-1]
|
||||
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(
|
||||
1, 1, 2).chunk(2, dim=-2)
|
||||
# BSNH
|
||||
self.cos = cos.view(1, -1, 1, last_dim).contiguous()
|
||||
self.sin = sin.view(1, -1, 1, last_dim).contiguous()
|
||||
forward_context.is_first_layer = False
|
||||
return _rope_forward_oot(self, positions, query, key, is_neox_style,
|
||||
offsets)
|
||||
|
||||
@@ -176,8 +265,6 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
) -> None:
|
||||
self.cos = None
|
||||
self.sin = None
|
||||
extra_kwargs = {
|
||||
"extrapolation_factor": extrapolation_factor,
|
||||
"attn_factor": attn_factor,
|
||||
@@ -186,6 +273,7 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
||||
}
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
||||
_record_cos_sin_cache(self.cos_sin_cache)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
|
||||
0
vllm_ascend/ops/triton/linearnorm/__init__.py
Normal file
0
vllm_ascend/ops/triton/linearnorm/__init__.py
Normal file
305
vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py
Normal file
305
vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py
Normal file
@@ -0,0 +1,305 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton # type: ignore
|
||||
import triton.language as tl # type: ignore
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
|
||||
|
||||
@triton.jit
|
||||
def split_qkv_rmsnorm_rope_kernel(
|
||||
input_ptr,
|
||||
sin_ptr,
|
||||
cos_ptr,
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
q_weight_ptr,
|
||||
q_bias_ptr,
|
||||
k_weight_ptr,
|
||||
k_bias_ptr,
|
||||
batch_size,
|
||||
q_hidden_size: tl.constexpr,
|
||||
kv_hidden_size: tl.constexpr,
|
||||
total_hidden_size: tl.constexpr,
|
||||
eps: tl.constexpr,
|
||||
Q_BLOCK_SIZE: tl.constexpr,
|
||||
KV_BLOCK_SIZE: tl.constexpr,
|
||||
BIAS: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
HALF_HEAD_DIM: tl.constexpr,
|
||||
):
|
||||
row_pid = tl.program_id(0)
|
||||
col_pid = tl.program_id(1)
|
||||
row_step = tl.num_programs(0)
|
||||
# q
|
||||
weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM))
|
||||
if BIAS:
|
||||
bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM))
|
||||
input_offset = row_pid * total_hidden_size
|
||||
output_offset = row_pid * q_hidden_size
|
||||
input_offset_step = row_step * total_hidden_size
|
||||
output_offset_step = row_step * q_hidden_size
|
||||
for row_idx in tl.range(row_pid, batch_size, row_step):
|
||||
col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE)
|
||||
valid_mask = col_indices < q_hidden_size
|
||||
input_values = (tl.load(input_ptr + input_offset + col_indices,
|
||||
mask=valid_mask,
|
||||
other=0.0).to(tl.float32).reshape(
|
||||
Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM))
|
||||
squares = input_values * input_values
|
||||
variances = tl.sum(squares, axis=1) / HEAD_DIM
|
||||
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
|
||||
Q_BLOCK_SIZE // HEAD_DIM, 1)
|
||||
normalized_values = (input_values * reciprocal_std
|
||||
) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
|
||||
if BIAS:
|
||||
normalized_values = (normalized_values * weight_values +
|
||||
bias_values).to(tl.bfloat16)
|
||||
else:
|
||||
normalized_values = (normalized_values * weight_values).to(
|
||||
tl.bfloat16)
|
||||
|
||||
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
|
||||
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
x1 = tl.extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
x2 = tl.extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
|
||||
dtype=tl.bfloat16)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
-x2,
|
||||
offsets=(0, 0),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
x1,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
roped_q = cat_x * sin + normalized_values * cos
|
||||
tl.store(
|
||||
q_ptr + output_offset + col_indices,
|
||||
roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty),
|
||||
mask=valid_mask,
|
||||
)
|
||||
input_offset += input_offset_step
|
||||
output_offset += output_offset_step
|
||||
|
||||
weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM))
|
||||
if BIAS:
|
||||
bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM))
|
||||
input_offset = row_pid * total_hidden_size + q_hidden_size
|
||||
output_offset = row_pid * kv_hidden_size
|
||||
output_offset_step = row_step * kv_hidden_size
|
||||
for row_idx in tl.range(row_pid, batch_size, row_step):
|
||||
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
|
||||
valid_mask = col_indices < kv_hidden_size
|
||||
input_values = (tl.load(input_ptr + input_offset + col_indices,
|
||||
mask=valid_mask,
|
||||
other=0.0).to(tl.float32).reshape(
|
||||
KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM))
|
||||
squares = input_values * input_values
|
||||
variances = tl.sum(squares, axis=1) / HEAD_DIM
|
||||
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
|
||||
KV_BLOCK_SIZE // HEAD_DIM, 1)
|
||||
normalized_values = (input_values * reciprocal_std
|
||||
) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
|
||||
if BIAS:
|
||||
normalized_values = (normalized_values * weight_values +
|
||||
bias_values).to(tl.bfloat16)
|
||||
else:
|
||||
normalized_values = (normalized_values * weight_values).to(
|
||||
tl.bfloat16)
|
||||
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
|
||||
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
|
||||
x1 = tl.extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
x2 = tl.extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
|
||||
dtype=tl.bfloat16)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
-x2,
|
||||
offsets=(0, 0),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x,
|
||||
x1,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
roped_k = cat_x * sin + normalized_values * cos
|
||||
|
||||
tl.store(
|
||||
k_ptr + output_offset + col_indices,
|
||||
roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE),
|
||||
mask=valid_mask,
|
||||
)
|
||||
input_offset += input_offset_step
|
||||
output_offset += output_offset_step
|
||||
|
||||
input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size
|
||||
output_offset = row_pid * kv_hidden_size
|
||||
for _ in tl.range(row_pid, batch_size, row_step):
|
||||
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
|
||||
valid_mask = col_indices < kv_hidden_size
|
||||
input_values = tl.load(input_ptr + input_offset + col_indices,
|
||||
mask=valid_mask,
|
||||
other=0.0)
|
||||
tl.store(v_ptr + output_offset + col_indices,
|
||||
input_values,
|
||||
mask=valid_mask)
|
||||
input_offset += input_offset_step
|
||||
output_offset += output_offset_step
|
||||
|
||||
|
||||
def split_qkv_rmsnorm_rope_impl(
|
||||
input: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_hidden_size: int,
|
||||
kv_hidden_size: int,
|
||||
head_dim: int,
|
||||
eps: float,
|
||||
q_bias: Optional[torch.Tensor],
|
||||
k_bias: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
KV_BLOCK_SIZE = triton.next_power_of_2(head_dim)
|
||||
assert KV_BLOCK_SIZE == head_dim
|
||||
assert q_hidden_size % kv_hidden_size == 0
|
||||
Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim
|
||||
batch_size = input.shape[0]
|
||||
total_hidden_size = q_hidden_size + kv_hidden_size * 2
|
||||
q_output = torch.empty(batch_size,
|
||||
q_hidden_size,
|
||||
device=input.device,
|
||||
dtype=input.dtype)
|
||||
k_output = torch.empty(batch_size,
|
||||
kv_hidden_size,
|
||||
device=input.device,
|
||||
dtype=input.dtype)
|
||||
v_output = torch.empty(batch_size,
|
||||
kv_hidden_size,
|
||||
device=input.device,
|
||||
dtype=input.dtype)
|
||||
n_cols = kv_hidden_size // KV_BLOCK_SIZE
|
||||
num_vectorcore = get_vectorcore_num()
|
||||
assert num_vectorcore % n_cols == 0
|
||||
n_rows = num_vectorcore // n_cols
|
||||
BIAS = q_bias is not None
|
||||
|
||||
split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)](
|
||||
input,
|
||||
sin,
|
||||
cos,
|
||||
q_output,
|
||||
k_output,
|
||||
v_output,
|
||||
q_weight,
|
||||
q_bias,
|
||||
k_weight,
|
||||
k_bias,
|
||||
batch_size,
|
||||
q_hidden_size,
|
||||
kv_hidden_size,
|
||||
total_hidden_size,
|
||||
eps,
|
||||
Q_BLOCK_SIZE,
|
||||
KV_BLOCK_SIZE,
|
||||
BIAS,
|
||||
head_dim,
|
||||
head_dim // 2,
|
||||
)
|
||||
return q_output, k_output, v_output
|
||||
|
||||
|
||||
def split_qkv_rmsnorm_rope_impl_fake(
|
||||
input: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_hidden_size: int,
|
||||
kv_hidden_size: int,
|
||||
head_dim: int,
|
||||
eps: float,
|
||||
q_bias: Optional[torch.Tensor] = None,
|
||||
k_bias: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Fake implementation for shape inference during Dynamo/AOT tracing.
|
||||
# Note: sin and cos are not used in shape computation, but must be present in signature.
|
||||
batch_size = input.shape[0]
|
||||
q_output = torch.empty(
|
||||
batch_size,
|
||||
q_hidden_size,
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
k_output = torch.empty(
|
||||
batch_size,
|
||||
kv_hidden_size,
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
v_output = torch.empty(
|
||||
batch_size,
|
||||
kv_hidden_size,
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
return q_output, k_output, v_output
|
||||
|
||||
|
||||
direct_register_custom_op(op_name="qkv_rmsnorm_rope",
|
||||
op_func=split_qkv_rmsnorm_rope_impl,
|
||||
fake_impl=split_qkv_rmsnorm_rope_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
Reference in New Issue
Block a user