feat: support flashinfer mla attention for deepseek v3 (#3550)

This commit is contained in:
Yineng Zhang
2025-02-14 08:50:14 +08:00
committed by GitHub
parent 368de3661e
commit 70f894b810
12 changed files with 299 additions and 135 deletions

View File

@@ -510,14 +510,20 @@ class DeepseekV2AttentionMLA(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# Use normal computation for prefill and use weight absorption for extend/decode
if (
forward_batch.forward_mode.is_extend()
and forward_batch.extend_prefix_lens.sum() == 0
):
return self.forward_normal(positions, hidden_states, forward_batch)
if global_server_args_dict["enable_flashinfer_mla"]:
if forward_batch.forward_mode.is_extend():
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
if (
forward_batch.forward_mode.is_extend()
and forward_batch.extend_prefix_lens.sum() == 0
):
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
def forward_normal(
self,