diff --git a/vllm_ascend/ops/moe/moe_mlp.py b/vllm_ascend/ops/moe/moe_mlp.py index 25dedd0..6e7b890 100644 --- a/vllm_ascend/ops/moe/moe_mlp.py +++ b/vllm_ascend/ops/moe/moe_mlp.py @@ -41,7 +41,7 @@ def cumsum_group_list(group_list: torch.Tensor, return group_list.cumsum(dim=0) if src_list_type == 0 and dst_list_type == 1: group_diff = torch.diff(group_list) - new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0) + new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0) return new_group if src_list_type == 2 and dst_list_type == 0: experts = pad(group_list[:, 0], (1, 0))