Support FlashMLA backend (#4472)
Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
@@ -71,6 +71,7 @@ global_server_args_dict = {
|
||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||
"enable_flashmla": ServerArgs.enable_flashmla,
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
}
|
||||
@@ -1273,7 +1274,10 @@ class ScheduleBatch:
|
||||
|
||||
def get_model_worker_batch(self) -> ModelWorkerBatch:
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||
if (
|
||||
global_server_args_dict["enable_flashinfer_mla"]
|
||||
or global_server_args_dict["enable_flashmla"]
|
||||
):
|
||||
decode_seq_lens = self.seq_lens.cpu()
|
||||
else:
|
||||
decode_seq_lens = None
|
||||
|
||||
Reference in New Issue
Block a user