diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py index 25aa727b..7c4cf9f7 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_mrope.py @@ -20,7 +20,7 @@ import torch from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op -from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num +from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice @triton.jit( @@ -86,13 +86,13 @@ def split_qkv_rmsnorm_mrope_kernel( .to(tl.float32) .reshape(num_q_heads, head_size * 2) ) - in_q_tensor = tl.extract_slice( + in_q_tensor = extract_slice( in_q_gate_tensor, offsets=(0, 0), sizes=(num_q_heads, head_size), strides=(1, 1), ) - in_gate_tensor = tl.extract_slice( + in_gate_tensor = extract_slice( in_q_gate_tensor, offsets=(0, head_size), sizes=(num_q_heads, head_size), @@ -162,27 +162,27 @@ def split_qkv_rmsnorm_mrope_kernel( k_normalized = k_normalized + k_bias # q-mrope - x1 = tl.extract_slice( + x1 = extract_slice( q_normalized, offsets=(0, 0), sizes=(num_q_heads, half_rope_dim), strides=(1, 1), ) - x2 = tl.extract_slice( + x2 = extract_slice( q_normalized, offsets=(0, half_rope_dim), sizes=(num_q_heads, half_rope_dim), strides=(1, 1), ) cat_x = tl.zeros((num_q_heads, rope_dim), dtype=tl.float32) - cat_x = tl.insert_slice( + cat_x = insert_slice( cat_x, -x2, offsets=(0, 0), sizes=(num_q_heads, half_rope_dim), strides=(1, 1), ) - cat_x = tl.insert_slice( + cat_x = insert_slice( cat_x, x1, offsets=(0, half_rope_dim), @@ -190,7 +190,7 @@ def split_qkv_rmsnorm_mrope_kernel( strides=(1, 1), ) if IS_PARTIAL_ROPE: - orig_qk = tl.extract_slice( + orig_qk = extract_slice( q_normalized, offsets=(0, 0), sizes=(num_q_heads, rope_dim), @@ -201,27 +201,27 @@ def split_qkv_rmsnorm_mrope_kernel( roped_q = cat_x * sin_tensor + orig_qk * cos_tensor # k-mrope - y1 = tl.extract_slice( + y1 = extract_slice( k_normalized, offsets=(0, 0), sizes=(num_kv_heads, half_rope_dim), strides=(1, 1), ) - y2 = tl.extract_slice( + y2 = extract_slice( k_normalized, offsets=(0, half_rope_dim), sizes=(num_kv_heads, half_rope_dim), strides=(1, 1), ) cat_y = tl.zeros((num_kv_heads, rope_dim), dtype=tl.float32) - cat_y = tl.insert_slice( + cat_y = insert_slice( cat_y, -y2, offsets=(0, 0), sizes=(num_kv_heads, half_rope_dim), strides=(1, 1), ) - cat_y = tl.insert_slice( + cat_y = insert_slice( cat_y, y1, offsets=(0, half_rope_dim), @@ -229,7 +229,7 @@ def split_qkv_rmsnorm_mrope_kernel( strides=(1, 1), ) if IS_PARTIAL_ROPE: - orig_qk = tl.extract_slice( + orig_qk = extract_slice( k_normalized, offsets=(0, 0), sizes=(num_kv_heads, rope_dim), @@ -240,14 +240,14 @@ def split_qkv_rmsnorm_mrope_kernel( roped_k = cat_y * sin_tensor + orig_qk * cos_tensor if IS_PARTIAL_ROPE: - q_normalized = tl.insert_slice( + q_normalized = insert_slice( q_normalized, roped_q, offsets=(0, 0), sizes=(num_q_heads, rope_dim), strides=(1, 1), ) - k_normalized = tl.insert_slice( + k_normalized = insert_slice( k_normalized, roped_k, offsets=(0, 0),