Fix global input scale incompatible with CuTe DSL moe (#10370)
This commit is contained in:
@@ -1187,6 +1187,21 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
|
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
|
||||||
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
||||||
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
||||||
|
elif self.enable_flashinfer_cutedsl_moe:
|
||||||
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||||
|
w2_input_scale = layer.w2_input_scale
|
||||||
|
|
||||||
|
def _slice_scale(w):
|
||||||
|
assert w.shape == (layer.num_experts,)
|
||||||
|
assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
|
||||||
|
return w[
|
||||||
|
layer.moe_ep_rank
|
||||||
|
* layer.num_local_experts : (layer.moe_ep_rank + 1)
|
||||||
|
* layer.num_local_experts
|
||||||
|
]
|
||||||
|
|
||||||
|
w13_input_scale = _slice_scale(w13_input_scale)
|
||||||
|
w2_input_scale = _slice_scale(w2_input_scale)
|
||||||
else:
|
else:
|
||||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||||
w2_input_scale = layer.w2_input_scale
|
w2_input_scale = layer.w2_input_scale
|
||||||
|
|||||||
Reference in New Issue
Block a user