[Bugfix] bugfix for moe_mlp in vllm-ascend/v0.11.0-dev (#4885)
### What this PR does / why we need it? This PR fixes a bug in the moe_mlp module by correcting the arguments passed to the torch_npu.npu_dequant_swiglu_quant function.It properly converts group_list from a cumulative sum to counts for the group_index parameter. ### Does this PR introduce _any_ user-facing change? No - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/main --------- Signed-off-by: tanqingshan (A) <50050625@china.huawei.com> Signed-off-by: tanqingshan (A) <50050625@china.huawei.com> Co-authored-by: tanqingshan (A) <50050625@china.huawei.com> Co-authored-by: Mercykid-bash <ruanche0218@gmail.com>
This commit is contained in:
@@ -293,13 +293,13 @@ class TestCumsumGroupList(TestBase):
|
||||
def test_cumsum_group_list_with_type_0(self):
|
||||
group_list = self.experts.cumsum(dim=0)
|
||||
group_list_type = 0
|
||||
result = cumsum_group_list(group_list, group_list_type)
|
||||
result = cumsum_group_list(group_list, group_list_type, 0)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
def test_cumsum_group_list_with_type_1(self):
|
||||
group_list = self.experts
|
||||
group_list_type = 1
|
||||
result = cumsum_group_list(group_list, group_list_type)
|
||||
result = cumsum_group_list(group_list, group_list_type, 0)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
def test_cumsum_group_list_with_type_2(self):
|
||||
@@ -312,6 +312,7 @@ class TestCumsumGroupList(TestBase):
|
||||
group_list_type = 2
|
||||
result = cumsum_group_list(group_list,
|
||||
group_list_type,
|
||||
0,
|
||||
active_num=self.active_num,
|
||||
expert_num=self.expert_num)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
Reference in New Issue
Block a user