add 2 kernels and optimize the calculation of topk_indices (#134)

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-22 10:29:28 +08:00
committed by GitHub
parent c9f00c132c
commit 74d4f804e8
4 changed files with 108 additions and 32 deletions

View File

@@ -596,23 +596,33 @@ def sparse_attn_indexer_vllm_kunlun(
if has_prefill:
prefill_metadata = attn_metadata.prefill
for chunk in prefill_metadata.chunks:
k_fp8, k_scale = cp_gather_indexer_k_quant_cache(
kv_cache,
chunk.block_table,
chunk.cu_seq_lens,
chunk.num_reqs,
head_dim,
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
device=k.device,
dtype=torch.int8)
k_scale = torch.empty([chunk.total_seq_lens, 4],
device=k.device,
dtype=torch.uint8)
torch.ops.xspeedgate_ops.cp_gather_indexer_k_quant_cache(
kv_cache=kv_cache,
dst_k=k_fp8,
dst_scale=k_scale,
block_table=chunk.block_table,
cu_seq_lens=chunk.cu_seq_lens,
)
logits = int8_mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
(k_fp8, k_scale),
(k_fp8, k_scale.view(torch.float32)),
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
context_q_lens_xpu=chunk.context_q_lens,
context_q_lens_cpu=chunk.context_q_lens_cpu,
context_k_lens_xpu=chunk.context_k_lens,
context_k_lens_cpu=chunk.context_k_lens_cpu,
)
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
dim=-1)[1]
del k_fp8, k_scale
topk_indices = torch.argsort(logits, dim=-1, descending=True)[..., :min(topk_tokens, logits.shape[-1])]
topk_indices -= chunk.cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
@@ -675,8 +685,9 @@ def sparse_attn_indexer_vllm_kunlun(
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float('-inf'))
topk_indices = logits.topk(topk_tokens,
dim=-1)[1].to(torch.int32) # [B * N, K]
topk_indices = torch.argsort(logits, dim=-1,
descending=True)[..., :min(topk_tokens, logits.shape[-1])]# [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K

View File

@@ -1,4 +1,5 @@
import torch
import xspeedgate_ops
def int8_mqa_logits(
q: torch.Tensor,
@@ -6,6 +7,10 @@ def int8_mqa_logits(
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
context_q_lens_xpu: torch.Tensor,
context_q_lens_cpu: torch.Tensor,
context_k_lens_xpu: torch.Tensor,
context_k_lens_cpu: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
@@ -24,28 +29,27 @@ def int8_mqa_logits(
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
logits = torch.empty((q.shape[0], kv[0].shape[0]), dtype=torch.float32, device=q.device)
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
seq_len_q, seq_len_kv =q.shape[0], kv[0].shape[0]
logits = torch.empty((seq_len_q, seq_len_kv), dtype=torch.float32, device=q.device)
torch.ops._C.I8_mqa_logits(
q=q,
fused_kv_cache=kv,
weights=weights,
context_q_lens=(context_q_lens_xpu.cpu(), context_q_lens_xpu),
context_k_lens=(context_k_lens_xpu.cpu(), context_k_lens_xpu),
context_q_lens=(context_q_lens_cpu, context_q_lens_xpu),
context_k_lens=(context_k_lens_cpu, context_k_lens_xpu),
logits=logits,
clean_logits=True,
use_xfa_boost=False,
)
seq_len_kv = kv[0].shape[0]
# mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现
mask_lo = (torch.arange(0, seq_len_kv, device=cu_seqlen_ks.device)[None, :]
>= cu_seqlen_ks[:, None])
mask_hi = (torch.arange(0, seq_len_kv, device=cu_seqlen_ke.device)[None, :]
< cu_seqlen_ke[:, None])
mask = mask_lo & mask_hi
logits = logits.masked_fill(~mask, float("-inf"))
torch.ops.xspeedgate_ops.mask_for_I8_mqa_logits(
seq_len_kv=seq_len_kv,
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
logits=logits,
)
return logits

View File

@@ -153,13 +153,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out = torch.empty(
M,
top_k,
layer.w2_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
del y
out1 = out1.reshape(-1, out1.shape[-1])
x_shape = out1.shape
@@ -168,6 +162,14 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
(x_shape[0], 1), dtype=torch.float32, device=moe_expand.device
)
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
del out1, moe_expand
out = torch.empty(
M,
top_k,
layer.w2_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops._C.moe_fc(
x=x_q,
@@ -182,6 +184,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
# sort_mode=False,
act=None,
)
del x_q, x_scale, sorted_tokens_num_lod,expert_m
dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device)
output = torch.empty(

View File

@@ -8,10 +8,30 @@ from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
DeepseekV32IndexerMetadataBuilder,
DeepseekV32IndexerPrefillMetadata)
kv_spans_from_batches)
logger = init_logger(__name__)
@dataclass
class DeepseekV32IndexerPrefillChunkMetadata:
block_table: torch.Tensor
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor
total_seq_lens: int
token_start: int
token_end: int
num_reqs: int
context_q_lens: torch.Tensor
context_q_lens_cpu: torch.Tensor
context_k_lens: torch.Tensor
context_k_lens_cpu: torch.Tensor
@dataclass
class DeepseekV32IndexerPrefillMetadata:
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
@dataclass
class DeepSeekV32IndexerDecodeMetadata:
block_table: torch.Tensor
@@ -50,6 +70,43 @@ class DeepseekV32IndexerMetadata:
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
query_start_loc_cpu, seq_lens_cpu,
block_table):
prefill_query_start_loc = query_start_loc_cpu[
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
self.device)
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = torch.cat([
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
]).to(torch.int32).to(self.device)
seq_len_q = token_end - token_start
seq_len_kv = total_seq_lens
context_q_lens = torch.tensor([0, seq_len_q], dtype=torch.int32, device=self.device)
context_k_lens = torch.tensor([0, seq_len_kv], dtype=torch.int32, device=self.device)
context_q_lens_cpu = torch.tensor([0, seq_len_q], dtype=torch.int32, device="cpu")
context_k_lens_cpu = torch.tensor([0, seq_len_kv], dtype=torch.int32, device="cpu")
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
total_seq_lens=total_seq_lens,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
context_q_lens=context_q_lens,
context_q_lens_cpu=context_q_lens_cpu,
context_k_lens=context_k_lens,
context_k_lens_cpu=context_k_lens_cpu,
)
def kunlun_build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
@@ -130,4 +187,5 @@ def kunlun_build(self,
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk= kunlun_build_one_prefill_chunk
DeepseekV32IndexerMetadataBuilder.build = kunlun_build