add 2 kernels and optimize the calculation of topk_indices (#134)
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import xspeedgate_ops
|
||||
|
||||
def int8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
@@ -6,6 +7,10 @@ def int8_mqa_logits(
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
context_q_lens_xpu: torch.Tensor,
|
||||
context_q_lens_cpu: torch.Tensor,
|
||||
context_k_lens_xpu: torch.Tensor,
|
||||
context_k_lens_cpu: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
@@ -24,28 +29,27 @@ def int8_mqa_logits(
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
logits = torch.empty((q.shape[0], kv[0].shape[0]), dtype=torch.float32, device=q.device)
|
||||
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
seq_len_q, seq_len_kv =q.shape[0], kv[0].shape[0]
|
||||
logits = torch.empty((seq_len_q, seq_len_kv), dtype=torch.float32, device=q.device)
|
||||
|
||||
torch.ops._C.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=kv,
|
||||
weights=weights,
|
||||
context_q_lens=(context_q_lens_xpu.cpu(), context_q_lens_xpu),
|
||||
context_k_lens=(context_k_lens_xpu.cpu(), context_k_lens_xpu),
|
||||
context_q_lens=(context_q_lens_cpu, context_q_lens_xpu),
|
||||
context_k_lens=(context_k_lens_cpu, context_k_lens_xpu),
|
||||
logits=logits,
|
||||
clean_logits=True,
|
||||
use_xfa_boost=False,
|
||||
)
|
||||
seq_len_kv = kv[0].shape[0]
|
||||
|
||||
# mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现
|
||||
mask_lo = (torch.arange(0, seq_len_kv, device=cu_seqlen_ks.device)[None, :]
|
||||
>= cu_seqlen_ks[:, None])
|
||||
mask_hi = (torch.arange(0, seq_len_kv, device=cu_seqlen_ke.device)[None, :]
|
||||
< cu_seqlen_ke[:, None])
|
||||
mask = mask_lo & mask_hi
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
torch.ops.xspeedgate_ops.mask_for_I8_mqa_logits(
|
||||
seq_len_kv=seq_len_kv,
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
logits=logits,
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
Reference in New Issue
Block a user