diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0019fab76..6384532cd 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -147,6 +147,25 @@ from sglang.srt.weight_sync.tensor_bucket import ( FlattenedTensorMetadata, ) +MLA_ATTENTION_BACKENDS = [ + "aiter", + "flashinfer", + "fa3", + "fa4", + "triton", + "flashmla", + "cutlass_mla", + "trtllm_mla", + "ascend", +] + + +def add_mla_attention_backend(backend_name): + if backend_name not in MLA_ATTENTION_BACKENDS: + MLA_ATTENTION_BACKENDS.append(backend_name) + logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.") + + _is_hip = is_hip() _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() @@ -513,17 +532,7 @@ class ModelRunner: ) elif self.use_mla_backend: if server_args.device != "cpu": - if server_args.attention_backend in [ - "aiter", - "flashinfer", - "fa3", - "fa4", - "triton", - "flashmla", - "cutlass_mla", - "trtllm_mla", - "ascend", - ]: + if server_args.attention_backend in MLA_ATTENTION_BACKENDS: logger.info( f"MLA optimization is turned on. Use {server_args.attention_backend} backend." )