minor: fix config (#1524)
This commit is contained in:
@@ -49,13 +49,13 @@ class ModelConfig:
|
|||||||
if context_length is not None:
|
if context_length is not None:
|
||||||
self.context_len = context_length
|
self.context_len = context_length
|
||||||
else:
|
else:
|
||||||
self.context_len = get_context_length(self.hf_config)
|
self.context_len = get_context_length(self.hf_text_config)
|
||||||
|
|
||||||
# Unify the config keys for hf_config
|
# Unify the config keys for hf_text_config
|
||||||
self.head_dim = getattr(
|
self.head_dim = getattr(
|
||||||
self.hf_config,
|
self.hf_text_config,
|
||||||
"head_dim",
|
"head_dim",
|
||||||
self.hf_config.hidden_size // self.hf_config.num_attention_heads,
|
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
||||||
@@ -72,8 +72,10 @@ class ModelConfig:
|
|||||||
else:
|
else:
|
||||||
self.attention_arch = AttentionArch.MHA
|
self.attention_arch = AttentionArch.MHA
|
||||||
|
|
||||||
self.num_attention_heads = self.hf_config.num_attention_heads
|
self.num_attention_heads = self.hf_text_config.num_attention_heads
|
||||||
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
|
self.num_key_value_heads = getattr(
|
||||||
|
self.hf_text_config, "num_key_value_heads", None
|
||||||
|
)
|
||||||
|
|
||||||
# for Dbrx and MPT models
|
# for Dbrx and MPT models
|
||||||
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
||||||
@@ -83,9 +85,9 @@ class ModelConfig:
|
|||||||
|
|
||||||
if self.num_key_value_heads is None:
|
if self.num_key_value_heads is None:
|
||||||
self.num_key_value_heads = self.num_attention_heads
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
self.hidden_size = self.hf_config.hidden_size
|
self.hidden_size = self.hf_text_config.hidden_size
|
||||||
self.num_hidden_layers = self.hf_config.num_hidden_layers
|
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||||
self.vocab_size = self.hf_config.vocab_size
|
self.vocab_size = self.hf_text_config.vocab_size
|
||||||
|
|
||||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||||
def get_total_num_kv_heads(self) -> int:
|
def get_total_num_kv_heads(self) -> int:
|
||||||
|
|||||||
Reference in New Issue
Block a user