Support DeepSeek V3.2 Exp (#11061)

Co-authored-by: Stefan He <11166516+hebiao064@users.noreply.github.com>
Co-authored-by: Liangsheng Yin <95566987+hnyls2002@users.noreply.github.com>
Co-authored-by: Baizhou Zhang <56809903+fridge003@users.noreply.github.com>
Co-authored-by: DarkSharpness <76582120+darksharpness@users.noreply.github.com>
Co-authored-by: ZhengdQin <46387172+zhengdqin@users.noreply.github.com>
Co-authored-by: DarkSharpness <2040703891@qq.com>
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
Co-authored-by: Zhengda Qin <zhengdqin@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: HAI <hixiao@gmail.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
fzyzcjy
2025-10-06 15:24:15 +08:00
committed by GitHub
parent 292a867ad9
commit efbc687c28
29 changed files with 4540 additions and 139 deletions

View File

@@ -49,6 +49,30 @@ class ModelImpl(str, Enum):
TRANSFORMERS = "transformers"
def is_deepseek_nsa(config: PretrainedConfig) -> bool:
return (
config.architectures is not None
and config.architectures[0]
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
and getattr(config, "index_topk", None) is not None
)
def get_nsa_index_head_dim(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_head_dim
def get_nsa_index_topk(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_topk
def get_nsa_index_n_heads(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_n_heads
class ModelConfig:
def __init__(
self,
@@ -271,6 +295,7 @@ class ModelConfig:
# FIXME: temporary special judge for MLA architecture
if (
"DeepseekV2ForCausalLM" in self.hf_config.architectures
or "DeepseekV32ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
or "LongcatFlashForCausalLM" in self.hf_config.architectures
@@ -283,6 +308,11 @@ class ModelConfig:
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
self.index_head_dim = (
get_nsa_index_head_dim(self.hf_config)
if is_deepseek_nsa(self.hf_config)
else None
)
# Handle rope scaling with yarn
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)