diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/sparse_moe_mlp.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/sparse_moe_mlp.py index 269cac4..efd726c 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/sparse_moe_mlp.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/layers/sparse_moe_mlp.py @@ -425,9 +425,9 @@ class SparseMoeMlp(nn.Module): scatter_idx=torch.zeros((indices.numel(),), dtype=seqs.dtype, device=seqs.device).scatter(0, indices, seqs) # token_count: [self.num_experts_per_rank] - partial_token_index, partial_token_count = sorted_expert_id.unique(sorted=True, return_counts=True) - zero_token_count = torch.zeros(self.num_total_experts, dtype=partial_token_count.dtype, device=device) - token_count = zero_token_count.scatter(dim=0, index=partial_token_index, src=partial_token_count) + # Use scatter_add_ instead of torch.unique for MLU graph capture compatibility + token_count = torch.zeros(self.num_total_experts, dtype=sorted_expert_id.dtype, device=device) + token_count.scatter_add_(0, sorted_expert_id, torch.ones_like(sorted_expert_id)) # cusum_token_count: [self.num_experts_per_rank + 1] cusum_token_count = torch.cat( [torch.tensor([0], dtype=token_count.dtype, device=device),