[Bugfix] bugfix for moe_mlp (#4822)
### What this PR does / why we need it?
This PR fixes a bug in the moe_mlp module by correcting the arguments
passed to the torch_npu.npu_dequant_swiglu_quant function.It properly
converts group_list from a cumulative sum to counts for the group_index
parameter.
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: tanqingshan (A) <50050625@china.huawei.com>
Co-authored-by: tanqingshan (A) <50050625@china.huawei.com>
This commit is contained in:
@@ -294,13 +294,13 @@ class TestCumsumGroupList(TestBase):
|
||||
def test_cumsum_group_list_with_type_0(self):
|
||||
group_list = self.experts.cumsum(dim=0)
|
||||
group_list_type = 0
|
||||
result = cumsum_group_list(group_list, group_list_type)
|
||||
result = cumsum_group_list(group_list, group_list_type, 0)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
def test_cumsum_group_list_with_type_1(self):
|
||||
group_list = self.experts
|
||||
group_list_type = 1
|
||||
result = cumsum_group_list(group_list, group_list_type)
|
||||
result = cumsum_group_list(group_list, group_list_type, 0)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
def test_cumsum_group_list_with_type_2(self):
|
||||
@@ -313,6 +313,7 @@ class TestCumsumGroupList(TestBase):
|
||||
group_list_type = 2
|
||||
result = cumsum_group_list(group_list,
|
||||
group_list_type,
|
||||
0,
|
||||
active_num=self.active_num,
|
||||
expert_num=self.expert_num)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
@@ -31,31 +31,39 @@ def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
||||
|
||||
|
||||
def cumsum_group_list(group_list: torch.Tensor,
|
||||
group_list_type: int,
|
||||
src_list_type: int,
|
||||
dst_list_type: int,
|
||||
active_num: int = 0,
|
||||
expert_num: int = 0) -> torch.Tensor:
|
||||
if group_list_type not in [0, 1, 2]:
|
||||
if src_list_type not in [0, 1, 2]:
|
||||
raise ValueError(
|
||||
f"group_list_type should be in [0, 1, 2], but received {group_list_type}"
|
||||
f"group_list_type should be in [0, 1, 2], but received {src_list_type}"
|
||||
)
|
||||
|
||||
if group_list_type == 0:
|
||||
if src_list_type == dst_list_type:
|
||||
return group_list
|
||||
if group_list_type == 1:
|
||||
if src_list_type == 1 and dst_list_type == 0:
|
||||
return group_list.cumsum(dim=0)
|
||||
if src_list_type == 0 and dst_list_type == 1:
|
||||
group_diff = torch.diff(group_list)
|
||||
new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0)
|
||||
return new_group
|
||||
if src_list_type == 2 and dst_list_type == 0:
|
||||
experts = pad(group_list[:, 0], (1, 0))
|
||||
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
|
||||
cumsum_group_list = torch.full(size=(expert_num, ),
|
||||
fill_value=active_num,
|
||||
dtype=group_list.dtype,
|
||||
device=group_list.device)
|
||||
|
||||
experts = pad(group_list[:, 0], (1, 0))
|
||||
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
|
||||
cumsum_group_list = torch.full(size=(expert_num, ),
|
||||
fill_value=active_num,
|
||||
dtype=group_list.dtype,
|
||||
device=group_list.device)
|
||||
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
|
||||
if end > start:
|
||||
cumsum_group_list[start:end] = tokens[i]
|
||||
|
||||
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
|
||||
if end > start:
|
||||
cumsum_group_list[start:end] = tokens[i]
|
||||
|
||||
return cumsum_group_list
|
||||
return cumsum_group_list
|
||||
raise NotImplementedError(
|
||||
f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. "
|
||||
"This feature is under development.")
|
||||
|
||||
|
||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
@@ -106,14 +114,15 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
weight=w1,
|
||||
weight_scale=w1_scale,
|
||||
x_scale=pertoken_scale,
|
||||
group_list=cumsum_group_list(group_list, group_list_type),
|
||||
group_list=cumsum_group_list(group_list, group_list_type,
|
||||
0),
|
||||
))
|
||||
elif fusion and not dynamic_eplb:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight=w1[0],
|
||||
group_list=cumsum_group_list(group_list, group_list_type),
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
weight_scale=w1_scale[0],
|
||||
x_scale=pertoken_scale)
|
||||
if quantized_hidden_states is not None:
|
||||
@@ -140,7 +149,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
group_index=cumsum_group_list(group_list, group_list_type, 1),
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
@@ -202,7 +211,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
weight=w1,
|
||||
weight_scale=w1_scale,
|
||||
x_scale=pertoken_scale,
|
||||
group_list=cumsum_group_list(group_list, group_list_type),
|
||||
group_list=cumsum_group_list(group_list, group_list_type,
|
||||
0),
|
||||
bias=bias1,
|
||||
))
|
||||
elif fusion and not dynamic_eplb:
|
||||
@@ -211,7 +221,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
x=hidden_states,
|
||||
weight=w1[0],
|
||||
bias=bias1,
|
||||
group_list=cumsum_group_list(group_list, group_list_type),
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
weight_scale=w1_scale[0],
|
||||
x_scale=pertoken_scale)
|
||||
if quantized_hidden_states is not None:
|
||||
|
||||
Reference in New Issue
Block a user