[ModelOpt] Fix Weight Loading for DSR1-FP4 Quantization (#9712)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -235,8 +235,9 @@ class ReplicatedLinear(LinearBase):
|
||||
loaded_weight = loaded_weight[:1]
|
||||
else:
|
||||
raise ValueError(f"{loaded_weight} are not all equal")
|
||||
|
||||
assert param.size() == loaded_weight.size()
|
||||
assert (
|
||||
param.size() == loaded_weight.size()
|
||||
), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
@@ -599,6 +599,13 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
|
||||
# Check if the last part of the excluded pattern is contained in the last part of the prefix
|
||||
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
|
||||
pattern_last_part = pattern.split(".")[-1]
|
||||
prefix_last_part = prefix.split(".")[-1]
|
||||
if pattern_last_part in prefix_last_part:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(
|
||||
|
||||
Reference in New Issue
Block a user