Deprecate enable-flashinfer-mla and enable-flashmla (#5480)
This commit is contained in:
@@ -76,7 +76,6 @@ global_server_args_dict = {
|
||||
"device": ServerArgs.device,
|
||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
"enable_flashmla": ServerArgs.enable_flashmla,
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
||||
@@ -1480,7 +1479,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
global_server_args_dict["use_mla_backend"]
|
||||
and global_server_args_dict["attention_backend"] == "flashinfer"
|
||||
)
|
||||
or global_server_args_dict["enable_flashmla"]
|
||||
or global_server_args_dict["attention_backend"] == "flashmla"
|
||||
or global_server_args_dict["attention_backend"] == "fa3"
|
||||
):
|
||||
seq_lens_cpu = self.seq_lens.cpu()
|
||||
|
||||
@@ -156,7 +156,6 @@ class ModelRunner:
|
||||
"device": server_args.device,
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
"enable_flashmla": server_args.enable_flashmla,
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
||||
@@ -225,15 +224,7 @@ class ModelRunner:
|
||||
def model_specific_adjustment(self):
|
||||
server_args = self.server_args
|
||||
|
||||
if server_args.enable_flashinfer_mla:
|
||||
# TODO: remove this branch after enable_flashinfer_mla is deprecated
|
||||
logger.info("MLA optimization is turned on. Use flashinfer backend.")
|
||||
server_args.attention_backend = "flashinfer"
|
||||
elif server_args.enable_flashmla:
|
||||
# TODO: remove this branch after enable_flashmla is deprecated
|
||||
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
||||
server_args.attention_backend = "flashmla"
|
||||
elif server_args.attention_backend is None:
|
||||
if server_args.attention_backend is None:
|
||||
# By default, use flashinfer for non-mla attention and triton for mla attention
|
||||
if not self.use_mla_backend:
|
||||
if (
|
||||
@@ -259,7 +250,12 @@ class ModelRunner:
|
||||
elif self.use_mla_backend:
|
||||
# TODO: add MLA optimization on CPU
|
||||
if server_args.device != "cpu":
|
||||
if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
|
||||
if server_args.attention_backend in [
|
||||
"flashinfer",
|
||||
"fa3",
|
||||
"triton",
|
||||
"flashmla",
|
||||
]:
|
||||
logger.info(
|
||||
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
||||
)
|
||||
|
||||
@@ -179,8 +179,6 @@ class ServerArgs:
|
||||
tool_call_parser: Optional[str] = None
|
||||
enable_hierarchical_cache: bool = False
|
||||
hicache_ratio: float = 2.0
|
||||
enable_flashinfer_mla: bool = False # TODO: remove this argument
|
||||
enable_flashmla: bool = False
|
||||
flashinfer_mla_disable_ragged: bool = False
|
||||
warmups: Optional[str] = None
|
||||
n_share_experts_fusion: int = 0
|
||||
@@ -254,7 +252,7 @@ class ServerArgs:
|
||||
|
||||
assert self.chunked_prefill_size % self.page_size == 0
|
||||
|
||||
if self.enable_flashmla is True:
|
||||
if self.attention_backend == "flashmla":
|
||||
logger.warning(
|
||||
"FlashMLA only supports a page_size of 64, change page_size to 64."
|
||||
)
|
||||
@@ -823,7 +821,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--attention-backend",
|
||||
type=str,
|
||||
choices=["flashinfer", "triton", "torch_native", "fa3"],
|
||||
choices=["flashinfer", "triton", "torch_native", "fa3", "flashmla"],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
)
|
||||
@@ -843,13 +841,13 @@ class ServerArgs:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-mla",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
|
||||
action=DeprecatedAction,
|
||||
help="--enable-flashinfer-mla is deprecated. Please use '--attention-backend flashinfer' instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashmla",
|
||||
action="store_true",
|
||||
help="Enable FlashMLA decode optimization",
|
||||
action=DeprecatedAction,
|
||||
help="--enable-flashmla is deprecated. Please use '--attention-backend flashmla' instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flashinfer-mla-disable-ragged",
|
||||
|
||||
Reference in New Issue
Block a user