feat: support flashinfer mla with prefix cache (#3643)

This commit is contained in:
Yineng Zhang
2025-02-18 02:06:43 +08:00
committed by GitHub
parent c38f3aed24
commit 714f3e6362
4 changed files with 107 additions and 31 deletions

View File

@@ -511,8 +511,11 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch,
) -> torch.Tensor:
if global_server_args_dict["enable_flashinfer_mla"]:
if forward_batch.forward_mode.is_extend():
return self.forward_normal(positions, hidden_states, forward_batch)
if global_server_args_dict["disable_radix_cache"]:
if forward_batch.forward_mode.is_extend():
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
else: