[ModelOpt] Fix Weight Loading for DSR1-FP4 Quantization (#9712)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -599,6 +599,13 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
|
||||
# Check if the last part of the excluded pattern is contained in the last part of the prefix
|
||||
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
|
||||
pattern_last_part = pattern.split(".")[-1]
|
||||
prefix_last_part = prefix.split(".")[-1]
|
||||
if pattern_last_part in prefix_last_part:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(
|
||||
|
||||
Reference in New Issue
Block a user