Co-authored-by: averyhuang <averyh@nvidia.com>
This commit is contained in:
@@ -445,7 +445,11 @@ class ServerArgs:
|
||||
"trtllm_mla backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
if self.attention_backend == "trtllm_mha":
|
||||
if (
|
||||
self.attention_backend == "trtllm_mha"
|
||||
or self.decode_attention_backend == "trtllm_mha"
|
||||
or self.prefill_attention_backend == "trtllm_mha"
|
||||
):
|
||||
if not is_sm100_supported():
|
||||
raise ValueError(
|
||||
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
|
||||
@@ -459,11 +463,18 @@ class ServerArgs:
|
||||
|
||||
if self.speculative_algorithm is not None:
|
||||
raise ValueError(
|
||||
"trtllm_mla backend does not support speculative decoding yet."
|
||||
"trtllm_mha backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
model_arch = self.get_hf_config().architectures[0]
|
||||
if model_arch in ["GptOssForCausalLM"]:
|
||||
self.attention_backend = "triton"
|
||||
if self.attention_backend is None:
|
||||
# default is triton, but we could have trtllm_mha as an option
|
||||
self.attention_backend = "triton"
|
||||
assert (
|
||||
self.attention_backend == "trtllm_mha"
|
||||
or self.attention_backend == "triton"
|
||||
)
|
||||
|
||||
# Check if FlashInfer MXFP4 MoE is enabled
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
Reference in New Issue
Block a user