Support FlashMLA backend (#4472)

Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
lukec
2025-03-17 00:07:06 +08:00
committed by GitHub
parent 1b859295f4
commit a53fe428f9
6 changed files with 209 additions and 1 deletions

View File

@@ -71,6 +71,7 @@ global_server_args_dict = {
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
}
@@ -1273,7 +1274,10 @@ class ScheduleBatch:
def get_model_worker_batch(self) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle():
if global_server_args_dict["enable_flashinfer_mla"]:
if (
global_server_args_dict["enable_flashinfer_mla"]
or global_server_args_dict["enable_flashmla"]
):
decode_seq_lens = self.seq_lens.cpu()
else:
decode_seq_lens = None