Fix MTP MoE weight loading with NVFP4 target model. (#10758)
This commit is contained in:
@@ -575,7 +575,10 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
||||||
if should_use_flashinfer_trtllm_moe():
|
if (
|
||||||
|
should_use_flashinfer_trtllm_moe()
|
||||||
|
and self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
|
||||||
|
):
|
||||||
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
||||||
|
|
||||||
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
||||||
|
|||||||
Reference in New Issue
Block a user