[Bugfix] Fix Weightloading for the original nvidia/Deepseek-R1-FP4 checkpoint (#9940)
Signed-off-by: Pavani Majety <pmajety@nvidia.com> Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
@@ -642,10 +642,22 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
||||
import regex as re
|
||||
|
||||
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
|
||||
prefix_split = prefix.split(".")
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||
pattern_split = pattern.split(".")
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
elif (
|
||||
pattern_split[-1] in fused_patterns
|
||||
and pattern_split[-1] in prefix_split[-1]
|
||||
):
|
||||
# Check if the last part of the excluded pattern is contained in the last part of the prefix
|
||||
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
|
||||
# e.g., model.layers.{i}.self_attn.{fused_weight_name}
|
||||
assert len(prefix_split) == 5 and len(pattern_split) == 5
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(
|
||||
@@ -1250,8 +1262,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight_scale,
|
||||
)
|
||||
|
||||
logger.info_once("Applied flashinfer weight processing for both w13 and w2")
|
||||
|
||||
else:
|
||||
# CUTLASS processing - handle w13 and w2 separately
|
||||
|
||||
@@ -1268,7 +1278,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||
|
||||
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
||||
logger.info_once("Applied weight processing for both w13 and w2")
|
||||
|
||||
# Set up CUTLASS MoE parameters
|
||||
device = layer.w13_weight.device
|
||||
|
||||
@@ -654,11 +654,13 @@ class ServerArgs:
|
||||
], "The expert parallel size must be 1 or the same as the tensor parallel size"
|
||||
|
||||
if self.moe_runner_backend == "flashinfer_trtllm":
|
||||
if not self.disable_shared_experts_fusion:
|
||||
self.disable_shared_experts_fusion = True
|
||||
logger.warning(
|
||||
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
|
||||
)
|
||||
assert (
|
||||
self.quantization == "modelopt_fp4" or self.quantization == "fp8"
|
||||
), "modelopt_fp4 quantization is required for Flashinfer TRTLLM MoE"
|
||||
self.disable_shared_experts_fusion = True
|
||||
logger.warning(
|
||||
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
|
||||
)
|
||||
|
||||
# DeepEP MoE
|
||||
if self.moe_a2a_backend == "deepep":
|
||||
|
||||
Reference in New Issue
Block a user