Remove 200us slow concat kernel (part 2: srt) (#7020)
This commit is contained in:
@@ -233,25 +233,49 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
|
# For multi-head latent attention
|
||||||
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
cache_loc = forward_batch.out_cache_loc
|
cache_loc = forward_batch.out_cache_loc
|
||||||
|
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
if k_rope is not None:
|
||||||
layer,
|
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||||
cache_loc,
|
layer,
|
||||||
k,
|
cache_loc,
|
||||||
v,
|
k,
|
||||||
)
|
k_rope,
|
||||||
bs = forward_batch.batch_size
|
)
|
||||||
|
else:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer,
|
||||||
|
cache_loc,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape inputs
|
||||||
|
if q_rope is not None:
|
||||||
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||||
|
q_rope = q_rope.view(
|
||||||
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
||||||
|
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
||||||
|
|
||||||
|
q_nope = q_nope.to(self.q_data_type)
|
||||||
|
q_rope = q_rope.to(self.q_data_type)
|
||||||
|
|
||||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
|
||||||
reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
||||||
|
|
||||||
o = cutlass_mla_decode(
|
o = cutlass_mla_decode(
|
||||||
q_nope_and_q_pe=reshape_q.to(self.q_data_type),
|
q_nope=q_nope,
|
||||||
|
q_pe=q_rope,
|
||||||
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
|
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
|
||||||
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
||||||
page_table=self.forward_metadata.block_kv_indices,
|
page_table=self.forward_metadata.block_kv_indices,
|
||||||
|
|||||||
@@ -1013,7 +1013,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
def forward_absorb_core(
|
def forward_absorb_core(
|
||||||
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||||
):
|
):
|
||||||
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
|
if (
|
||||||
|
self.attention_backend == "fa3"
|
||||||
|
or self.attention_backend == "flashinfer"
|
||||||
|
or self.attention_backend == "cutlass_mla"
|
||||||
|
):
|
||||||
attn_output = self.attn_mqa(
|
attn_output = self.attn_mqa(
|
||||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user