Refactor flashinfer logic for deepseek v3 and fix accuracy bug (#3785)
This commit is contained in:
@@ -510,25 +510,27 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||
if global_server_args_dict["disable_radix_cache"]:
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
|
||||
def no_absorb() -> bool:
|
||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||
# Flashinfer MLA: Only do not use absorb when prefilling/extending without radix cache
|
||||
return (
|
||||
global_server_args_dict["disable_radix_cache"]
|
||||
and forward_batch.forward_mode.is_extend()
|
||||
)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
return (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and forward_batch.extend_prefix_lens.sum() == 0
|
||||
)
|
||||
|
||||
if no_absorb():
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and forward_batch.extend_prefix_lens.sum() == 0
|
||||
):
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user