Replace enable_flashinfer_mla argument with attention_backend (#5005)
This commit is contained in:
@@ -151,7 +151,6 @@ class ModelRunner:
|
||||
"device": server_args.device,
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||
"enable_flashmla": server_args.enable_flashmla,
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
@@ -223,10 +222,14 @@ class ModelRunner:
|
||||
):
|
||||
# TODO: add MLA optimization on CPU
|
||||
if server_args.device != "cpu":
|
||||
if server_args.enable_flashinfer_mla:
|
||||
if (
|
||||
server_args.attention_backend == "flashinfer"
|
||||
or server_args.enable_flashinfer_mla
|
||||
):
|
||||
logger.info(
|
||||
"MLA optimization is turned on. Use flashinfer mla backend."
|
||||
"MLA optimization is turned on. Use flashinfer backend."
|
||||
)
|
||||
# Here we use a special flashinfer_mla tag to differentiate it from normal flashinfer backend
|
||||
server_args.attention_backend = "flashinfer_mla"
|
||||
elif server_args.enable_flashmla:
|
||||
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
||||
|
||||
Reference in New Issue
Block a user