[feat] Support tp mode for DeepSeek-R1-W4AFP8 (#8118)

Co-authored-by: yuhyao <827623970@qq.com>
This commit is contained in:
chenxj
2025-09-02 13:17:26 +08:00
committed by GitHub
parent 21e1bc475c
commit d4a938417d
11 changed files with 291 additions and 120 deletions

View File

@@ -2185,6 +2185,8 @@ class DeepseekV2ForCausalLM(nn.Module):
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
elif self.quant_config.get_name() == "w4afp8":
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
@@ -2496,6 +2498,9 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
)
# Params for special naming rules in mixed-precision models, for example:
# model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
# see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
if self.quant_config and self.quant_config.get_name() == "w4afp8":
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
num_experts=self.config.n_routed_experts