Fix global input scale incompatible with CuTe DSL moe (#10370)
This commit is contained in:
@@ -1187,6 +1187,21 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
|
||||
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
||||
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
||||
elif self.enable_flashinfer_cutedsl_moe:
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
|
||||
def _slice_scale(w):
|
||||
assert w.shape == (layer.num_experts,)
|
||||
assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
|
||||
return w[
|
||||
layer.moe_ep_rank
|
||||
* layer.num_local_experts : (layer.moe_ep_rank + 1)
|
||||
* layer.num_local_experts
|
||||
]
|
||||
|
||||
w13_input_scale = _slice_scale(w13_input_scale)
|
||||
w2_input_scale = _slice_scale(w2_input_scale)
|
||||
else:
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
|
||||
Reference in New Issue
Block a user