From b3710d2c93b6f1ef608990096d71817c5cf35608 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Tue, 17 Sep 2024 22:07:53 +0800 Subject: [PATCH] Fix attention backend (#1448) --- python/sglang/srt/model_executor/model_runner.py | 8 ++++++++ python/sglang/srt/server_args.py | 4 ---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9e614b81d..dc8dcd4ed 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -86,6 +86,14 @@ class ModelRunner: self.is_multimodal_model = is_multimodal_model( self.model_config.hf_config.architectures ) + + if ( + self.model_config.attention_arch == AttentionArch.MLA + and not self.server_args.disable_mla + ): + logger.info("MLA optimization is tunred on. Use triton backend.") + self.server_args.attention_backend = "triton" + global_server_args_dict.update( { "attention_backend": server_args.attention_backend, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a741d43a2..7eef08b71 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -173,10 +173,6 @@ class ServerArgs: self.sampling_backend = "pytorch" # Default kernel backends - if not self.disable_mla: - logger.info("MLA optimization is tunred on. Use triton backend.") - self.attention_backend = "triton" - if self.attention_backend is None: self.attention_backend = "flashinfer"