[bugfix] Fix moe bug: allgather error. (#3279)
It will crash when deepseek model executed in A2. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -269,6 +269,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
def test_token_combine_with_expert_map(self):
|
def test_token_combine_with_expert_map(self):
|
||||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||||
|
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||||
self.dispatcher.sorted_weights = torch.tensor(
|
self.dispatcher.sorted_weights = torch.tensor(
|
||||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||||
self.dispatcher.original_shape = (3, 128)
|
self.dispatcher.original_shape = (3, 128)
|
||||||
|
|||||||
@@ -383,7 +383,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
assert self.original_shape is not None
|
assert self.original_shape is not None
|
||||||
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||||
permuted_tokens=hidden_states,
|
permuted_tokens=hidden_states,
|
||||||
sorted_indices=self.expanded_row_idx,
|
sorted_indices=torch.abs(self.expanded_row_idx),
|
||||||
probs=self.topk_weights)
|
probs=self.topk_weights)
|
||||||
if len(self.original_shape) == 3:
|
if len(self.original_shape) == 3:
|
||||||
final_hidden_states = final_hidden_states.view(self.original_shape)
|
final_hidden_states = final_hidden_states.view(self.original_shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user