Tiny cleanup deepseek_v2.py (#11163)

This commit is contained in:
fzyzcjy
2025-10-02 21:54:52 +08:00
committed by GitHub
parent 948278f173
commit b65db0287b
2 changed files with 38 additions and 38 deletions

View File

@@ -234,6 +234,13 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_moe_runner(self, self.moe_runner_config)
self.dispatcher = StandardDispatcher()
self.should_fuse_routed_scaling_factor_in_topk = isinstance(
self.quant_method, ModelOptNvFp4FusedMoEMethod
) or (
isinstance(self.quant_method, Fp8MoEMethod)
and self.quant_method.use_cutlass_fused_experts_fp8
)
def _load_per_tensor_weight_scale(
self,
shard_id: str,
@@ -936,12 +943,6 @@ class FusedMoE(torch.nn.Module):
for shard_id in ["w1", "w2", "w3"]
]
def should_fuse_routed_scaling_factor_in_topk(self):
return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
isinstance(self.quant_method, Fp8MoEMethod)
and self.quant_method.use_cutlass_fused_experts_fp8
)
class FlashInferFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs):