diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d7f2ebe2b..0d6c794e6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -24,7 +24,7 @@ import tempfile from typing import List, Literal, Optional, Union from sglang.srt.hf_transformers_utils import check_gguf_file, get_config -from sglang.srt.layers.utils import is_sm100_supported +from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( @@ -2117,11 +2117,25 @@ class ServerArgs: model_arch = hf_config.architectures[0] if model_arch in ["GptOssForCausalLM"]: if self.attention_backend is None: - self.attention_backend = "triton" + if is_sm100_supported(): + self.attention_backend = "trtllm_mha" + elif is_sm90_supported(): + self.attention_backend = "fa3" + else: + self.attention_backend = "triton" supported_backends = ["triton", "trtllm_mha", "fa3"] + logger.info( + f"Use {self.attention_backend} as attention backend for GptOssForCausalLM" + ) assert ( self.attention_backend in supported_backends ), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'" + + if is_sm100_supported(): + self.enable_flashinfer_allreduce_fusion = True + logger.info( + "Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM" + ) quantization_config = getattr(hf_config, "quantization_config", None) is_mxfp4_quant_format = ( quantization_config is not None