opt flashinfer mla cat (#5822)
Co-authored-by: xuyongfei.xyf <xuyongfei.xyf@antgroup.com>
This commit is contained in:
@@ -339,22 +339,38 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
cache_loc = forward_batch.out_cache_loc
|
cache_loc = forward_batch.out_cache_loc
|
||||||
logits_soft_cap = layer.logit_cap
|
logits_soft_cap = layer.logit_cap
|
||||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
||||||
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
||||||
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
|
||||||
# Save kv cache
|
# Save kv cache
|
||||||
if save_kv_cache and k is not None:
|
if save_kv_cache and k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
if k_rope is not None:
|
||||||
|
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||||
|
layer, cache_loc, k, k_rope
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||||
|
if q_rope is not None:
|
||||||
|
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||||
|
q_rope = q_rope.view(
|
||||||
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
|
)
|
||||||
|
|
||||||
if self.forward_metadata.use_ragged:
|
if self.forward_metadata.use_ragged:
|
||||||
# ragged prefill
|
# ragged prefill
|
||||||
|
if q_rope is not None:
|
||||||
|
q = torch.cat([q, q_rope], dim=-1)
|
||||||
|
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
if k_rope is not None:
|
||||||
|
k = torch.cat([k, k_rope], dim=-1)
|
||||||
o = self.prefill_wrapper_ragged.forward(
|
o = self.prefill_wrapper_ragged.forward(
|
||||||
qall,
|
qall,
|
||||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||||
@@ -365,11 +381,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# mla paged prefill
|
# mla paged prefill
|
||||||
|
if q_rope is None:
|
||||||
|
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
q, q_rope = (
|
||||||
|
qall[:, :, : layer.v_head_dim],
|
||||||
|
qall[:, :, layer.v_head_dim :],
|
||||||
|
)
|
||||||
|
o = q.new_empty(q.shape)
|
||||||
o = prefill_wrapper_paged.run(
|
o = prefill_wrapper_paged.run(
|
||||||
qall[:, :, : layer.v_head_dim],
|
q,
|
||||||
qall[:, :, layer.v_head_dim :],
|
q_rope,
|
||||||
k_buf[:, :, : layer.v_head_dim],
|
k_buf[:, :, : layer.v_head_dim],
|
||||||
k_buf[:, :, layer.v_head_dim :],
|
k_buf[:, :, layer.v_head_dim :],
|
||||||
|
out=o,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
@@ -382,6 +406,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
|
# For multi-head latent attention
|
||||||
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
decode_wrapper = self.forward_metadata.decode_wrapper
|
decode_wrapper = self.forward_metadata.decode_wrapper
|
||||||
cache_loc = forward_batch.out_cache_loc
|
cache_loc = forward_batch.out_cache_loc
|
||||||
@@ -389,23 +416,42 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
if k_rope is not None:
|
||||||
layer,
|
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||||
cache_loc,
|
layer,
|
||||||
k,
|
cache_loc,
|
||||||
v,
|
k,
|
||||||
)
|
k_rope,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer,
|
||||||
|
cache_loc,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
|
||||||
# Reshape inputs
|
# Reshape inputs
|
||||||
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
if q_rope is not None:
|
||||||
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||||
|
q_rope = q_rope.view(
|
||||||
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
||||||
|
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
||||||
|
|
||||||
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
|
||||||
|
o = q_nope.new_empty(q_nope.shape)
|
||||||
# Direct call to run without the wrapper
|
# Direct call to run without the wrapper
|
||||||
o = decode_wrapper.run(
|
o = decode_wrapper.run(
|
||||||
reshaped_q[:, :, : layer.v_head_dim],
|
q_nope,
|
||||||
reshaped_q[:, :, layer.v_head_dim :],
|
q_rope,
|
||||||
k_buffer[:, :, : layer.v_head_dim],
|
k_buffer[:, :, : layer.v_head_dim],
|
||||||
k_buffer[:, :, layer.v_head_dim :],
|
k_buffer[:, :, layer.v_head_dim :],
|
||||||
|
out=o,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|||||||
@@ -777,7 +777,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
q_nope_out = q_nope_out.transpose(0, 1)
|
q_nope_out = q_nope_out.transpose(0, 1)
|
||||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||||
|
|
||||||
if self.attention_backend == "fa3":
|
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
|
||||||
attn_output = self.attn_mqa(
|
attn_output = self.attn_mqa(
|
||||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user