[bugfix] Fix moe bug: allgather error. (#3279)
It will crash when deepseek model executed in A2. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -269,6 +269,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
def test_token_combine_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.sorted_weights = torch.tensor(
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||
self.dispatcher.original_shape = (3, 128)
|
||||
|
||||
@@ -383,7 +383,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
assert self.original_shape is not None
|
||||
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=hidden_states,
|
||||
sorted_indices=self.expanded_row_idx,
|
||||
sorted_indices=torch.abs(self.expanded_row_idx),
|
||||
probs=self.topk_weights)
|
||||
if len(self.original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(self.original_shape)
|
||||
|
||||
Reference in New Issue
Block a user