diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 8459c98b8..23a334fd6 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -121,7 +121,7 @@ class ModelRunner: skip_tokenizer_init=True, ) - if is_llama3_405b_fp8(self.model_config): + if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8: # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints self.model_config.hf_config.num_key_value_heads = 8 vllm_model_config.hf_config.num_key_value_heads = 8 diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e4367f4a4..ed90b9267 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -626,6 +626,7 @@ def is_llama3_405b_fp8(model_config): and model_config.hf_config.intermediate_size == 53248 and model_config.hf_config.num_hidden_layers == 126 and model_config.hf_config.num_key_value_heads == 16 + and hasattr(model_config.hf_config, "quantization_config") and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8" ): return True