feat: update model config (#9202)
This commit is contained in:
@@ -24,7 +24,7 @@ import tempfile
|
|||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
||||||
from sglang.srt.layers.utils import is_sm100_supported
|
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
||||||
from sglang.srt.lora.lora_registry import LoRARef
|
from sglang.srt.lora.lora_registry import LoRARef
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -2117,11 +2117,25 @@ class ServerArgs:
|
|||||||
model_arch = hf_config.architectures[0]
|
model_arch = hf_config.architectures[0]
|
||||||
if model_arch in ["GptOssForCausalLM"]:
|
if model_arch in ["GptOssForCausalLM"]:
|
||||||
if self.attention_backend is None:
|
if self.attention_backend is None:
|
||||||
self.attention_backend = "triton"
|
if is_sm100_supported():
|
||||||
|
self.attention_backend = "trtllm_mha"
|
||||||
|
elif is_sm90_supported():
|
||||||
|
self.attention_backend = "fa3"
|
||||||
|
else:
|
||||||
|
self.attention_backend = "triton"
|
||||||
supported_backends = ["triton", "trtllm_mha", "fa3"]
|
supported_backends = ["triton", "trtllm_mha", "fa3"]
|
||||||
|
logger.info(
|
||||||
|
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
self.attention_backend in supported_backends
|
self.attention_backend in supported_backends
|
||||||
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
|
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
|
||||||
|
|
||||||
|
if is_sm100_supported():
|
||||||
|
self.enable_flashinfer_allreduce_fusion = True
|
||||||
|
logger.info(
|
||||||
|
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
|
||||||
|
)
|
||||||
quantization_config = getattr(hf_config, "quantization_config", None)
|
quantization_config = getattr(hf_config, "quantization_config", None)
|
||||||
is_mxfp4_quant_format = (
|
is_mxfp4_quant_format = (
|
||||||
quantization_config is not None
|
quantization_config is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user