From 4020b3df60942357282062f48874f1955a3c7200 Mon Sep 17 00:00:00 2001 From: wangx700 Date: Thu, 23 Apr 2026 09:50:29 +0800 Subject: [PATCH] [BugFix] fix tl.extract_slice and tl.insert_slice. (#8567) ### What this PR does / why we need it? fix tl.extract_slice and tl.insert_slice to extract_slice and insert_slice from torch_utils ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? --------- Signed-off-by: wangx700 --- .../linearnorm/split_qkv_rmsnorm_mrope.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py index 25aa727b..7c4cf9f7 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py @@ -20,7 +20,7 @@ import torch from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op -from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num +from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice @triton.jit( @@ -86,13 +86,13 @@ def split_qkv_rmsnorm_mrope_kernel( .to(tl.float32) .reshape(num_q_heads, head_size * 2) ) - in_q_tensor = tl.extract_slice( + in_q_tensor = extract_slice( in_q_gate_tensor, offsets=(0, 0), sizes=(num_q_heads, head_size), strides=(1, 1), ) - in_gate_tensor = tl.extract_slice( + in_gate_tensor = extract_slice( in_q_gate_tensor, offsets=(0, head_size), sizes=(num_q_heads, head_size), @@ -162,27 +162,27 @@ def split_qkv_rmsnorm_mrope_kernel( k_normalized = k_normalized + k_bias # q-mrope - x1 = tl.extract_slice( + x1 = extract_slice( q_normalized, offsets=(0, 0), sizes=(num_q_heads, half_rope_dim), strides=(1, 1), ) - x2 = tl.extract_slice( + x2 = extract_slice( q_normalized, offsets=(0, half_rope_dim), sizes=(num_q_heads, half_rope_dim), strides=(1, 1), ) cat_x = tl.zeros((num_q_heads, rope_dim), dtype=tl.float32) - cat_x = tl.insert_slice( + cat_x = insert_slice( cat_x, -x2, offsets=(0, 0), sizes=(num_q_heads, half_rope_dim), strides=(1, 1), ) - cat_x = tl.insert_slice( + cat_x = insert_slice( cat_x, x1, offsets=(0, half_rope_dim), @@ -190,7 +190,7 @@ def split_qkv_rmsnorm_mrope_kernel( strides=(1, 1), ) if IS_PARTIAL_ROPE: - orig_qk = tl.extract_slice( + orig_qk = extract_slice( q_normalized, offsets=(0, 0), sizes=(num_q_heads, rope_dim), @@ -201,27 +201,27 @@ def split_qkv_rmsnorm_mrope_kernel( roped_q = cat_x * sin_tensor + orig_qk * cos_tensor # k-mrope - y1 = tl.extract_slice( + y1 = extract_slice( k_normalized, offsets=(0, 0), sizes=(num_kv_heads, half_rope_dim), strides=(1, 1), ) - y2 = tl.extract_slice( + y2 = extract_slice( k_normalized, offsets=(0, half_rope_dim), sizes=(num_kv_heads, half_rope_dim), strides=(1, 1), ) cat_y = tl.zeros((num_kv_heads, rope_dim), dtype=tl.float32) - cat_y = tl.insert_slice( + cat_y = insert_slice( cat_y, -y2, offsets=(0, 0), sizes=(num_kv_heads, half_rope_dim), strides=(1, 1), ) - cat_y = tl.insert_slice( + cat_y = insert_slice( cat_y, y1, offsets=(0, half_rope_dim), @@ -229,7 +229,7 @@ def split_qkv_rmsnorm_mrope_kernel( strides=(1, 1), ) if IS_PARTIAL_ROPE: - orig_qk = tl.extract_slice( + orig_qk = extract_slice( k_normalized, offsets=(0, 0), sizes=(num_kv_heads, rope_dim), @@ -240,14 +240,14 @@ def split_qkv_rmsnorm_mrope_kernel( roped_k = cat_y * sin_tensor + orig_qk * cos_tensor if IS_PARTIAL_ROPE: - q_normalized = tl.insert_slice( + q_normalized = insert_slice( q_normalized, roped_q, offsets=(0, 0), sizes=(num_q_heads, rope_dim), strides=(1, 1), ) - k_normalized = tl.insert_slice( + k_normalized = insert_slice( k_normalized, roped_k, offsets=(0, 0),