From 74d4f804e8f09ade4eca55dcf19db48ed8394be2 Mon Sep 17 00:00:00 2001 From: fromck <74886593+fromck@users.noreply.github.com> Date: Thu, 22 Jan 2026 10:29:28 +0800 Subject: [PATCH] add 2 kernels and optimize the calculation of topk_indices (#134) Co-authored-by: chengxiaokang --- vllm_kunlun/models/deepseek_v2.py | 35 +++++++---- vllm_kunlun/ops/deep_gemm.py | 28 +++++---- .../compressed_tensors_moe.py | 17 +++--- .../v1/attention/backends/mla/indexer.py | 60 ++++++++++++++++++- 4 files changed, 108 insertions(+), 32 deletions(-) diff --git a/vllm_kunlun/models/deepseek_v2.py b/vllm_kunlun/models/deepseek_v2.py index 1c3c11a..c1c00f3 100644 --- a/vllm_kunlun/models/deepseek_v2.py +++ b/vllm_kunlun/models/deepseek_v2.py @@ -596,23 +596,33 @@ def sparse_attn_indexer_vllm_kunlun( if has_prefill: prefill_metadata = attn_metadata.prefill for chunk in prefill_metadata.chunks: - k_fp8, k_scale = cp_gather_indexer_k_quant_cache( - kv_cache, - chunk.block_table, - chunk.cu_seq_lens, - chunk.num_reqs, - head_dim, + k_fp8 = torch.empty([chunk.total_seq_lens, head_dim], + device=k.device, + dtype=torch.int8) + k_scale = torch.empty([chunk.total_seq_lens, 4], + device=k.device, + dtype=torch.uint8) + torch.ops.xspeedgate_ops.cp_gather_indexer_k_quant_cache( + kv_cache=kv_cache, + dst_k=k_fp8, + dst_scale=k_scale, + block_table=chunk.block_table, + cu_seq_lens=chunk.cu_seq_lens, ) - logits = int8_mqa_logits( q_fp8[chunk.token_start:chunk.token_end], - (k_fp8, k_scale), + (k_fp8, k_scale.view(torch.float32)), weights[chunk.token_start:chunk.token_end], chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, + context_q_lens_xpu=chunk.context_q_lens, + context_q_lens_cpu=chunk.context_q_lens_cpu, + context_k_lens_xpu=chunk.context_k_lens, + context_k_lens_cpu=chunk.context_k_lens_cpu, ) - topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), - dim=-1)[1] + del k_fp8, k_scale + topk_indices = torch.argsort(logits, dim=-1, descending=True)[..., :min(topk_tokens, logits.shape[-1])] + topk_indices -= chunk.cu_seqlen_ks[:, None] mask_lo = topk_indices >= 0 mask_hi = topk_indices - (chunk.cu_seqlen_ke - @@ -675,8 +685,9 @@ def sparse_attn_indexer_vllm_kunlun( mask = positions <= index_end_pos # mask: [B * N, L] logits = logits.masked_fill(~mask, float('-inf')) - topk_indices = logits.topk(topk_tokens, - dim=-1)[1].to(torch.int32) # [B * N, K] + topk_indices = torch.argsort(logits, dim=-1, + descending=True)[..., :min(topk_tokens, logits.shape[-1])]# [B * N, K] + # ensure we don't set indices for the top k # that is out of range(masked already) # this will happen if context length is shorter than K diff --git a/vllm_kunlun/ops/deep_gemm.py b/vllm_kunlun/ops/deep_gemm.py index bd1318c..063c5b6 100644 --- a/vllm_kunlun/ops/deep_gemm.py +++ b/vllm_kunlun/ops/deep_gemm.py @@ -1,4 +1,5 @@ import torch +import xspeedgate_ops def int8_mqa_logits( q: torch.Tensor, @@ -6,6 +7,10 @@ def int8_mqa_logits( weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, + context_q_lens_xpu: torch.Tensor, + context_q_lens_cpu: torch.Tensor, + context_k_lens_xpu: torch.Tensor, + context_k_lens_cpu: torch.Tensor, ) -> torch.Tensor: """Compute FP8 MQA logits for a single sequence without KV paging. @@ -24,28 +29,27 @@ def int8_mqa_logits( Returns: Logits tensor of shape [M, N], dtype `torch.float32`. """ - logits = torch.empty((q.shape[0], kv[0].shape[0]), dtype=torch.float32, device=q.device) - context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device) - context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device) + seq_len_q, seq_len_kv =q.shape[0], kv[0].shape[0] + logits = torch.empty((seq_len_q, seq_len_kv), dtype=torch.float32, device=q.device) torch.ops._C.I8_mqa_logits( q=q, fused_kv_cache=kv, weights=weights, - context_q_lens=(context_q_lens_xpu.cpu(), context_q_lens_xpu), - context_k_lens=(context_k_lens_xpu.cpu(), context_k_lens_xpu), + context_q_lens=(context_q_lens_cpu, context_q_lens_xpu), + context_k_lens=(context_k_lens_cpu, context_k_lens_xpu), logits=logits, clean_logits=True, use_xfa_boost=False, ) - seq_len_kv = kv[0].shape[0] + # mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现 - mask_lo = (torch.arange(0, seq_len_kv, device=cu_seqlen_ks.device)[None, :] - >= cu_seqlen_ks[:, None]) - mask_hi = (torch.arange(0, seq_len_kv, device=cu_seqlen_ke.device)[None, :] - < cu_seqlen_ke[:, None]) - mask = mask_lo & mask_hi - logits = logits.masked_fill(~mask, float("-inf")) + torch.ops.xspeedgate_ops.mask_for_I8_mqa_logits( + seq_len_kv=seq_len_kv, + cu_seqlen_ks=cu_seqlen_ks, + cu_seqlen_ke=cu_seqlen_ke, + logits=logits, + ) return logits diff --git a/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py index 76452a1..c46fa3b 100644 --- a/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py @@ -153,13 +153,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) torch.ops._C.silu_and_mul(out1, y) - out = torch.empty( - M, - top_k, - layer.w2_weight.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + del y out1 = out1.reshape(-1, out1.shape[-1]) x_shape = out1.shape @@ -168,6 +162,14 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho (x_shape[0], 1), dtype=torch.float32, device=moe_expand.device ) torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) + del out1, moe_expand + out = torch.empty( + M, + top_k, + layer.w2_weight.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) torch.ops._C.moe_fc( x=x_q, @@ -182,6 +184,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho # sort_mode=False, act=None, ) + del x_q, x_scale, sorted_tokens_num_lod,expert_m dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device) output = torch.empty( diff --git a/vllm_kunlun/v1/attention/backends/mla/indexer.py b/vllm_kunlun/v1/attention/backends/mla/indexer.py index 67471be..ab2a383 100644 --- a/vllm_kunlun/v1/attention/backends/mla/indexer.py +++ b/vllm_kunlun/v1/attention/backends/mla/indexer.py @@ -8,10 +8,30 @@ from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, split_decodes_and_prefills) from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks, DeepseekV32IndexerMetadataBuilder, - DeepseekV32IndexerPrefillMetadata) + kv_spans_from_batches) logger = init_logger(__name__) +@dataclass +class DeepseekV32IndexerPrefillChunkMetadata: + block_table: torch.Tensor + cu_seqlen_ks: torch.Tensor + cu_seqlen_ke: torch.Tensor + cu_seq_lens: torch.Tensor + total_seq_lens: int + token_start: int + token_end: int + num_reqs: int + context_q_lens: torch.Tensor + context_q_lens_cpu: torch.Tensor + context_k_lens: torch.Tensor + context_k_lens_cpu: torch.Tensor + + +@dataclass +class DeepseekV32IndexerPrefillMetadata: + chunks: list[DeepseekV32IndexerPrefillChunkMetadata] + @dataclass class DeepSeekV32IndexerDecodeMetadata: block_table: torch.Tensor @@ -50,6 +70,43 @@ class DeepseekV32IndexerMetadata: decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None +def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end, + query_start_loc_cpu, seq_lens_cpu, + block_table): + prefill_query_start_loc = query_start_loc_cpu[ + reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start] + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( + prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], + self.device) + token_start = query_start_loc_cpu[reqs_start].item() + token_end = query_start_loc_cpu[reqs_end].item() + total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() + assert total_seq_lens <= self.max_prefill_buffer_size + cu_seq_lens = torch.cat([ + torch.zeros(1, dtype=torch.int32), + seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0) + ]).to(torch.int32).to(self.device) + seq_len_q = token_end - token_start + seq_len_kv = total_seq_lens + context_q_lens = torch.tensor([0, seq_len_q], dtype=torch.int32, device=self.device) + context_k_lens = torch.tensor([0, seq_len_kv], dtype=torch.int32, device=self.device) + context_q_lens_cpu = torch.tensor([0, seq_len_q], dtype=torch.int32, device="cpu") + context_k_lens_cpu = torch.tensor([0, seq_len_kv], dtype=torch.int32, device="cpu") + + return DeepseekV32IndexerPrefillChunkMetadata( + cu_seqlen_ks=cu_seqlen_ks, + cu_seqlen_ke=cu_seqlen_ke, + cu_seq_lens=cu_seq_lens, + total_seq_lens=total_seq_lens, + block_table=block_table[reqs_start:reqs_end], + token_start=token_start, + token_end=token_end, + num_reqs=reqs_end - reqs_start, + context_q_lens=context_q_lens, + context_q_lens_cpu=context_q_lens_cpu, + context_k_lens=context_k_lens, + context_k_lens_cpu=context_k_lens_cpu, + ) def kunlun_build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, @@ -130,4 +187,5 @@ def kunlun_build(self, # logger.info(f"attn_metadata: {attn_metadata}") return attn_metadata +DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk= kunlun_build_one_prefill_chunk DeepseekV32IndexerMetadataBuilder.build = kunlun_build \ No newline at end of file