[BugFix] fix tl.extract_slice and tl.insert_slice. (#8567)

### What this PR does / why we need it?
fix tl.extract_slice and tl.insert_slice to extract_slice and
insert_slice from torch_utils
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

---------

Signed-off-by: wangx700 <wangxin700@huawei.com>
This commit is contained in:
wangx700
2026-04-23 09:50:29 +08:00
committed by GitHub
parent c3b1d409a9
commit 4020b3df60

View File

@@ -20,7 +20,7 @@ import torch
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice
@triton.jit(
@@ -86,13 +86,13 @@ def split_qkv_rmsnorm_mrope_kernel(
.to(tl.float32)
.reshape(num_q_heads, head_size * 2)
)
in_q_tensor = tl.extract_slice(
in_q_tensor = extract_slice(
in_q_gate_tensor,
offsets=(0, 0),
sizes=(num_q_heads, head_size),
strides=(1, 1),
)
in_gate_tensor = tl.extract_slice(
in_gate_tensor = extract_slice(
in_q_gate_tensor,
offsets=(0, head_size),
sizes=(num_q_heads, head_size),
@@ -162,27 +162,27 @@ def split_qkv_rmsnorm_mrope_kernel(
k_normalized = k_normalized + k_bias
# q-mrope
x1 = tl.extract_slice(
x1 = extract_slice(
q_normalized,
offsets=(0, 0),
sizes=(num_q_heads, half_rope_dim),
strides=(1, 1),
)
x2 = tl.extract_slice(
x2 = extract_slice(
q_normalized,
offsets=(0, half_rope_dim),
sizes=(num_q_heads, half_rope_dim),
strides=(1, 1),
)
cat_x = tl.zeros((num_q_heads, rope_dim), dtype=tl.float32)
cat_x = tl.insert_slice(
cat_x = insert_slice(
cat_x,
-x2,
offsets=(0, 0),
sizes=(num_q_heads, half_rope_dim),
strides=(1, 1),
)
cat_x = tl.insert_slice(
cat_x = insert_slice(
cat_x,
x1,
offsets=(0, half_rope_dim),
@@ -190,7 +190,7 @@ def split_qkv_rmsnorm_mrope_kernel(
strides=(1, 1),
)
if IS_PARTIAL_ROPE:
orig_qk = tl.extract_slice(
orig_qk = extract_slice(
q_normalized,
offsets=(0, 0),
sizes=(num_q_heads, rope_dim),
@@ -201,27 +201,27 @@ def split_qkv_rmsnorm_mrope_kernel(
roped_q = cat_x * sin_tensor + orig_qk * cos_tensor
# k-mrope
y1 = tl.extract_slice(
y1 = extract_slice(
k_normalized,
offsets=(0, 0),
sizes=(num_kv_heads, half_rope_dim),
strides=(1, 1),
)
y2 = tl.extract_slice(
y2 = extract_slice(
k_normalized,
offsets=(0, half_rope_dim),
sizes=(num_kv_heads, half_rope_dim),
strides=(1, 1),
)
cat_y = tl.zeros((num_kv_heads, rope_dim), dtype=tl.float32)
cat_y = tl.insert_slice(
cat_y = insert_slice(
cat_y,
-y2,
offsets=(0, 0),
sizes=(num_kv_heads, half_rope_dim),
strides=(1, 1),
)
cat_y = tl.insert_slice(
cat_y = insert_slice(
cat_y,
y1,
offsets=(0, half_rope_dim),
@@ -229,7 +229,7 @@ def split_qkv_rmsnorm_mrope_kernel(
strides=(1, 1),
)
if IS_PARTIAL_ROPE:
orig_qk = tl.extract_slice(
orig_qk = extract_slice(
k_normalized,
offsets=(0, 0),
sizes=(num_kv_heads, rope_dim),
@@ -240,14 +240,14 @@ def split_qkv_rmsnorm_mrope_kernel(
roped_k = cat_y * sin_tensor + orig_qk * cos_tensor
if IS_PARTIAL_ROPE:
q_normalized = tl.insert_slice(
q_normalized = insert_slice(
q_normalized,
roped_q,
offsets=(0, 0),
sizes=(num_q_heads, rope_dim),
strides=(1, 1),
)
k_normalized = tl.insert_slice(
k_normalized = insert_slice(
k_normalized,
roped_k,
offsets=(0, 0),