[Bugfix] bugfix for moe_mlp in vllm-ascend/v0.11.0-dev (#4885)

### What this PR does / why we need it?
This PR fixes a bug in the moe_mlp module by correcting the arguments
passed to the torch_npu.npu_dequant_swiglu_quant function.It properly
converts group_list from a cumulative sum to counts for the group_index
parameter.

### Does this PR introduce _any_ user-facing change?
No


- vLLM version: v0.12.0
- vLLM main: https://github.com/vllm-project/vllm/main

---------

Signed-off-by: tanqingshan (A)  <50050625@china.huawei.com>
Signed-off-by: tanqingshan (A) <50050625@china.huawei.com>
Co-authored-by: tanqingshan (A) <50050625@china.huawei.com>
Co-authored-by: Mercykid-bash <ruanche0218@gmail.com>
This commit is contained in:
Clorist33
2025-12-12 14:51:47 +08:00
committed by GitHub
parent 9c0ad46c1a
commit 4f0dddc9ee
5 changed files with 41 additions and 34 deletions

View File

@@ -26,31 +26,39 @@ from vllm_ascend.utils import dispose_tensor, is_310p
def cumsum_group_list(group_list: torch.Tensor,
group_list_type: int,
src_list_type: int,
dst_list_type: int,
active_num: int = 0,
expert_num: int = 0) -> torch.Tensor:
if group_list_type not in [0, 1, 2]:
if src_list_type not in [0, 1, 2]:
raise ValueError(
f"group_list_type should be in [0, 1, 2], but received {group_list_type}"
f"group_list_type should be in [0, 1, 2], but received {src_list_type}"
)
if group_list_type == 0:
if src_list_type == dst_list_type:
return group_list
if group_list_type == 1:
if src_list_type == 1 and dst_list_type == 0:
return group_list.cumsum(dim=0)
if src_list_type == 0 and dst_list_type == 1:
group_diff = torch.diff(group_list)
new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0)
return new_group
if src_list_type == 2 and dst_list_type == 0:
experts = pad(group_list[:, 0], (1, 0))
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
cumsum_group_list = torch.full(size=(expert_num, ),
fill_value=active_num,
dtype=group_list.dtype,
device=group_list.device)
experts = pad(group_list[:, 0], (1, 0))
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
cumsum_group_list = torch.full(size=(expert_num, ),
fill_value=active_num,
dtype=group_list.dtype,
device=group_list.device)
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
if end > start:
cumsum_group_list[start:end] = tokens[i]
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
if end > start:
cumsum_group_list[start:end] = tokens[i]
return cumsum_group_list
return cumsum_group_list
raise NotImplementedError(
f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. "
"This feature is under development.")
def quant_apply_mlp(hidden_states: torch.Tensor,
@@ -89,7 +97,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1,
group_list=cumsum_group_list(group_list, group_list_type),
group_list=cumsum_group_list(group_list, group_list_type, 0),
weight_scale=w1_scale,
x_scale=pertoken_scale)
else:
@@ -105,9 +113,6 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list=group_list,
output_dtype=torch.int32)[0]
# act_fn: swiglu
group_diff = torch.diff(group_list)
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff],
dim=0)
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
@@ -115,7 +120,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=new_group,
group_index=cumsum_group_list(group_list, group_list_type, 1),
activate_left=True,
quant_mode=1,
)
@@ -148,7 +153,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
x=hidden_states,
weight=w1,
bias=bias1,
group_list=cumsum_group_list(group_list, group_list_type),
group_list=cumsum_group_list(group_list, group_list_type, 0),
weight_scale=w1_scale,
x_scale=pertoken_scale)
else:
@@ -258,4 +263,4 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
group_list=group_list,
group_list_type=group_list_type,
topk_scales=topk_scales,
need_trans=need_trans)
need_trans=need_trans)