allgather use fusedop. (#2689)
### What this PR does / why we need it?
Use 'npu_moe_init_routing_v2' &'npu_moe_token_unpermute' repalce
'npu_moe_init_routing' &‘npu_moe_compute_expert_tokens’&
'npu_moe_finalize_routing' to optimize performance
### Does this PR introduce _any_ user-facing change?
| branch| tps| TTFT |TPOT |
| --- | --- | --- |--- |
|main |733.98 | 280.05 |34.30 |
|main+fusedop | 740.33 | 273.34 |33.99 |
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
6997a25ac6
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -33,7 +33,7 @@ from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
TokenDispatcherWithAllGather
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1, 4]
|
||||
EP_SIZE = [1]
|
||||
TOP_KS = [2, 6]
|
||||
DEVICE = ["npu"]
|
||||
|
||||
@@ -115,19 +115,6 @@ def test_token_dispatcher_with_all_gather(
|
||||
w1_local = w1
|
||||
w2_local = w2
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.arange(local_e * 0,
|
||||
local_e * (0 + 1),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
expert_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
|
||||
expert_map[e_ids] = torch.arange(local_e,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
w1_local = w1[e_ids]
|
||||
w2_local = w2[e_ids]
|
||||
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
@@ -171,32 +171,25 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
# Mock NPU functions
|
||||
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
|
||||
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
|
||||
self.mock_moe_init_routing.return_value = (
|
||||
self.patcher_npu_moe_init_routing_v2 = patch(
|
||||
'torch_npu.npu_moe_init_routing_v2')
|
||||
self.mock_npu_moe_init_routing_v2 = self.patcher_npu_moe_init_routing_v2.start(
|
||||
)
|
||||
self.mock_npu_moe_init_routing_v2.return_value = (
|
||||
torch.randn(6, 128), # sorted_hidden_states
|
||||
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
|
||||
)
|
||||
|
||||
self.patcher_moe_compute_expert_tokens = patch(
|
||||
'torch_npu.npu_moe_compute_expert_tokens')
|
||||
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
|
||||
)
|
||||
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
|
||||
[3, 3]) # expert_tokens
|
||||
|
||||
self.patcher_moe_finalize_routing = patch(
|
||||
'torch_npu.npu_moe_finalize_routing')
|
||||
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
|
||||
)
|
||||
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]))
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
self.patcher_npu_moe_token_unpermute = patch(
|
||||
'torch_npu.npu_moe_token_unpermute')
|
||||
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
|
||||
)
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128)
|
||||
|
||||
def tearDown(self):
|
||||
self.patcher_moe_init_routing.stop()
|
||||
self.patcher_moe_compute_expert_tokens.stop()
|
||||
self.patcher_moe_finalize_routing.stop()
|
||||
self.patcher_npu_moe_init_routing_v2.stop()
|
||||
self.patcher_npu_moe_token_unpermute.stop()
|
||||
|
||||
def test_token_dispatch_without_expert_map(self):
|
||||
hidden_states = torch.randn(3, 128)
|
||||
@@ -207,10 +200,25 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_ids, self.row_idx, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_moe_init_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_init_routing.call_args
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, self.row_idx, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
kwargs = {
|
||||
@@ -230,7 +238,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, None)
|
||||
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_combine_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
@@ -242,9 +250,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
hidden_states = torch.randn(6, 128)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify index_add_ is applied correctly
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_combine_without_expert_map(self):
|
||||
self.dispatcher.with_quant = False
|
||||
@@ -260,10 +266,10 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify npu_moe_finalize_routing is called
|
||||
self.mock_moe_finalize_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_finalize_routing.call_args
|
||||
self.mock_npu_moe_token_unpermute.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_token_unpermute.call_args
|
||||
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_dispatch_with_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
|
||||
@@ -338,8 +338,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
self.original_shape = hidden_states.shape
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
self.expert_map = expert_map
|
||||
self.topk_weights = topk_weights
|
||||
self.topk_ids = topk_ids
|
||||
@@ -353,67 +351,31 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * \
|
||||
topk_weights.to(hidden_states.dtype)
|
||||
|
||||
if expert_map is not None:
|
||||
# Generate token indices and flatten
|
||||
token_indices = (torch.arange(
|
||||
num_tokens, device=device,
|
||||
dtype=torch.int64).unsqueeze(1).expand(-1,
|
||||
self.top_k).reshape(-1))
|
||||
|
||||
# Flatten token-to-expert mappings and map to local experts
|
||||
weights_flat = topk_weights.view(-1)
|
||||
experts_flat = topk_ids.view(-1)
|
||||
local_experts_flat = expert_map[experts_flat]
|
||||
|
||||
# Filter valid token-expert pairs
|
||||
self.mask = local_experts_flat != -1
|
||||
filtered_weights = torch.where(
|
||||
self.mask, weights_flat,
|
||||
torch.zeros_like(weights_flat)).to(dtype)
|
||||
filtered_experts = torch.where(
|
||||
self.mask, local_experts_flat,
|
||||
torch.full_like(local_experts_flat,
|
||||
self.num_experts_local)).to(topk_ids.dtype)
|
||||
|
||||
# Sort by local expert IDs
|
||||
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
||||
self.sorted_token_indices = token_indices[sort_indices]
|
||||
self.sorted_weights = filtered_weights[sort_indices]
|
||||
|
||||
# Compute token counts with minlength of num_experts
|
||||
# This is equivalent to but faster than:
|
||||
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
||||
token_counts = torch.zeros(self.num_experts_local + 1,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64),
|
||||
ones)
|
||||
token_counts = token_counts[:self.num_experts_local]
|
||||
|
||||
# Rearrange hidden_states
|
||||
sorted_hidden_states = hidden_states[self.sorted_token_indices]
|
||||
if self.with_quant:
|
||||
group_list_type = 1
|
||||
expert_tokens = token_counts
|
||||
else:
|
||||
expert_tokens = torch.cumsum(token_counts,
|
||||
dim=0,
|
||||
dtype=torch.int64)
|
||||
group_list_type = 0
|
||||
global_num_experts = len(expert_map)
|
||||
mask = (expert_map[topk_ids] != -1)
|
||||
self.topk_weights = topk_weights * mask
|
||||
first_expert_idx = get_ep_group(
|
||||
).rank_in_group * self.num_experts_local
|
||||
last_expert_idx = first_expert_idx + self.num_experts_local
|
||||
else:
|
||||
active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens
|
||||
sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=active_num)
|
||||
first_expert_idx = 0
|
||||
last_expert_idx = self.num_experts_local
|
||||
global_num_experts = self.num_experts_local
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
expanded_expert_idx, self.num_experts_local)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 0
|
||||
sorted_hidden_states, self.expanded_row_idx, expert_tokens, _ = (
|
||||
torch_npu.npu_moe_init_routing_v2(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
active_num=num_tokens * self.top_k,
|
||||
expert_num=global_num_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=-1,
|
||||
))
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1 # `count` mode
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
@@ -424,61 +386,12 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert self.original_shape is not None
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
if self.expert_map is not None:
|
||||
assert self.mask is not None
|
||||
assert self.sorted_token_indices is not None
|
||||
assert self.sorted_weights is not None
|
||||
|
||||
weighted_down_out = hidden_states * \
|
||||
self.sorted_weights.unsqueeze(1)
|
||||
|
||||
final_hidden_states = torch.zeros(*self.original_shape,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# This created multiple NaN and index_add_ will mix them up which harms accuracy
|
||||
# remove this mask and filter after it being fixed
|
||||
num_valid_tokens = self.mask.sum()
|
||||
valid_token_mask = torch.arange(
|
||||
0, self.sorted_token_indices.shape[0],
|
||||
device=device).unsqueeze(1) < num_valid_tokens
|
||||
valid_output = torch.where(
|
||||
valid_token_mask, weighted_down_out,
|
||||
torch.zeros_like(weighted_down_out)).to(dtype)
|
||||
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
||||
valid_output)
|
||||
else:
|
||||
if self.with_quant:
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=self.topk_weights,
|
||||
expanded_src_to_dst_row=self.expanded_row_idx,
|
||||
export_for_source_row=self.topk_ids,
|
||||
)
|
||||
if len(self.original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(
|
||||
self.original_shape)
|
||||
else:
|
||||
scales = torch.ones_like(
|
||||
self.topk_weights
|
||||
) if self.apply_router_weight_on_input else self.topk_weights
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=scales,
|
||||
expanded_src_to_dst_row=self.expanded_row_idx,
|
||||
export_for_source_row=self.topk_ids,
|
||||
)
|
||||
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=hidden_states,
|
||||
sorted_indices=self.expanded_row_idx,
|
||||
probs=self.topk_weights)
|
||||
if len(self.original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(self.original_shape)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user