diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index cc2d307..ed32b93 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -269,6 +269,7 @@ class TestTokenDispatcherWithAllGather(TestBase): def test_token_combine_with_expert_map(self): self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1]) + self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1]) self.dispatcher.sorted_weights = torch.tensor( [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) self.dispatcher.original_shape = (3, 128) diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index b36cc44..5e17815 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -383,7 +383,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): assert self.original_shape is not None final_hidden_states = torch_npu.npu_moe_token_unpermute( permuted_tokens=hidden_states, - sorted_indices=self.expanded_row_idx, + sorted_indices=torch.abs(self.expanded_row_idx), probs=self.topk_weights) if len(self.original_shape) == 3: final_hidden_states = final_hidden_states.view(self.original_shape)