[DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (#6853)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2025-06-07 17:24:35 -07:00
committed by GitHub
parent 23881fa60c
commit c2c4f57f63
3 changed files with 386 additions and 13 deletions

View File

@@ -1746,7 +1746,7 @@ class DeepseekV2ForCausalLM(nn.Module):
global_server_args_dict["disable_shared_experts_fusion"] = False
log_info_on_rank0(
logger,
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
"Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
)
def get_input_embeddings(self) -> nn.Embedding:
@@ -1926,6 +1926,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.use_deep_gemm_bmm = True
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
@@ -1982,6 +1983,21 @@ class DeepseekV2ForCausalLM(nn.Module):
"up_proj.qzeros",
"up_proj.scales",
]
elif self.quant_config.get_name() == "modelopt_fp4":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"down_proj.weight_scale_2",
"down_proj.input_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"gate_proj.weight_scale_2",
"gate_proj.input_scale",
"up_proj.weight",
"up_proj.weight_scale",
"up_proj.weight_scale_2",
"up_proj.input_scale",
]
else:
raise ValueError(
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
@@ -2125,7 +2141,6 @@ class DeepseekV2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
@@ -2151,9 +2166,12 @@ class DeepseekV2ForCausalLM(nn.Module):
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)
param_name = name.replace(
"q_a_proj", "fused_qkv_a_proj_with_mqa"
param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
)
)
param = params_dict[param_name]
@@ -2164,6 +2182,16 @@ class DeepseekV2ForCausalLM(nn.Module):
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
if any(scale in name for scale in ["k_scale", "v_scale"]):
name = name.replace("_proj", "attn_mqa")
else:
logger.warning(
f"Unknown scale found in checkpoint: {name}"
)
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader