[Bugfix] bugfix for moe_mlp in vllm-ascend/v0.11.0-dev (#4885)
### What this PR does / why we need it? This PR fixes a bug in the moe_mlp module by correcting the arguments passed to the torch_npu.npu_dequant_swiglu_quant function.It properly converts group_list from a cumulative sum to counts for the group_index parameter. ### Does this PR introduce _any_ user-facing change? No - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/main --------- Signed-off-by: tanqingshan (A) <50050625@china.huawei.com> Signed-off-by: tanqingshan (A) <50050625@china.huawei.com> Co-authored-by: tanqingshan (A) <50050625@china.huawei.com> Co-authored-by: Mercykid-bash <ruanche0218@gmail.com>
This commit is contained in:
@@ -47,8 +47,8 @@ def test_generate_task_and_state_flow(mock_adaptor):
|
|||||||
loader_obj.state = loader.ExpertWeightUpdateState.WAITING
|
loader_obj.state = loader.ExpertWeightUpdateState.WAITING
|
||||||
|
|
||||||
loader_obj.generate_expert_d2d_transfer_task([], [], {}, 0)
|
loader_obj.generate_expert_d2d_transfer_task([], [], {}, 0)
|
||||||
assert loader_obj.comm_op_list is None
|
assert not loader_obj.comm_op_list
|
||||||
assert loader_obj.state == loader.ExpertWeightUpdateState.WAITING
|
assert loader_obj.state == loader.ExpertWeightUpdateState.READY
|
||||||
|
|
||||||
|
|
||||||
def test_asyn_transfer_and_update(mock_adaptor):
|
def test_asyn_transfer_and_update(mock_adaptor):
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(expert_placement_map.shape,
|
self.assertEqual(expert_placement_map.shape,
|
||||||
(self.expert_load_balancer.layers_num,
|
(self.expert_load_balancer.layers_num,
|
||||||
self.expert_load_balancer.ranks_num, 10))
|
self.expert_load_balancer.ranks_num, 8))
|
||||||
self.assertTrue(torch.all(expert_placement_map >= -1))
|
self.assertTrue(torch.all(expert_placement_map >= -1))
|
||||||
|
|
||||||
def test_generate_log2phy_expert_map(self):
|
def test_generate_log2phy_expert_map(self):
|
||||||
@@ -90,7 +90,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
|
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
|
||||||
layer_id)
|
layer_id)
|
||||||
self.assertEqual(log2phy_map.shape,
|
self.assertEqual(log2phy_map.shape,
|
||||||
(self.expert_load_balancer.ranks_num, 10))
|
(self.expert_load_balancer.ranks_num, 8))
|
||||||
self.assertTrue(torch.all(log2phy_map >= -1))
|
self.assertTrue(torch.all(log2phy_map >= -1))
|
||||||
|
|
||||||
@mock.patch("torch_npu.npu._lazy_init")
|
@mock.patch("torch_npu.npu._lazy_init")
|
||||||
@@ -101,7 +101,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
|
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
|
||||||
layer_id, rank_id)
|
layer_id, rank_id)
|
||||||
self.assertEqual(rank_local_expert_num, 5)
|
self.assertEqual(rank_local_expert_num, 5)
|
||||||
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0, -1, -1],
|
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0],
|
||||||
dtype=torch.int32).to(
|
dtype=torch.int32).to(
|
||||||
rank_expert_map.device)
|
rank_expert_map.device)
|
||||||
self.assertTrue(rank_expert_map.equal(expected_tensor))
|
self.assertTrue(rank_expert_map.equal(expected_tensor))
|
||||||
@@ -109,7 +109,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
rank_id = 1
|
rank_id = 1
|
||||||
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
|
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
|
||||||
layer_id, rank_id)
|
layer_id, rank_id)
|
||||||
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3, -1, -1],
|
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3],
|
||||||
dtype=torch.int32).to(
|
dtype=torch.int32).to(
|
||||||
rank_expert_map.device)
|
rank_expert_map.device)
|
||||||
self.assertTrue(rank_expert_map.equal(expected_tensor))
|
self.assertTrue(rank_expert_map.equal(expected_tensor))
|
||||||
@@ -119,7 +119,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
rank_id = 0
|
rank_id = 0
|
||||||
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
|
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
|
||||||
layer_id, rank_id)
|
layer_id, rank_id)
|
||||||
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0, -1, -1],
|
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0],
|
||||||
dtype=torch.int32).to(
|
dtype=torch.int32).to(
|
||||||
log2phy_map.device)
|
log2phy_map.device)
|
||||||
self.assertTrue(log2phy_map.equal(expected_tensor))
|
self.assertTrue(log2phy_map.equal(expected_tensor))
|
||||||
@@ -127,7 +127,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
rank_id = 1
|
rank_id = 1
|
||||||
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
|
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
|
||||||
layer_id, rank_id)
|
layer_id, rank_id)
|
||||||
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8, -1, -1],
|
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8],
|
||||||
dtype=torch.int32).to(
|
dtype=torch.int32).to(
|
||||||
log2phy_map.device)
|
log2phy_map.device)
|
||||||
self.assertTrue(log2phy_map.equal(expected_tensor))
|
self.assertTrue(log2phy_map.equal(expected_tensor))
|
||||||
|
|||||||
@@ -293,13 +293,13 @@ class TestCumsumGroupList(TestBase):
|
|||||||
def test_cumsum_group_list_with_type_0(self):
|
def test_cumsum_group_list_with_type_0(self):
|
||||||
group_list = self.experts.cumsum(dim=0)
|
group_list = self.experts.cumsum(dim=0)
|
||||||
group_list_type = 0
|
group_list_type = 0
|
||||||
result = cumsum_group_list(group_list, group_list_type)
|
result = cumsum_group_list(group_list, group_list_type, 0)
|
||||||
self.assertTrue(torch.equal(result, self.group_list))
|
self.assertTrue(torch.equal(result, self.group_list))
|
||||||
|
|
||||||
def test_cumsum_group_list_with_type_1(self):
|
def test_cumsum_group_list_with_type_1(self):
|
||||||
group_list = self.experts
|
group_list = self.experts
|
||||||
group_list_type = 1
|
group_list_type = 1
|
||||||
result = cumsum_group_list(group_list, group_list_type)
|
result = cumsum_group_list(group_list, group_list_type, 0)
|
||||||
self.assertTrue(torch.equal(result, self.group_list))
|
self.assertTrue(torch.equal(result, self.group_list))
|
||||||
|
|
||||||
def test_cumsum_group_list_with_type_2(self):
|
def test_cumsum_group_list_with_type_2(self):
|
||||||
@@ -312,6 +312,7 @@ class TestCumsumGroupList(TestBase):
|
|||||||
group_list_type = 2
|
group_list_type = 2
|
||||||
result = cumsum_group_list(group_list,
|
result = cumsum_group_list(group_list,
|
||||||
group_list_type,
|
group_list_type,
|
||||||
|
0,
|
||||||
active_num=self.active_num,
|
active_num=self.active_num,
|
||||||
expert_num=self.expert_num)
|
expert_num=self.expert_num)
|
||||||
self.assertTrue(torch.equal(result, self.group_list))
|
self.assertTrue(torch.equal(result, self.group_list))
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
|||||||
self.dispatcher.need_extra_args = True
|
self.dispatcher.need_extra_args = True
|
||||||
self.dispatcher.enable_dispatch_v2 = True
|
self.dispatcher.enable_dispatch_v2 = True
|
||||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||||
|
self.dispatcher.moe_expert_num = len(self.dispatcher.expert_map)
|
||||||
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
|
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
|
||||||
self.assertIn("tp_send_counts", kwargs)
|
self.assertIn("tp_send_counts", kwargs)
|
||||||
|
|
||||||
@@ -148,6 +148,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
|||||||
self.dispatcher.enable_dispatch_v2 = True
|
self.dispatcher.enable_dispatch_v2 = True
|
||||||
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
|
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
|
||||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||||
|
self.dispatcher.moe_expert_num = len(self.dispatcher.expert_map)
|
||||||
self.hidden_states = torch.randn(10, 128)
|
self.hidden_states = torch.randn(10, 128)
|
||||||
|
|
||||||
with patch("torch_npu.npu_moe_distribute_combine_v2",
|
with patch("torch_npu.npu_moe_distribute_combine_v2",
|
||||||
|
|||||||
@@ -26,31 +26,39 @@ from vllm_ascend.utils import dispose_tensor, is_310p
|
|||||||
|
|
||||||
|
|
||||||
def cumsum_group_list(group_list: torch.Tensor,
|
def cumsum_group_list(group_list: torch.Tensor,
|
||||||
group_list_type: int,
|
src_list_type: int,
|
||||||
|
dst_list_type: int,
|
||||||
active_num: int = 0,
|
active_num: int = 0,
|
||||||
expert_num: int = 0) -> torch.Tensor:
|
expert_num: int = 0) -> torch.Tensor:
|
||||||
if group_list_type not in [0, 1, 2]:
|
if src_list_type not in [0, 1, 2]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"group_list_type should be in [0, 1, 2], but received {group_list_type}"
|
f"group_list_type should be in [0, 1, 2], but received {src_list_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if group_list_type == 0:
|
if src_list_type == dst_list_type:
|
||||||
return group_list
|
return group_list
|
||||||
if group_list_type == 1:
|
if src_list_type == 1 and dst_list_type == 0:
|
||||||
return group_list.cumsum(dim=0)
|
return group_list.cumsum(dim=0)
|
||||||
|
if src_list_type == 0 and dst_list_type == 1:
|
||||||
|
group_diff = torch.diff(group_list)
|
||||||
|
new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0)
|
||||||
|
return new_group
|
||||||
|
if src_list_type == 2 and dst_list_type == 0:
|
||||||
|
experts = pad(group_list[:, 0], (1, 0))
|
||||||
|
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
|
||||||
|
cumsum_group_list = torch.full(size=(expert_num, ),
|
||||||
|
fill_value=active_num,
|
||||||
|
dtype=group_list.dtype,
|
||||||
|
device=group_list.device)
|
||||||
|
|
||||||
experts = pad(group_list[:, 0], (1, 0))
|
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
|
||||||
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
|
if end > start:
|
||||||
cumsum_group_list = torch.full(size=(expert_num, ),
|
cumsum_group_list[start:end] = tokens[i]
|
||||||
fill_value=active_num,
|
|
||||||
dtype=group_list.dtype,
|
|
||||||
device=group_list.device)
|
|
||||||
|
|
||||||
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
|
return cumsum_group_list
|
||||||
if end > start:
|
raise NotImplementedError(
|
||||||
cumsum_group_list[start:end] = tokens[i]
|
f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. "
|
||||||
|
"This feature is under development.")
|
||||||
return cumsum_group_list
|
|
||||||
|
|
||||||
|
|
||||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||||
@@ -89,7 +97,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
weight=w1,
|
weight=w1,
|
||||||
group_list=cumsum_group_list(group_list, group_list_type),
|
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||||
weight_scale=w1_scale,
|
weight_scale=w1_scale,
|
||||||
x_scale=pertoken_scale)
|
x_scale=pertoken_scale)
|
||||||
else:
|
else:
|
||||||
@@ -105,9 +113,6 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
group_list=group_list,
|
group_list=group_list,
|
||||||
output_dtype=torch.int32)[0]
|
output_dtype=torch.int32)[0]
|
||||||
# act_fn: swiglu
|
# act_fn: swiglu
|
||||||
group_diff = torch.diff(group_list)
|
|
||||||
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff],
|
|
||||||
dim=0)
|
|
||||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
weight_scale=w1_scale,
|
weight_scale=w1_scale,
|
||||||
@@ -115,7 +120,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
bias=None,
|
bias=None,
|
||||||
quant_scale=None,
|
quant_scale=None,
|
||||||
quant_offset=None,
|
quant_offset=None,
|
||||||
group_index=new_group,
|
group_index=cumsum_group_list(group_list, group_list_type, 1),
|
||||||
activate_left=True,
|
activate_left=True,
|
||||||
quant_mode=1,
|
quant_mode=1,
|
||||||
)
|
)
|
||||||
@@ -148,7 +153,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
weight=w1,
|
weight=w1,
|
||||||
bias=bias1,
|
bias=bias1,
|
||||||
group_list=cumsum_group_list(group_list, group_list_type),
|
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||||
weight_scale=w1_scale,
|
weight_scale=w1_scale,
|
||||||
x_scale=pertoken_scale)
|
x_scale=pertoken_scale)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user