[LongCat] Optimize zero_experts_compute_triton by changing mask (#10303)
This commit is contained in:
@@ -1416,7 +1416,7 @@ def zero_experts_compute_triton(
|
|||||||
zero_expert_scales[zero_expert_mask] = 0.0
|
zero_expert_scales[zero_expert_mask] = 0.0
|
||||||
|
|
||||||
normal_expert_mask = expert_indices >= num_experts
|
normal_expert_mask = expert_indices >= num_experts
|
||||||
expert_indices[normal_expert_mask] = 0
|
expert_indices[normal_expert_mask] = -1
|
||||||
expert_scales[normal_expert_mask] = 0.0
|
expert_scales[normal_expert_mask] = 0.0
|
||||||
|
|
||||||
output = torch.zeros_like(hidden_states).to(hidden_states.device)
|
output = torch.zeros_like(hidden_states).to(hidden_states.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user