allgather use fusedop. (#2689)

### What this PR does / why we need it?
Use 'npu_moe_init_routing_v2' &'npu_moe_token_unpermute' repalce
'npu_moe_init_routing' &‘npu_moe_compute_expert_tokens’&
'npu_moe_finalize_routing' to optimize performance
### Does this PR introduce _any_ user-facing change?
| branch| tps| TTFT |TPOT |
| --- | --- | --- |--- |
|main  |733.98  | 280.05 |34.30 |
|main+fusedop  | 740.33 | 273.34 |33.99 |
### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
6997a25ac6

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
sherie
2025-09-04 11:56:29 +08:00
committed by GitHub
parent 7d47d8f4f6
commit f86596a66c
3 changed files with 66 additions and 160 deletions

View File

@@ -33,7 +33,7 @@ from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
TokenDispatcherWithAllGather
NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
EP_SIZE = [1]
TOP_KS = [2, 6]
DEVICE = ["npu"]
@@ -115,19 +115,6 @@ def test_token_dispatcher_with_all_gather(
w1_local = w1
w2_local = w2
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.arange(local_e * 0,
local_e * (0 + 1),
device=device,
dtype=torch.int32)
expert_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
expert_map[e_ids] = torch.arange(local_e,
device=device,
dtype=torch.int32)
w1_local = w1[e_ids]
w2_local = w2[e_ids]
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)

View File

@@ -171,32 +171,25 @@ class TestTokenDispatcherWithAllGather(TestBase):
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
# Mock NPU functions
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
self.mock_moe_init_routing.return_value = (
self.patcher_npu_moe_init_routing_v2 = patch(
'torch_npu.npu_moe_init_routing_v2')
self.mock_npu_moe_init_routing_v2 = self.patcher_npu_moe_init_routing_v2.start(
)
self.mock_npu_moe_init_routing_v2.return_value = (
torch.randn(6, 128), # sorted_hidden_states
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
)
self.patcher_moe_compute_expert_tokens = patch(
'torch_npu.npu_moe_compute_expert_tokens')
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
)
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
[3, 3]) # expert_tokens
self.patcher_moe_finalize_routing = patch(
'torch_npu.npu_moe_finalize_routing')
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
)
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
torch.tensor([0, 1, 0, 1, 0, 1]))
self.row_idx = torch.arange(10, dtype=torch.int32)
self.patcher_npu_moe_token_unpermute = patch(
'torch_npu.npu_moe_token_unpermute')
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
)
self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128)
def tearDown(self):
self.patcher_moe_init_routing.stop()
self.patcher_moe_compute_expert_tokens.stop()
self.patcher_moe_finalize_routing.stop()
self.patcher_npu_moe_init_routing_v2.stop()
self.patcher_npu_moe_token_unpermute.stop()
def test_token_dispatch_without_expert_map(self):
hidden_states = torch.randn(3, 128)
@@ -207,10 +200,25 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_ids, self.row_idx, None)
# Verify npu_moe_init_routing is called
self.mock_moe_init_routing.assert_called_once()
args, kwargs = self.mock_moe_init_routing.call_args
self.mock_npu_moe_init_routing_v2.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
self.assertEqual(results["group_list_type"], 0)
self.assertEqual(results["group_list_type"], 1)
def test_token_dispatch_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, self.row_idx, None)
# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_v2.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
self.assertEqual(results["group_list_type"], 1)
def test_token_dispatch_with_quant(self):
kwargs = {
@@ -230,7 +238,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights, topk_ids,
self.row_idx, None)
self.assertEqual(results["group_list_type"], 0)
self.assertEqual(results["group_list_type"], 1)
def test_token_combine_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
@@ -242,9 +250,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
hidden_states = torch.randn(6, 128)
final_hidden_states = self.dispatcher.token_combine(hidden_states)
# Verify index_add_ is applied correctly
self.assertEqual(final_hidden_states.shape, (3, 128))
self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_combine_without_expert_map(self):
self.dispatcher.with_quant = False
@@ -260,10 +266,10 @@ class TestTokenDispatcherWithAllGather(TestBase):
final_hidden_states = self.dispatcher.token_combine(hidden_states)
# Verify npu_moe_finalize_routing is called
self.mock_moe_finalize_routing.assert_called_once()
args, kwargs = self.mock_moe_finalize_routing.call_args
self.mock_npu_moe_token_unpermute.assert_called_once()
args, kwargs = self.mock_npu_moe_token_unpermute.call_args
self.assertEqual(final_hidden_states.shape, (3, 128))
self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_dispatch_with_router_weight(self):
self.dispatcher.apply_router_weight_on_input = True

View File

@@ -338,8 +338,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.original_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
dtype = hidden_states.dtype
device = hidden_states.device
self.expert_map = expert_map
self.topk_weights = topk_weights
self.topk_ids = topk_ids
@@ -353,67 +351,31 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
if expert_map is not None:
# Generate token indices and flatten
token_indices = (torch.arange(
num_tokens, device=device,
dtype=torch.int64).unsqueeze(1).expand(-1,
self.top_k).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = expert_map[experts_flat]
# Filter valid token-expert pairs
self.mask = local_experts_flat != -1
filtered_weights = torch.where(
self.mask, weights_flat,
torch.zeros_like(weights_flat)).to(dtype)
filtered_experts = torch.where(
self.mask, local_experts_flat,
torch.full_like(local_experts_flat,
self.num_experts_local)).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
self.sorted_token_indices = token_indices[sort_indices]
self.sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(self.num_experts_local + 1,
device=device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64),
ones)
token_counts = token_counts[:self.num_experts_local]
# Rearrange hidden_states
sorted_hidden_states = hidden_states[self.sorted_token_indices]
if self.with_quant:
group_list_type = 1
expert_tokens = token_counts
else:
expert_tokens = torch.cumsum(token_counts,
dim=0,
dtype=torch.int64)
group_list_type = 0
global_num_experts = len(expert_map)
mask = (expert_map[topk_ids] != -1)
self.topk_weights = topk_weights * mask
first_expert_idx = get_ep_group(
).rank_in_group * self.num_experts_local
last_expert_idx = first_expert_idx + self.num_experts_local
else:
active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens
sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=active_num)
first_expert_idx = 0
last_expert_idx = self.num_experts_local
global_num_experts = self.num_experts_local
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, self.num_experts_local)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
sorted_hidden_states, self.expanded_row_idx, expert_tokens, _ = (
torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
active_num=num_tokens * self.top_k,
expert_num=global_num_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=-1,
))
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
@@ -424,61 +386,12 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
assert self.original_shape is not None
dtype = hidden_states.dtype
device = hidden_states.device
if self.expert_map is not None:
assert self.mask is not None
assert self.sorted_token_indices is not None
assert self.sorted_weights is not None
weighted_down_out = hidden_states * \
self.sorted_weights.unsqueeze(1)
final_hidden_states = torch.zeros(*self.original_shape,
device=hidden_states.device,
dtype=hidden_states.dtype)
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# This created multiple NaN and index_add_ will mix them up which harms accuracy
# remove this mask and filter after it being fixed
num_valid_tokens = self.mask.sum()
valid_token_mask = torch.arange(
0, self.sorted_token_indices.shape[0],
device=device).unsqueeze(1) < num_valid_tokens
valid_output = torch.where(
valid_token_mask, weighted_down_out,
torch.zeros_like(weighted_down_out)).to(dtype)
final_hidden_states.index_add_(0, self.sorted_token_indices,
valid_output)
else:
if self.with_quant:
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=self.topk_weights,
expanded_src_to_dst_row=self.expanded_row_idx,
export_for_source_row=self.topk_ids,
)
if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(
self.original_shape)
else:
scales = torch.ones_like(
self.topk_weights
) if self.apply_router_weight_on_input else self.topk_weights
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=scales,
expanded_src_to_dst_row=self.expanded_row_idx,
export_for_source_row=self.topk_ids,
)
final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states,
sorted_indices=self.expanded_row_idx,
probs=self.topk_weights)
if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(self.original_shape)
return final_hidden_states