[LongCat] Optimize zero_experts_compute_triton by changing mask (#10303)
This commit is contained in:
@@ -1416,7 +1416,7 @@ def zero_experts_compute_triton(
|
||||
zero_expert_scales[zero_expert_mask] = 0.0
|
||||
|
||||
normal_expert_mask = expert_indices >= num_experts
|
||||
expert_indices[normal_expert_mask] = 0
|
||||
expert_indices[normal_expert_mask] = -1
|
||||
expert_scales[normal_expert_mask] = 0.0
|
||||
|
||||
output = torch.zeros_like(hidden_states).to(hidden_states.device)
|
||||
|
||||
Reference in New Issue
Block a user