add 2 kernels and optimize the calculation of topk_indices (#134)
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -596,23 +596,33 @@ def sparse_attn_indexer_vllm_kunlun(
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8, k_scale = cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
chunk.num_reqs,
|
||||
head_dim,
|
||||
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
|
||||
device=k.device,
|
||||
dtype=torch.int8)
|
||||
k_scale = torch.empty([chunk.total_seq_lens, 4],
|
||||
device=k.device,
|
||||
dtype=torch.uint8)
|
||||
torch.ops.xspeedgate_ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache=kv_cache,
|
||||
dst_k=k_fp8,
|
||||
dst_scale=k_scale,
|
||||
block_table=chunk.block_table,
|
||||
cu_seq_lens=chunk.cu_seq_lens,
|
||||
)
|
||||
|
||||
logits = int8_mqa_logits(
|
||||
q_fp8[chunk.token_start:chunk.token_end],
|
||||
(k_fp8, k_scale),
|
||||
(k_fp8, k_scale.view(torch.float32)),
|
||||
weights[chunk.token_start:chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
context_q_lens_xpu=chunk.context_q_lens,
|
||||
context_q_lens_cpu=chunk.context_q_lens_cpu,
|
||||
context_k_lens_xpu=chunk.context_k_lens,
|
||||
context_k_lens_cpu=chunk.context_k_lens_cpu,
|
||||
)
|
||||
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
|
||||
dim=-1)[1]
|
||||
del k_fp8, k_scale
|
||||
topk_indices = torch.argsort(logits, dim=-1, descending=True)[..., :min(topk_tokens, logits.shape[-1])]
|
||||
|
||||
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
||||
mask_lo = topk_indices >= 0
|
||||
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
|
||||
@@ -675,8 +685,9 @@ def sparse_attn_indexer_vllm_kunlun(
|
||||
mask = positions <= index_end_pos
|
||||
# mask: [B * N, L]
|
||||
logits = logits.masked_fill(~mask, float('-inf'))
|
||||
topk_indices = logits.topk(topk_tokens,
|
||||
dim=-1)[1].to(torch.int32) # [B * N, K]
|
||||
topk_indices = torch.argsort(logits, dim=-1,
|
||||
descending=True)[..., :min(topk_tokens, logits.shape[-1])]# [B * N, K]
|
||||
|
||||
# ensure we don't set indices for the top k
|
||||
# that is out of range(masked already)
|
||||
# this will happen if context length is shorter than K
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import xspeedgate_ops
|
||||
|
||||
def int8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
@@ -6,6 +7,10 @@ def int8_mqa_logits(
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
context_q_lens_xpu: torch.Tensor,
|
||||
context_q_lens_cpu: torch.Tensor,
|
||||
context_k_lens_xpu: torch.Tensor,
|
||||
context_k_lens_cpu: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
@@ -24,28 +29,27 @@ def int8_mqa_logits(
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
logits = torch.empty((q.shape[0], kv[0].shape[0]), dtype=torch.float32, device=q.device)
|
||||
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
seq_len_q, seq_len_kv =q.shape[0], kv[0].shape[0]
|
||||
logits = torch.empty((seq_len_q, seq_len_kv), dtype=torch.float32, device=q.device)
|
||||
|
||||
torch.ops._C.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=kv,
|
||||
weights=weights,
|
||||
context_q_lens=(context_q_lens_xpu.cpu(), context_q_lens_xpu),
|
||||
context_k_lens=(context_k_lens_xpu.cpu(), context_k_lens_xpu),
|
||||
context_q_lens=(context_q_lens_cpu, context_q_lens_xpu),
|
||||
context_k_lens=(context_k_lens_cpu, context_k_lens_xpu),
|
||||
logits=logits,
|
||||
clean_logits=True,
|
||||
use_xfa_boost=False,
|
||||
)
|
||||
seq_len_kv = kv[0].shape[0]
|
||||
|
||||
# mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现
|
||||
mask_lo = (torch.arange(0, seq_len_kv, device=cu_seqlen_ks.device)[None, :]
|
||||
>= cu_seqlen_ks[:, None])
|
||||
mask_hi = (torch.arange(0, seq_len_kv, device=cu_seqlen_ke.device)[None, :]
|
||||
< cu_seqlen_ke[:, None])
|
||||
mask = mask_lo & mask_hi
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
torch.ops.xspeedgate_ops.mask_for_I8_mqa_logits(
|
||||
seq_len_kv=seq_len_kv,
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
logits=logits,
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@@ -153,13 +153,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.silu_and_mul(out1, y)
|
||||
|
||||
out = torch.empty(
|
||||
M,
|
||||
top_k,
|
||||
layer.w2_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
del y
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
x_shape = out1.shape
|
||||
@@ -168,6 +162,14 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
|
||||
(x_shape[0], 1), dtype=torch.float32, device=moe_expand.device
|
||||
)
|
||||
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
|
||||
del out1, moe_expand
|
||||
out = torch.empty(
|
||||
M,
|
||||
top_k,
|
||||
layer.w2_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=x_q,
|
||||
@@ -182,6 +184,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
|
||||
# sort_mode=False,
|
||||
act=None,
|
||||
)
|
||||
del x_q, x_scale, sorted_tokens_num_lod,expert_m
|
||||
|
||||
dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device)
|
||||
output = torch.empty(
|
||||
|
||||
@@ -8,10 +8,30 @@ from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
|
||||
DeepseekV32IndexerMetadataBuilder,
|
||||
DeepseekV32IndexerPrefillMetadata)
|
||||
kv_spans_from_batches)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
block_table: torch.Tensor
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
total_seq_lens: int
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
context_q_lens: torch.Tensor
|
||||
context_q_lens_cpu: torch.Tensor
|
||||
context_k_lens: torch.Tensor
|
||||
context_k_lens_cpu: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillMetadata:
|
||||
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
||||
|
||||
@dataclass
|
||||
class DeepSeekV32IndexerDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
@@ -50,6 +70,43 @@ class DeepseekV32IndexerMetadata:
|
||||
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
|
||||
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
|
||||
|
||||
def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
|
||||
query_start_loc_cpu, seq_lens_cpu,
|
||||
block_table):
|
||||
prefill_query_start_loc = query_start_loc_cpu[
|
||||
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
|
||||
self.device)
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = torch.cat([
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
|
||||
]).to(torch.int32).to(self.device)
|
||||
seq_len_q = token_end - token_start
|
||||
seq_len_kv = total_seq_lens
|
||||
context_q_lens = torch.tensor([0, seq_len_q], dtype=torch.int32, device=self.device)
|
||||
context_k_lens = torch.tensor([0, seq_len_kv], dtype=torch.int32, device=self.device)
|
||||
context_q_lens_cpu = torch.tensor([0, seq_len_q], dtype=torch.int32, device="cpu")
|
||||
context_k_lens_cpu = torch.tensor([0, seq_len_kv], dtype=torch.int32, device="cpu")
|
||||
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
total_seq_lens=total_seq_lens,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
context_q_lens=context_q_lens,
|
||||
context_q_lens_cpu=context_q_lens_cpu,
|
||||
context_k_lens=context_k_lens,
|
||||
context_k_lens_cpu=context_k_lens_cpu,
|
||||
)
|
||||
def kunlun_build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
@@ -130,4 +187,5 @@ def kunlun_build(self,
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
|
||||
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk= kunlun_build_one_prefill_chunk
|
||||
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
|
||||
Reference in New Issue
Block a user