add 2 kernels and optimize the calculation of topk_indices (#134)

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-22 10:29:28 +08:00
committed by GitHub
parent c9f00c132c
commit 74d4f804e8
4 changed files with 108 additions and 32 deletions

View File

@@ -1,4 +1,5 @@
import torch
import xspeedgate_ops
def int8_mqa_logits(
q: torch.Tensor,
@@ -6,6 +7,10 @@ def int8_mqa_logits(
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
context_q_lens_xpu: torch.Tensor,
context_q_lens_cpu: torch.Tensor,
context_k_lens_xpu: torch.Tensor,
context_k_lens_cpu: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
@@ -24,28 +29,27 @@ def int8_mqa_logits(
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
logits = torch.empty((q.shape[0], kv[0].shape[0]), dtype=torch.float32, device=q.device)
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
seq_len_q, seq_len_kv =q.shape[0], kv[0].shape[0]
logits = torch.empty((seq_len_q, seq_len_kv), dtype=torch.float32, device=q.device)
torch.ops._C.I8_mqa_logits(
q=q,
fused_kv_cache=kv,
weights=weights,
context_q_lens=(context_q_lens_xpu.cpu(), context_q_lens_xpu),
context_k_lens=(context_k_lens_xpu.cpu(), context_k_lens_xpu),
context_q_lens=(context_q_lens_cpu, context_q_lens_xpu),
context_k_lens=(context_k_lens_cpu, context_k_lens_xpu),
logits=logits,
clean_logits=True,
use_xfa_boost=False,
)
seq_len_kv = kv[0].shape[0]
# mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现
mask_lo = (torch.arange(0, seq_len_kv, device=cu_seqlen_ks.device)[None, :]
>= cu_seqlen_ks[:, None])
mask_hi = (torch.arange(0, seq_len_kv, device=cu_seqlen_ke.device)[None, :]
< cu_seqlen_ke[:, None])
mask = mask_lo & mask_hi
logits = logits.masked_fill(~mask, float("-inf"))
torch.ops.xspeedgate_ops.mask_for_I8_mqa_logits(
seq_len_kv=seq_len_kv,
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
logits=logits,
)
return logits

View File

@@ -153,13 +153,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out = torch.empty(
M,
top_k,
layer.w2_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
del y
out1 = out1.reshape(-1, out1.shape[-1])
x_shape = out1.shape
@@ -168,6 +162,14 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
(x_shape[0], 1), dtype=torch.float32, device=moe_expand.device
)
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
del out1, moe_expand
out = torch.empty(
M,
top_k,
layer.w2_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops._C.moe_fc(
x=x_q,
@@ -182,6 +184,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
# sort_mode=False,
act=None,
)
del x_q, x_scale, sorted_tokens_num_lod,expert_m
dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device)
output = torch.empty(