[BugFix] fix tl.extract_slice and tl.insert_slice. (#8567)
### What this PR does / why we need it? fix tl.extract_slice and tl.insert_slice to extract_slice and insert_slice from torch_utils ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? --------- Signed-off-by: wangx700 <wangxin700@huawei.com>
This commit is contained in:
@@ -20,7 +20,7 @@ import torch
|
|||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice
|
||||||
|
|
||||||
|
|
||||||
@triton.jit(
|
@triton.jit(
|
||||||
@@ -86,13 +86,13 @@ def split_qkv_rmsnorm_mrope_kernel(
|
|||||||
.to(tl.float32)
|
.to(tl.float32)
|
||||||
.reshape(num_q_heads, head_size * 2)
|
.reshape(num_q_heads, head_size * 2)
|
||||||
)
|
)
|
||||||
in_q_tensor = tl.extract_slice(
|
in_q_tensor = extract_slice(
|
||||||
in_q_gate_tensor,
|
in_q_gate_tensor,
|
||||||
offsets=(0, 0),
|
offsets=(0, 0),
|
||||||
sizes=(num_q_heads, head_size),
|
sizes=(num_q_heads, head_size),
|
||||||
strides=(1, 1),
|
strides=(1, 1),
|
||||||
)
|
)
|
||||||
in_gate_tensor = tl.extract_slice(
|
in_gate_tensor = extract_slice(
|
||||||
in_q_gate_tensor,
|
in_q_gate_tensor,
|
||||||
offsets=(0, head_size),
|
offsets=(0, head_size),
|
||||||
sizes=(num_q_heads, head_size),
|
sizes=(num_q_heads, head_size),
|
||||||
@@ -162,27 +162,27 @@ def split_qkv_rmsnorm_mrope_kernel(
|
|||||||
k_normalized = k_normalized + k_bias
|
k_normalized = k_normalized + k_bias
|
||||||
|
|
||||||
# q-mrope
|
# q-mrope
|
||||||
x1 = tl.extract_slice(
|
x1 = extract_slice(
|
||||||
q_normalized,
|
q_normalized,
|
||||||
offsets=(0, 0),
|
offsets=(0, 0),
|
||||||
sizes=(num_q_heads, half_rope_dim),
|
sizes=(num_q_heads, half_rope_dim),
|
||||||
strides=(1, 1),
|
strides=(1, 1),
|
||||||
)
|
)
|
||||||
x2 = tl.extract_slice(
|
x2 = extract_slice(
|
||||||
q_normalized,
|
q_normalized,
|
||||||
offsets=(0, half_rope_dim),
|
offsets=(0, half_rope_dim),
|
||||||
sizes=(num_q_heads, half_rope_dim),
|
sizes=(num_q_heads, half_rope_dim),
|
||||||
strides=(1, 1),
|
strides=(1, 1),
|
||||||
)
|
)
|
||||||
cat_x = tl.zeros((num_q_heads, rope_dim), dtype=tl.float32)
|
cat_x = tl.zeros((num_q_heads, rope_dim), dtype=tl.float32)
|
||||||
cat_x = tl.insert_slice(
|
cat_x = insert_slice(
|
||||||
cat_x,
|
cat_x,
|
||||||
-x2,
|
-x2,
|
||||||
offsets=(0, 0),
|
offsets=(0, 0),
|
||||||
sizes=(num_q_heads, half_rope_dim),
|
sizes=(num_q_heads, half_rope_dim),
|
||||||
strides=(1, 1),
|
strides=(1, 1),
|
||||||
)
|
)
|
||||||
cat_x = tl.insert_slice(
|
cat_x = insert_slice(
|
||||||
cat_x,
|
cat_x,
|
||||||
x1,
|
x1,
|
||||||
offsets=(0, half_rope_dim),
|
offsets=(0, half_rope_dim),
|
||||||
@@ -190,7 +190,7 @@ def split_qkv_rmsnorm_mrope_kernel(
|
|||||||
strides=(1, 1),
|
strides=(1, 1),
|
||||||
)
|
)
|
||||||
if IS_PARTIAL_ROPE:
|
if IS_PARTIAL_ROPE:
|
||||||
orig_qk = tl.extract_slice(
|
orig_qk = extract_slice(
|
||||||
q_normalized,
|
q_normalized,
|
||||||
offsets=(0, 0),
|
offsets=(0, 0),
|
||||||
sizes=(num_q_heads, rope_dim),
|
sizes=(num_q_heads, rope_dim),
|
||||||
@@ -201,27 +201,27 @@ def split_qkv_rmsnorm_mrope_kernel(
|
|||||||
roped_q = cat_x * sin_tensor + orig_qk * cos_tensor
|
roped_q = cat_x * sin_tensor + orig_qk * cos_tensor
|
||||||
|
|
||||||
# k-mrope
|
# k-mrope
|
||||||
y1 = tl.extract_slice(
|
y1 = extract_slice(
|
||||||
k_normalized,
|
k_normalized,
|
||||||
offsets=(0, 0),
|
offsets=(0, 0),
|
||||||
sizes=(num_kv_heads, half_rope_dim),
|
sizes=(num_kv_heads, half_rope_dim),
|
||||||
strides=(1, 1),
|
strides=(1, 1),
|
||||||
)
|
)
|
||||||
y2 = tl.extract_slice(
|
y2 = extract_slice(
|
||||||
k_normalized,
|
k_normalized,
|
||||||
offsets=(0, half_rope_dim),
|
offsets=(0, half_rope_dim),
|
||||||
sizes=(num_kv_heads, half_rope_dim),
|
sizes=(num_kv_heads, half_rope_dim),
|
||||||
strides=(1, 1),
|
strides=(1, 1),
|
||||||
)
|
)
|
||||||
cat_y = tl.zeros((num_kv_heads, rope_dim), dtype=tl.float32)
|
cat_y = tl.zeros((num_kv_heads, rope_dim), dtype=tl.float32)
|
||||||
cat_y = tl.insert_slice(
|
cat_y = insert_slice(
|
||||||
cat_y,
|
cat_y,
|
||||||
-y2,
|
-y2,
|
||||||
offsets=(0, 0),
|
offsets=(0, 0),
|
||||||
sizes=(num_kv_heads, half_rope_dim),
|
sizes=(num_kv_heads, half_rope_dim),
|
||||||
strides=(1, 1),
|
strides=(1, 1),
|
||||||
)
|
)
|
||||||
cat_y = tl.insert_slice(
|
cat_y = insert_slice(
|
||||||
cat_y,
|
cat_y,
|
||||||
y1,
|
y1,
|
||||||
offsets=(0, half_rope_dim),
|
offsets=(0, half_rope_dim),
|
||||||
@@ -229,7 +229,7 @@ def split_qkv_rmsnorm_mrope_kernel(
|
|||||||
strides=(1, 1),
|
strides=(1, 1),
|
||||||
)
|
)
|
||||||
if IS_PARTIAL_ROPE:
|
if IS_PARTIAL_ROPE:
|
||||||
orig_qk = tl.extract_slice(
|
orig_qk = extract_slice(
|
||||||
k_normalized,
|
k_normalized,
|
||||||
offsets=(0, 0),
|
offsets=(0, 0),
|
||||||
sizes=(num_kv_heads, rope_dim),
|
sizes=(num_kv_heads, rope_dim),
|
||||||
@@ -240,14 +240,14 @@ def split_qkv_rmsnorm_mrope_kernel(
|
|||||||
roped_k = cat_y * sin_tensor + orig_qk * cos_tensor
|
roped_k = cat_y * sin_tensor + orig_qk * cos_tensor
|
||||||
|
|
||||||
if IS_PARTIAL_ROPE:
|
if IS_PARTIAL_ROPE:
|
||||||
q_normalized = tl.insert_slice(
|
q_normalized = insert_slice(
|
||||||
q_normalized,
|
q_normalized,
|
||||||
roped_q,
|
roped_q,
|
||||||
offsets=(0, 0),
|
offsets=(0, 0),
|
||||||
sizes=(num_q_heads, rope_dim),
|
sizes=(num_q_heads, rope_dim),
|
||||||
strides=(1, 1),
|
strides=(1, 1),
|
||||||
)
|
)
|
||||||
k_normalized = tl.insert_slice(
|
k_normalized = insert_slice(
|
||||||
k_normalized,
|
k_normalized,
|
||||||
roped_k,
|
roped_k,
|
||||||
offsets=(0, 0),
|
offsets=(0, 0),
|
||||||
|
|||||||
Reference in New Issue
Block a user