[DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (#6853)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -1746,7 +1746,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = False
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
||||
"Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
||||
)
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
@@ -1926,6 +1926,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self_attn.use_deep_gemm_bmm = True
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
||||
|
||||
if is_nextn:
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
@@ -1982,6 +1983,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
"up_proj.qzeros",
|
||||
"up_proj.scales",
|
||||
]
|
||||
elif self.quant_config.get_name() == "modelopt_fp4":
|
||||
suffix_list = [
|
||||
"down_proj.weight",
|
||||
"down_proj.weight_scale",
|
||||
"down_proj.weight_scale_2",
|
||||
"down_proj.input_scale",
|
||||
"gate_proj.weight",
|
||||
"gate_proj.weight_scale",
|
||||
"gate_proj.weight_scale_2",
|
||||
"gate_proj.input_scale",
|
||||
"up_proj.weight",
|
||||
"up_proj.weight_scale",
|
||||
"up_proj.weight_scale_2",
|
||||
"up_proj.input_scale",
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
||||
@@ -2125,7 +2141,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if fuse_qkv_a_proj and (
|
||||
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
||||
):
|
||||
@@ -2151,9 +2166,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
fused_weight = torch.cat(
|
||||
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
||||
)
|
||||
|
||||
param_name = name.replace(
|
||||
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
||||
param_name = (
|
||||
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
||||
if "q_a_proj" in name
|
||||
else name.replace(
|
||||
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
||||
)
|
||||
)
|
||||
param = params_dict[param_name]
|
||||
|
||||
@@ -2164,6 +2182,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
cached_a_proj.pop(q_a_proj_name)
|
||||
cached_a_proj.pop(kv_a_proj_name)
|
||||
else:
|
||||
if (
|
||||
"k_scale" in name or "v_scale" in name
|
||||
) and name not in params_dict:
|
||||
# modelopt attn kv scale is named differently
|
||||
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
||||
name = name.replace("_proj", "attn_mqa")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown scale found in checkpoint: {name}"
|
||||
)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
|
||||
Reference in New Issue
Block a user