fix: small bug for llama-405b fp16 (#733)
This commit is contained in:
@@ -121,7 +121,7 @@ class ModelRunner:
|
|||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_llama3_405b_fp8(self.model_config):
|
if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8:
|
||||||
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
||||||
self.model_config.hf_config.num_key_value_heads = 8
|
self.model_config.hf_config.num_key_value_heads = 8
|
||||||
vllm_model_config.hf_config.num_key_value_heads = 8
|
vllm_model_config.hf_config.num_key_value_heads = 8
|
||||||
|
|||||||
@@ -626,6 +626,7 @@ def is_llama3_405b_fp8(model_config):
|
|||||||
and model_config.hf_config.intermediate_size == 53248
|
and model_config.hf_config.intermediate_size == 53248
|
||||||
and model_config.hf_config.num_hidden_layers == 126
|
and model_config.hf_config.num_hidden_layers == 126
|
||||||
and model_config.hf_config.num_key_value_heads == 16
|
and model_config.hf_config.num_key_value_heads == 16
|
||||||
|
and hasattr(model_config.hf_config, "quantization_config")
|
||||||
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
|
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|||||||
Reference in New Issue
Block a user