feat: support flashinfer mla with prefix cache (#3643)
This commit is contained in:
@@ -511,8 +511,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
if global_server_args_dict["disable_radix_cache"]:
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user