Fix attention backend (#1448)

This commit is contained in:
Ke Bao
2024-09-17 22:07:53 +08:00
committed by GitHub
parent c6b6d2e71b
commit b3710d2c93
2 changed files with 8 additions and 4 deletions

View File

@@ -86,6 +86,14 @@ class ModelRunner:
self.is_multimodal_model = is_multimodal_model(
self.model_config.hf_config.architectures
)
if (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
):
logger.info("MLA optimization is tunred on. Use triton backend.")
self.server_args.attention_backend = "triton"
global_server_args_dict.update(
{
"attention_backend": server_args.attention_backend,

View File

@@ -173,10 +173,6 @@ class ServerArgs:
self.sampling_backend = "pytorch"
# Default kernel backends
if not self.disable_mla:
logger.info("MLA optimization is tunred on. Use triton backend.")
self.attention_backend = "triton"
if self.attention_backend is None:
self.attention_backend = "flashinfer"