Support DeepSeek V3.2 Exp (#11061)
Co-authored-by: Stefan He <11166516+hebiao064@users.noreply.github.com> Co-authored-by: Liangsheng Yin <95566987+hnyls2002@users.noreply.github.com> Co-authored-by: Baizhou Zhang <56809903+fridge003@users.noreply.github.com> Co-authored-by: DarkSharpness <76582120+darksharpness@users.noreply.github.com> Co-authored-by: ZhengdQin <46387172+zhengdqin@users.noreply.github.com> Co-authored-by: DarkSharpness <2040703891@qq.com> Co-authored-by: hnyls2002 <lsyincs@gmail.com> Co-authored-by: Zhengda Qin <zhengdqin@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: HAI <hixiao@gmail.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -49,6 +49,30 @@ class ModelImpl(str, Enum):
|
||||
TRANSFORMERS = "transformers"
|
||||
|
||||
|
||||
def is_deepseek_nsa(config: PretrainedConfig) -> bool:
|
||||
return (
|
||||
config.architectures is not None
|
||||
and config.architectures[0]
|
||||
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
|
||||
and getattr(config, "index_topk", None) is not None
|
||||
)
|
||||
|
||||
|
||||
def get_nsa_index_head_dim(config: PretrainedConfig) -> int:
|
||||
assert is_deepseek_nsa(config)
|
||||
return config.index_head_dim
|
||||
|
||||
|
||||
def get_nsa_index_topk(config: PretrainedConfig) -> int:
|
||||
assert is_deepseek_nsa(config)
|
||||
return config.index_topk
|
||||
|
||||
|
||||
def get_nsa_index_n_heads(config: PretrainedConfig) -> int:
|
||||
assert is_deepseek_nsa(config)
|
||||
return config.index_n_heads
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -271,6 +295,7 @@ class ModelConfig:
|
||||
# FIXME: temporary special judge for MLA architecture
|
||||
if (
|
||||
"DeepseekV2ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV32ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
|
||||
or "LongcatFlashForCausalLM" in self.hf_config.architectures
|
||||
@@ -283,6 +308,11 @@ class ModelConfig:
|
||||
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||
self.v_head_dim = self.hf_config.v_head_dim
|
||||
self.index_head_dim = (
|
||||
get_nsa_index_head_dim(self.hf_config)
|
||||
if is_deepseek_nsa(self.hf_config)
|
||||
else None
|
||||
)
|
||||
|
||||
# Handle rope scaling with yarn
|
||||
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
|
||||
|
||||
Reference in New Issue
Block a user