[BugFix] fix tl.extract_slice and tl.insert_slice. (#8567)
### What this PR does / why we need it? fix tl.extract_slice and tl.insert_slice to extract_slice and insert_slice from torch_utils ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? --------- Signed-off-by: wangx700 <wangxin700@huawei.com>
This commit is contained in:
@@ -20,7 +20,7 @@ import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice
|
||||
|
||||
|
||||
@triton.jit(
|
||||
@@ -86,13 +86,13 @@ def split_qkv_rmsnorm_mrope_kernel(
|
||||
.to(tl.float32)
|
||||
.reshape(num_q_heads, head_size * 2)
|
||||
)
|
||||
in_q_tensor = tl.extract_slice(
|
||||
in_q_tensor = extract_slice(
|
||||
in_q_gate_tensor,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_q_heads, head_size),
|
||||
strides=(1, 1),
|
||||
)
|
||||
in_gate_tensor = tl.extract_slice(
|
||||
in_gate_tensor = extract_slice(
|
||||
in_q_gate_tensor,
|
||||
offsets=(0, head_size),
|
||||
sizes=(num_q_heads, head_size),
|
||||
@@ -162,27 +162,27 @@ def split_qkv_rmsnorm_mrope_kernel(
|
||||
k_normalized = k_normalized + k_bias
|
||||
|
||||
# q-mrope
|
||||
x1 = tl.extract_slice(
|
||||
x1 = extract_slice(
|
||||
q_normalized,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_q_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
x2 = tl.extract_slice(
|
||||
x2 = extract_slice(
|
||||
q_normalized,
|
||||
offsets=(0, half_rope_dim),
|
||||
sizes=(num_q_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.zeros((num_q_heads, rope_dim), dtype=tl.float32)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x = insert_slice(
|
||||
cat_x,
|
||||
-x2,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_q_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_x = tl.insert_slice(
|
||||
cat_x = insert_slice(
|
||||
cat_x,
|
||||
x1,
|
||||
offsets=(0, half_rope_dim),
|
||||
@@ -190,7 +190,7 @@ def split_qkv_rmsnorm_mrope_kernel(
|
||||
strides=(1, 1),
|
||||
)
|
||||
if IS_PARTIAL_ROPE:
|
||||
orig_qk = tl.extract_slice(
|
||||
orig_qk = extract_slice(
|
||||
q_normalized,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_q_heads, rope_dim),
|
||||
@@ -201,27 +201,27 @@ def split_qkv_rmsnorm_mrope_kernel(
|
||||
roped_q = cat_x * sin_tensor + orig_qk * cos_tensor
|
||||
|
||||
# k-mrope
|
||||
y1 = tl.extract_slice(
|
||||
y1 = extract_slice(
|
||||
k_normalized,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_kv_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
y2 = tl.extract_slice(
|
||||
y2 = extract_slice(
|
||||
k_normalized,
|
||||
offsets=(0, half_rope_dim),
|
||||
sizes=(num_kv_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_y = tl.zeros((num_kv_heads, rope_dim), dtype=tl.float32)
|
||||
cat_y = tl.insert_slice(
|
||||
cat_y = insert_slice(
|
||||
cat_y,
|
||||
-y2,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_kv_heads, half_rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
cat_y = tl.insert_slice(
|
||||
cat_y = insert_slice(
|
||||
cat_y,
|
||||
y1,
|
||||
offsets=(0, half_rope_dim),
|
||||
@@ -229,7 +229,7 @@ def split_qkv_rmsnorm_mrope_kernel(
|
||||
strides=(1, 1),
|
||||
)
|
||||
if IS_PARTIAL_ROPE:
|
||||
orig_qk = tl.extract_slice(
|
||||
orig_qk = extract_slice(
|
||||
k_normalized,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_kv_heads, rope_dim),
|
||||
@@ -240,14 +240,14 @@ def split_qkv_rmsnorm_mrope_kernel(
|
||||
roped_k = cat_y * sin_tensor + orig_qk * cos_tensor
|
||||
|
||||
if IS_PARTIAL_ROPE:
|
||||
q_normalized = tl.insert_slice(
|
||||
q_normalized = insert_slice(
|
||||
q_normalized,
|
||||
roped_q,
|
||||
offsets=(0, 0),
|
||||
sizes=(num_q_heads, rope_dim),
|
||||
strides=(1, 1),
|
||||
)
|
||||
k_normalized = tl.insert_slice(
|
||||
k_normalized = insert_slice(
|
||||
k_normalized,
|
||||
roped_k,
|
||||
offsets=(0, 0),
|
||||
|
||||
Reference in New Issue
Block a user