Update grok 1 model (#1095)
This commit is contained in:
@@ -53,7 +53,7 @@ from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
is_generation_model,
|
||||
is_llama3_405b_fp8,
|
||||
is_llama3_405b_fp8_head_16,
|
||||
is_multimodal_model,
|
||||
monkey_patch_vllm_dummy_weight_loader,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
@@ -158,7 +158,7 @@ class ModelRunner:
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
|
||||
if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8:
|
||||
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
||||
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
||||
self.model_config.hf_config.num_key_value_heads = 8
|
||||
vllm_model_config.hf_config.num_key_value_heads = 8
|
||||
|
||||
Reference in New Issue
Block a user