[Bugfix] Fix Weightloading for the original nvidia/Deepseek-R1-FP4 checkpoint (#9940)
Signed-off-by: Pavani Majety <pmajety@nvidia.com> Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
@@ -642,10 +642,22 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
||||||
import regex as re
|
import regex as re
|
||||||
|
|
||||||
|
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
|
||||||
|
prefix_split = prefix.split(".")
|
||||||
for pattern in exclude_modules:
|
for pattern in exclude_modules:
|
||||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||||
|
pattern_split = pattern.split(".")
|
||||||
if re.fullmatch(regex_str, prefix):
|
if re.fullmatch(regex_str, prefix):
|
||||||
return True
|
return True
|
||||||
|
elif (
|
||||||
|
pattern_split[-1] in fused_patterns
|
||||||
|
and pattern_split[-1] in prefix_split[-1]
|
||||||
|
):
|
||||||
|
# Check if the last part of the excluded pattern is contained in the last part of the prefix
|
||||||
|
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
|
||||||
|
# e.g., model.layers.{i}.self_attn.{fused_weight_name}
|
||||||
|
assert len(prefix_split) == 5 and len(pattern_split) == 5
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
@@ -1250,8 +1262,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_weight_scale,
|
layer.w13_weight_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info_once("Applied flashinfer weight processing for both w13 and w2")
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# CUTLASS processing - handle w13 and w2 separately
|
# CUTLASS processing - handle w13 and w2 separately
|
||||||
|
|
||||||
@@ -1268,7 +1278,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||||
|
|
||||||
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
||||||
logger.info_once("Applied weight processing for both w13 and w2")
|
|
||||||
|
|
||||||
# Set up CUTLASS MoE parameters
|
# Set up CUTLASS MoE parameters
|
||||||
device = layer.w13_weight.device
|
device = layer.w13_weight.device
|
||||||
|
|||||||
@@ -654,11 +654,13 @@ class ServerArgs:
|
|||||||
], "The expert parallel size must be 1 or the same as the tensor parallel size"
|
], "The expert parallel size must be 1 or the same as the tensor parallel size"
|
||||||
|
|
||||||
if self.moe_runner_backend == "flashinfer_trtllm":
|
if self.moe_runner_backend == "flashinfer_trtllm":
|
||||||
if not self.disable_shared_experts_fusion:
|
assert (
|
||||||
self.disable_shared_experts_fusion = True
|
self.quantization == "modelopt_fp4" or self.quantization == "fp8"
|
||||||
logger.warning(
|
), "modelopt_fp4 quantization is required for Flashinfer TRTLLM MoE"
|
||||||
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
|
self.disable_shared_experts_fusion = True
|
||||||
)
|
logger.warning(
|
||||||
|
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
|
||||||
|
)
|
||||||
|
|
||||||
# DeepEP MoE
|
# DeepEP MoE
|
||||||
if self.moe_a2a_backend == "deepep":
|
if self.moe_a2a_backend == "deepep":
|
||||||
|
|||||||
Reference in New Issue
Block a user