diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/sparse_moe_mlp.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/sparse_moe_mlp.py index efd726c..d509fcb 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/sparse_moe_mlp.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/sparse_moe_mlp.py @@ -284,34 +284,28 @@ class SparseMoeMlp(nn.Module): def forward_experts_nofused(self, hidden_states, expert_logits): - hidden_states_shape = hidden_states.shape + # Dense approach: each expert processes ALL tokens, then mask by routing + # weights. This avoids data-dependent control flow (variable-size slicing, + # conditional branches, torch.unique, torch.tensor creation) that is + # incompatible with MLU graph capture. + num_tokens, hidden_size = hidden_states.shape topk_values, topk_indices = self.topk_softmax(expert_logits) - expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = self.generate_gather_idx( - topk_indices) - # no expert is routed, then expand_gather_idx, expand_scatter_idx has no item, - # expand_token_count and expand_cusum_token_count has item but the value is all zero - # so this rank should only return final_hidden_states with zero value - if expand_gather_idx.numel() == 0: - final_hidden_states = torch.zeros_like(hidden_states, - dtype=hidden_states.dtype, - device=hidden_states.device) - return final_hidden_states - expand_hidden_states = self.expand_input(hidden_states, expand_gather_idx) + final_hidden_states = torch.zeros( + num_tokens, hidden_size, + dtype=hidden_states.dtype, device=hidden_states.device) - expand_output_list = [] - expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id + - 1] - cusum_token_count[self.start_expert_id] - for expert_idx, num_tokens_per_expert in enumerate(expand_token_count): - if num_tokens_per_expert > 0: - expert_hidden_states = expand_hidden_states[ - expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]] - expert_output = self.experts[expert_idx](expert_hidden_states) - expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output - expand_output_list.append(expert_output) - expand_output = torch.cat(expand_output_list, dim=0) - final_hidden_states = self.combine_moe(expand_output, scatter_idx, cusum_token_count, hidden_states_shape, - topk_values) + for expert_idx in range(self.num_experts_per_rank): + global_expert_idx = self.start_expert_id + expert_idx + expert_output = self.experts[expert_idx](hidden_states) + expert_output = expert_output[0] if isinstance( + expert_output, (tuple, list)) else expert_output + + # Routing weight per token for this expert + expert_mask = (topk_indices == global_expert_idx).to(topk_values.dtype) + expert_weights = (topk_values * expert_mask).sum(dim=-1, keepdim=True) + + final_hidden_states = final_hidden_states + expert_output * expert_weights return final_hidden_states