add 2 kernels and optimize the calculation of topk_indices (#134)
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -596,23 +596,33 @@ def sparse_attn_indexer_vllm_kunlun(
|
|||||||
if has_prefill:
|
if has_prefill:
|
||||||
prefill_metadata = attn_metadata.prefill
|
prefill_metadata = attn_metadata.prefill
|
||||||
for chunk in prefill_metadata.chunks:
|
for chunk in prefill_metadata.chunks:
|
||||||
k_fp8, k_scale = cp_gather_indexer_k_quant_cache(
|
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
|
||||||
kv_cache,
|
device=k.device,
|
||||||
chunk.block_table,
|
dtype=torch.int8)
|
||||||
chunk.cu_seq_lens,
|
k_scale = torch.empty([chunk.total_seq_lens, 4],
|
||||||
chunk.num_reqs,
|
device=k.device,
|
||||||
head_dim,
|
dtype=torch.uint8)
|
||||||
|
torch.ops.xspeedgate_ops.cp_gather_indexer_k_quant_cache(
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
dst_k=k_fp8,
|
||||||
|
dst_scale=k_scale,
|
||||||
|
block_table=chunk.block_table,
|
||||||
|
cu_seq_lens=chunk.cu_seq_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = int8_mqa_logits(
|
logits = int8_mqa_logits(
|
||||||
q_fp8[chunk.token_start:chunk.token_end],
|
q_fp8[chunk.token_start:chunk.token_end],
|
||||||
(k_fp8, k_scale),
|
(k_fp8, k_scale.view(torch.float32)),
|
||||||
weights[chunk.token_start:chunk.token_end],
|
weights[chunk.token_start:chunk.token_end],
|
||||||
chunk.cu_seqlen_ks,
|
chunk.cu_seqlen_ks,
|
||||||
chunk.cu_seqlen_ke,
|
chunk.cu_seqlen_ke,
|
||||||
|
context_q_lens_xpu=chunk.context_q_lens,
|
||||||
|
context_q_lens_cpu=chunk.context_q_lens_cpu,
|
||||||
|
context_k_lens_xpu=chunk.context_k_lens,
|
||||||
|
context_k_lens_cpu=chunk.context_k_lens_cpu,
|
||||||
)
|
)
|
||||||
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
|
del k_fp8, k_scale
|
||||||
dim=-1)[1]
|
topk_indices = torch.argsort(logits, dim=-1, descending=True)[..., :min(topk_tokens, logits.shape[-1])]
|
||||||
|
|
||||||
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
||||||
mask_lo = topk_indices >= 0
|
mask_lo = topk_indices >= 0
|
||||||
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
|
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
|
||||||
@@ -675,8 +685,9 @@ def sparse_attn_indexer_vllm_kunlun(
|
|||||||
mask = positions <= index_end_pos
|
mask = positions <= index_end_pos
|
||||||
# mask: [B * N, L]
|
# mask: [B * N, L]
|
||||||
logits = logits.masked_fill(~mask, float('-inf'))
|
logits = logits.masked_fill(~mask, float('-inf'))
|
||||||
topk_indices = logits.topk(topk_tokens,
|
topk_indices = torch.argsort(logits, dim=-1,
|
||||||
dim=-1)[1].to(torch.int32) # [B * N, K]
|
descending=True)[..., :min(topk_tokens, logits.shape[-1])]# [B * N, K]
|
||||||
|
|
||||||
# ensure we don't set indices for the top k
|
# ensure we don't set indices for the top k
|
||||||
# that is out of range(masked already)
|
# that is out of range(masked already)
|
||||||
# this will happen if context length is shorter than K
|
# this will happen if context length is shorter than K
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import xspeedgate_ops
|
||||||
|
|
||||||
def int8_mqa_logits(
|
def int8_mqa_logits(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
@@ -6,6 +7,10 @@ def int8_mqa_logits(
|
|||||||
weights: torch.Tensor,
|
weights: torch.Tensor,
|
||||||
cu_seqlen_ks: torch.Tensor,
|
cu_seqlen_ks: torch.Tensor,
|
||||||
cu_seqlen_ke: torch.Tensor,
|
cu_seqlen_ke: torch.Tensor,
|
||||||
|
context_q_lens_xpu: torch.Tensor,
|
||||||
|
context_q_lens_cpu: torch.Tensor,
|
||||||
|
context_k_lens_xpu: torch.Tensor,
|
||||||
|
context_k_lens_cpu: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||||
|
|
||||||
@@ -24,28 +29,27 @@ def int8_mqa_logits(
|
|||||||
Returns:
|
Returns:
|
||||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||||
"""
|
"""
|
||||||
logits = torch.empty((q.shape[0], kv[0].shape[0]), dtype=torch.float32, device=q.device)
|
seq_len_q, seq_len_kv =q.shape[0], kv[0].shape[0]
|
||||||
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
logits = torch.empty((seq_len_q, seq_len_kv), dtype=torch.float32, device=q.device)
|
||||||
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
|
||||||
|
|
||||||
torch.ops._C.I8_mqa_logits(
|
torch.ops._C.I8_mqa_logits(
|
||||||
q=q,
|
q=q,
|
||||||
fused_kv_cache=kv,
|
fused_kv_cache=kv,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
context_q_lens=(context_q_lens_xpu.cpu(), context_q_lens_xpu),
|
context_q_lens=(context_q_lens_cpu, context_q_lens_xpu),
|
||||||
context_k_lens=(context_k_lens_xpu.cpu(), context_k_lens_xpu),
|
context_k_lens=(context_k_lens_cpu, context_k_lens_xpu),
|
||||||
logits=logits,
|
logits=logits,
|
||||||
clean_logits=True,
|
clean_logits=True,
|
||||||
use_xfa_boost=False,
|
use_xfa_boost=False,
|
||||||
)
|
)
|
||||||
seq_len_kv = kv[0].shape[0]
|
|
||||||
# mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现
|
# mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现
|
||||||
mask_lo = (torch.arange(0, seq_len_kv, device=cu_seqlen_ks.device)[None, :]
|
torch.ops.xspeedgate_ops.mask_for_I8_mqa_logits(
|
||||||
>= cu_seqlen_ks[:, None])
|
seq_len_kv=seq_len_kv,
|
||||||
mask_hi = (torch.arange(0, seq_len_kv, device=cu_seqlen_ke.device)[None, :]
|
cu_seqlen_ks=cu_seqlen_ks,
|
||||||
< cu_seqlen_ke[:, None])
|
cu_seqlen_ke=cu_seqlen_ke,
|
||||||
mask = mask_lo & mask_hi
|
logits=logits,
|
||||||
logits = logits.masked_fill(~mask, float("-inf"))
|
)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|||||||
@@ -153,13 +153,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
|
|||||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||||
torch.ops._C.silu_and_mul(out1, y)
|
torch.ops._C.silu_and_mul(out1, y)
|
||||||
|
|
||||||
out = torch.empty(
|
del y
|
||||||
M,
|
|
||||||
top_k,
|
|
||||||
layer.w2_weight.shape[1],
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
device=hidden_states.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
out1 = out1.reshape(-1, out1.shape[-1])
|
out1 = out1.reshape(-1, out1.shape[-1])
|
||||||
x_shape = out1.shape
|
x_shape = out1.shape
|
||||||
@@ -168,6 +162,14 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
|
|||||||
(x_shape[0], 1), dtype=torch.float32, device=moe_expand.device
|
(x_shape[0], 1), dtype=torch.float32, device=moe_expand.device
|
||||||
)
|
)
|
||||||
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
|
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
|
||||||
|
del out1, moe_expand
|
||||||
|
out = torch.empty(
|
||||||
|
M,
|
||||||
|
top_k,
|
||||||
|
layer.w2_weight.shape[1],
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
|
||||||
torch.ops._C.moe_fc(
|
torch.ops._C.moe_fc(
|
||||||
x=x_q,
|
x=x_q,
|
||||||
@@ -182,6 +184,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
|
|||||||
# sort_mode=False,
|
# sort_mode=False,
|
||||||
act=None,
|
act=None,
|
||||||
)
|
)
|
||||||
|
del x_q, x_scale, sorted_tokens_num_lod,expert_m
|
||||||
|
|
||||||
dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device)
|
dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device)
|
||||||
output = torch.empty(
|
output = torch.empty(
|
||||||
|
|||||||
@@ -8,10 +8,30 @@ from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
|
|||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
|
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
|
||||||
DeepseekV32IndexerMetadataBuilder,
|
DeepseekV32IndexerMetadataBuilder,
|
||||||
DeepseekV32IndexerPrefillMetadata)
|
kv_spans_from_batches)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeepseekV32IndexerPrefillChunkMetadata:
|
||||||
|
block_table: torch.Tensor
|
||||||
|
cu_seqlen_ks: torch.Tensor
|
||||||
|
cu_seqlen_ke: torch.Tensor
|
||||||
|
cu_seq_lens: torch.Tensor
|
||||||
|
total_seq_lens: int
|
||||||
|
token_start: int
|
||||||
|
token_end: int
|
||||||
|
num_reqs: int
|
||||||
|
context_q_lens: torch.Tensor
|
||||||
|
context_q_lens_cpu: torch.Tensor
|
||||||
|
context_k_lens: torch.Tensor
|
||||||
|
context_k_lens_cpu: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeepseekV32IndexerPrefillMetadata:
|
||||||
|
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DeepSeekV32IndexerDecodeMetadata:
|
class DeepSeekV32IndexerDecodeMetadata:
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
@@ -50,6 +70,43 @@ class DeepseekV32IndexerMetadata:
|
|||||||
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
|
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
|
||||||
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
|
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
|
||||||
|
|
||||||
|
def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
|
||||||
|
query_start_loc_cpu, seq_lens_cpu,
|
||||||
|
block_table):
|
||||||
|
prefill_query_start_loc = query_start_loc_cpu[
|
||||||
|
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
|
||||||
|
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||||
|
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
|
||||||
|
self.device)
|
||||||
|
token_start = query_start_loc_cpu[reqs_start].item()
|
||||||
|
token_end = query_start_loc_cpu[reqs_end].item()
|
||||||
|
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||||
|
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||||
|
cu_seq_lens = torch.cat([
|
||||||
|
torch.zeros(1, dtype=torch.int32),
|
||||||
|
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
|
||||||
|
]).to(torch.int32).to(self.device)
|
||||||
|
seq_len_q = token_end - token_start
|
||||||
|
seq_len_kv = total_seq_lens
|
||||||
|
context_q_lens = torch.tensor([0, seq_len_q], dtype=torch.int32, device=self.device)
|
||||||
|
context_k_lens = torch.tensor([0, seq_len_kv], dtype=torch.int32, device=self.device)
|
||||||
|
context_q_lens_cpu = torch.tensor([0, seq_len_q], dtype=torch.int32, device="cpu")
|
||||||
|
context_k_lens_cpu = torch.tensor([0, seq_len_kv], dtype=torch.int32, device="cpu")
|
||||||
|
|
||||||
|
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||||
|
cu_seqlen_ks=cu_seqlen_ks,
|
||||||
|
cu_seqlen_ke=cu_seqlen_ke,
|
||||||
|
cu_seq_lens=cu_seq_lens,
|
||||||
|
total_seq_lens=total_seq_lens,
|
||||||
|
block_table=block_table[reqs_start:reqs_end],
|
||||||
|
token_start=token_start,
|
||||||
|
token_end=token_end,
|
||||||
|
num_reqs=reqs_end - reqs_start,
|
||||||
|
context_q_lens=context_q_lens,
|
||||||
|
context_q_lens_cpu=context_q_lens_cpu,
|
||||||
|
context_k_lens=context_k_lens,
|
||||||
|
context_k_lens_cpu=context_k_lens_cpu,
|
||||||
|
)
|
||||||
def kunlun_build(self,
|
def kunlun_build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
@@ -130,4 +187,5 @@ def kunlun_build(self,
|
|||||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk= kunlun_build_one_prefill_chunk
|
||||||
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
|
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
|
||||||
Reference in New Issue
Block a user