[Ops][Triton] Add a triton kernel supporting partial rope. (#4413)
### What this PR does / why we need it? This PR adds a triton rope kernel witch supports scenarios of `rope_dim != head_dim`. This can save the split op before rope and the concat op after rope. Profiling shows improvement. ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? I will add related ut after ci integrated with triton. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
@@ -16,6 +17,7 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_enable_nz)
|
||||
@@ -492,35 +494,50 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
cos = attn_metadata.cos
|
||||
sin = attn_metadata.sin
|
||||
|
||||
cos_q, sin_q = cos, sin
|
||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
|
||||
# q process in new stream
|
||||
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||
q = q.view(-1, self.n_head, self.head_dim) # [b,s,64,128]
|
||||
q_pe, q_nope = torch.split(
|
||||
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64,64+64]
|
||||
|
||||
q_pe = q_pe.unsqueeze(2)
|
||||
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
|
||||
q_pe = q_pe.squeeze(2)
|
||||
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
|
||||
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
||||
|
||||
k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
|
||||
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
k_proj, need_gather_q_kv)
|
||||
k = self.k_norm(k_proj).unsqueeze(1)
|
||||
k_pe, k_nope = torch.split(
|
||||
k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64+64]
|
||||
k = k.view(-1, 1, self.head_dim)
|
||||
|
||||
k_pe = k_pe.unsqueeze(2)
|
||||
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
|
||||
k_pe = k_pe.squeeze(2)
|
||||
if HAS_TRITON:
|
||||
cos = cos.view(-1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, self.qk_rope_head_dim)
|
||||
q, k = rope_forward_triton(q,
|
||||
k,
|
||||
cos,
|
||||
sin,
|
||||
rope_dim=self.qk_rope_head_dim,
|
||||
is_neox_style=True)
|
||||
else:
|
||||
cos_q, sin_q = cos, sin
|
||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
|
||||
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
|
||||
q_pe, q_nope = torch.split(
|
||||
q,
|
||||
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64,64+64]
|
||||
|
||||
q_pe = q_pe.unsqueeze(2)
|
||||
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
|
||||
q_pe = q_pe.squeeze(2)
|
||||
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
|
||||
|
||||
k_pe, k_nope = torch.split(
|
||||
k,
|
||||
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64+64]
|
||||
|
||||
k_pe = k_pe.unsqueeze(2)
|
||||
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
|
||||
k_pe = k_pe.squeeze(2)
|
||||
|
||||
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
|
||||
|
||||
if kv_cache is not None:
|
||||
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
|
||||
|
||||
Reference in New Issue
Block a user