allgather use fusedop. (#2689)
### What this PR does / why we need it?
Use 'npu_moe_init_routing_v2' &'npu_moe_token_unpermute' repalce
'npu_moe_init_routing' &‘npu_moe_compute_expert_tokens’&
'npu_moe_finalize_routing' to optimize performance
### Does this PR introduce _any_ user-facing change?
| branch| tps| TTFT |TPOT |
| --- | --- | --- |--- |
|main |733.98 | 280.05 |34.30 |
|main+fusedop | 740.33 | 273.34 |33.99 |
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
6997a25ac6
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -171,32 +171,25 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
# Mock NPU functions
|
||||
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
|
||||
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
|
||||
self.mock_moe_init_routing.return_value = (
|
||||
self.patcher_npu_moe_init_routing_v2 = patch(
|
||||
'torch_npu.npu_moe_init_routing_v2')
|
||||
self.mock_npu_moe_init_routing_v2 = self.patcher_npu_moe_init_routing_v2.start(
|
||||
)
|
||||
self.mock_npu_moe_init_routing_v2.return_value = (
|
||||
torch.randn(6, 128), # sorted_hidden_states
|
||||
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
|
||||
)
|
||||
|
||||
self.patcher_moe_compute_expert_tokens = patch(
|
||||
'torch_npu.npu_moe_compute_expert_tokens')
|
||||
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
|
||||
)
|
||||
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
|
||||
[3, 3]) # expert_tokens
|
||||
|
||||
self.patcher_moe_finalize_routing = patch(
|
||||
'torch_npu.npu_moe_finalize_routing')
|
||||
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
|
||||
)
|
||||
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]))
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
self.patcher_npu_moe_token_unpermute = patch(
|
||||
'torch_npu.npu_moe_token_unpermute')
|
||||
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
|
||||
)
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128)
|
||||
|
||||
def tearDown(self):
|
||||
self.patcher_moe_init_routing.stop()
|
||||
self.patcher_moe_compute_expert_tokens.stop()
|
||||
self.patcher_moe_finalize_routing.stop()
|
||||
self.patcher_npu_moe_init_routing_v2.stop()
|
||||
self.patcher_npu_moe_token_unpermute.stop()
|
||||
|
||||
def test_token_dispatch_without_expert_map(self):
|
||||
hidden_states = torch.randn(3, 128)
|
||||
@@ -207,10 +200,25 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_ids, self.row_idx, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_moe_init_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_init_routing.call_args
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, self.row_idx, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
kwargs = {
|
||||
@@ -230,7 +238,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, None)
|
||||
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_combine_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
@@ -242,9 +250,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
hidden_states = torch.randn(6, 128)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify index_add_ is applied correctly
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_combine_without_expert_map(self):
|
||||
self.dispatcher.with_quant = False
|
||||
@@ -260,10 +266,10 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify npu_moe_finalize_routing is called
|
||||
self.mock_moe_finalize_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_finalize_routing.call_args
|
||||
self.mock_npu_moe_token_unpermute.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_token_unpermute.call_args
|
||||
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_dispatch_with_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
|
||||
Reference in New Issue
Block a user