Refactor flashinfer logic for deepseek v3 and fix accuracy bug (#3785)

This commit is contained in:
Baizhou Zhang
2025-02-24 04:07:25 -08:00
committed by GitHub
parent 27a46317b6
commit b110084654
4 changed files with 565 additions and 19 deletions

View File

@@ -510,25 +510,27 @@ class DeepseekV2AttentionMLA(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if global_server_args_dict["enable_flashinfer_mla"]:
if global_server_args_dict["disable_radix_cache"]:
if forward_batch.forward_mode.is_extend():
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
def no_absorb() -> bool:
if global_server_args_dict["enable_flashinfer_mla"]:
# Flashinfer MLA: Only do not use absorb when prefilling/extending without radix cache
return (
global_server_args_dict["disable_radix_cache"]
and forward_batch.forward_mode.is_extend()
)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
return (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
if no_absorb():
return self.forward_normal(positions, hidden_states, forward_batch)
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
):
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
return self.forward_absorb(positions, hidden_states, forward_batch)
def forward_normal(
self,