allgather use fusedop. (#2689)
### What this PR does / why we need it?
Use 'npu_moe_init_routing_v2' &'npu_moe_token_unpermute' repalce
'npu_moe_init_routing' &‘npu_moe_compute_expert_tokens’&
'npu_moe_finalize_routing' to optimize performance
### Does this PR introduce _any_ user-facing change?
| branch| tps| TTFT |TPOT |
| --- | --- | --- |--- |
|main |733.98 | 280.05 |34.30 |
|main+fusedop | 740.33 | 273.34 |33.99 |
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
6997a25ac6
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -33,7 +33,7 @@ from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
|||||||
TokenDispatcherWithAllGather
|
TokenDispatcherWithAllGather
|
||||||
|
|
||||||
NUM_EXPERTS = [8, 64]
|
NUM_EXPERTS = [8, 64]
|
||||||
EP_SIZE = [1, 4]
|
EP_SIZE = [1]
|
||||||
TOP_KS = [2, 6]
|
TOP_KS = [2, 6]
|
||||||
DEVICE = ["npu"]
|
DEVICE = ["npu"]
|
||||||
|
|
||||||
@@ -115,19 +115,6 @@ def test_token_dispatcher_with_all_gather(
|
|||||||
w1_local = w1
|
w1_local = w1
|
||||||
w2_local = w2
|
w2_local = w2
|
||||||
|
|
||||||
if ep_size > 1:
|
|
||||||
local_e = e // ep_size
|
|
||||||
e_ids = torch.arange(local_e * 0,
|
|
||||||
local_e * (0 + 1),
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int32)
|
|
||||||
expert_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
|
|
||||||
expert_map[e_ids] = torch.arange(local_e,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int32)
|
|
||||||
w1_local = w1[e_ids]
|
|
||||||
w2_local = w2[e_ids]
|
|
||||||
|
|
||||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||||
topk_weights, topk_ids = torch.topk(score, topk)
|
topk_weights, topk_ids = torch.topk(score, topk)
|
||||||
topk_ids = topk_ids.to(torch.int32)
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
|||||||
@@ -171,32 +171,25 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
|
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
|
||||||
|
|
||||||
# Mock NPU functions
|
# Mock NPU functions
|
||||||
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
|
self.patcher_npu_moe_init_routing_v2 = patch(
|
||||||
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
|
'torch_npu.npu_moe_init_routing_v2')
|
||||||
self.mock_moe_init_routing.return_value = (
|
self.mock_npu_moe_init_routing_v2 = self.patcher_npu_moe_init_routing_v2.start(
|
||||||
|
)
|
||||||
|
self.mock_npu_moe_init_routing_v2.return_value = (
|
||||||
torch.randn(6, 128), # sorted_hidden_states
|
torch.randn(6, 128), # sorted_hidden_states
|
||||||
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
|
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
|
||||||
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
|
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
|
||||||
)
|
torch.tensor([0, 1, 0, 1, 0, 1]))
|
||||||
|
|
||||||
self.patcher_moe_compute_expert_tokens = patch(
|
|
||||||
'torch_npu.npu_moe_compute_expert_tokens')
|
|
||||||
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
|
|
||||||
)
|
|
||||||
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
|
|
||||||
[3, 3]) # expert_tokens
|
|
||||||
|
|
||||||
self.patcher_moe_finalize_routing = patch(
|
|
||||||
'torch_npu.npu_moe_finalize_routing')
|
|
||||||
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
|
|
||||||
)
|
|
||||||
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
|
|
||||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||||
|
self.patcher_npu_moe_token_unpermute = patch(
|
||||||
|
'torch_npu.npu_moe_token_unpermute')
|
||||||
|
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
|
||||||
|
)
|
||||||
|
self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.patcher_moe_init_routing.stop()
|
self.patcher_npu_moe_init_routing_v2.stop()
|
||||||
self.patcher_moe_compute_expert_tokens.stop()
|
self.patcher_npu_moe_token_unpermute.stop()
|
||||||
self.patcher_moe_finalize_routing.stop()
|
|
||||||
|
|
||||||
def test_token_dispatch_without_expert_map(self):
|
def test_token_dispatch_without_expert_map(self):
|
||||||
hidden_states = torch.randn(3, 128)
|
hidden_states = torch.randn(3, 128)
|
||||||
@@ -207,10 +200,25 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
topk_ids, self.row_idx, None)
|
topk_ids, self.row_idx, None)
|
||||||
|
|
||||||
# Verify npu_moe_init_routing is called
|
# Verify npu_moe_init_routing is called
|
||||||
self.mock_moe_init_routing.assert_called_once()
|
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||||
args, kwargs = self.mock_moe_init_routing.call_args
|
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||||
|
|
||||||
self.assertEqual(results["group_list_type"], 0)
|
self.assertEqual(results["group_list_type"], 1)
|
||||||
|
|
||||||
|
def test_token_dispatch_with_expert_map(self):
|
||||||
|
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||||
|
hidden_states = torch.randn(3, 128)
|
||||||
|
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||||
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||||
|
|
||||||
|
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||||
|
topk_ids, self.row_idx, None)
|
||||||
|
|
||||||
|
# Verify npu_moe_init_routing is called
|
||||||
|
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||||
|
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||||
|
|
||||||
|
self.assertEqual(results["group_list_type"], 1)
|
||||||
|
|
||||||
def test_token_dispatch_with_quant(self):
|
def test_token_dispatch_with_quant(self):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
@@ -230,7 +238,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
topk_weights, topk_ids,
|
topk_weights, topk_ids,
|
||||||
self.row_idx, None)
|
self.row_idx, None)
|
||||||
|
|
||||||
self.assertEqual(results["group_list_type"], 0)
|
self.assertEqual(results["group_list_type"], 1)
|
||||||
|
|
||||||
def test_token_combine_with_expert_map(self):
|
def test_token_combine_with_expert_map(self):
|
||||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||||
@@ -242,9 +250,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
hidden_states = torch.randn(6, 128)
|
hidden_states = torch.randn(6, 128)
|
||||||
|
|
||||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||||
|
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||||
# Verify index_add_ is applied correctly
|
|
||||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
|
||||||
|
|
||||||
def test_token_combine_without_expert_map(self):
|
def test_token_combine_without_expert_map(self):
|
||||||
self.dispatcher.with_quant = False
|
self.dispatcher.with_quant = False
|
||||||
@@ -260,10 +266,10 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||||
|
|
||||||
# Verify npu_moe_finalize_routing is called
|
# Verify npu_moe_finalize_routing is called
|
||||||
self.mock_moe_finalize_routing.assert_called_once()
|
self.mock_npu_moe_token_unpermute.assert_called_once()
|
||||||
args, kwargs = self.mock_moe_finalize_routing.call_args
|
args, kwargs = self.mock_npu_moe_token_unpermute.call_args
|
||||||
|
|
||||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||||
|
|
||||||
def test_token_dispatch_with_router_weight(self):
|
def test_token_dispatch_with_router_weight(self):
|
||||||
self.dispatcher.apply_router_weight_on_input = True
|
self.dispatcher.apply_router_weight_on_input = True
|
||||||
|
|||||||
@@ -338,8 +338,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
self.original_shape = hidden_states.shape
|
self.original_shape = hidden_states.shape
|
||||||
|
|
||||||
num_tokens = hidden_states.shape[:-1].numel()
|
num_tokens = hidden_states.shape[:-1].numel()
|
||||||
dtype = hidden_states.dtype
|
|
||||||
device = hidden_states.device
|
|
||||||
self.expert_map = expert_map
|
self.expert_map = expert_map
|
||||||
self.topk_weights = topk_weights
|
self.topk_weights = topk_weights
|
||||||
self.topk_ids = topk_ids
|
self.topk_ids = topk_ids
|
||||||
@@ -353,67 +351,31 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||||
hidden_states = hidden_states * \
|
hidden_states = hidden_states * \
|
||||||
topk_weights.to(hidden_states.dtype)
|
topk_weights.to(hidden_states.dtype)
|
||||||
|
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
# Generate token indices and flatten
|
global_num_experts = len(expert_map)
|
||||||
token_indices = (torch.arange(
|
mask = (expert_map[topk_ids] != -1)
|
||||||
num_tokens, device=device,
|
self.topk_weights = topk_weights * mask
|
||||||
dtype=torch.int64).unsqueeze(1).expand(-1,
|
first_expert_idx = get_ep_group(
|
||||||
self.top_k).reshape(-1))
|
).rank_in_group * self.num_experts_local
|
||||||
|
last_expert_idx = first_expert_idx + self.num_experts_local
|
||||||
# Flatten token-to-expert mappings and map to local experts
|
|
||||||
weights_flat = topk_weights.view(-1)
|
|
||||||
experts_flat = topk_ids.view(-1)
|
|
||||||
local_experts_flat = expert_map[experts_flat]
|
|
||||||
|
|
||||||
# Filter valid token-expert pairs
|
|
||||||
self.mask = local_experts_flat != -1
|
|
||||||
filtered_weights = torch.where(
|
|
||||||
self.mask, weights_flat,
|
|
||||||
torch.zeros_like(weights_flat)).to(dtype)
|
|
||||||
filtered_experts = torch.where(
|
|
||||||
self.mask, local_experts_flat,
|
|
||||||
torch.full_like(local_experts_flat,
|
|
||||||
self.num_experts_local)).to(topk_ids.dtype)
|
|
||||||
|
|
||||||
# Sort by local expert IDs
|
|
||||||
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
|
||||||
self.sorted_token_indices = token_indices[sort_indices]
|
|
||||||
self.sorted_weights = filtered_weights[sort_indices]
|
|
||||||
|
|
||||||
# Compute token counts with minlength of num_experts
|
|
||||||
# This is equivalent to but faster than:
|
|
||||||
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
|
||||||
token_counts = torch.zeros(self.num_experts_local + 1,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
|
||||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64),
|
|
||||||
ones)
|
|
||||||
token_counts = token_counts[:self.num_experts_local]
|
|
||||||
|
|
||||||
# Rearrange hidden_states
|
|
||||||
sorted_hidden_states = hidden_states[self.sorted_token_indices]
|
|
||||||
if self.with_quant:
|
|
||||||
group_list_type = 1
|
|
||||||
expert_tokens = token_counts
|
|
||||||
else:
|
|
||||||
expert_tokens = torch.cumsum(token_counts,
|
|
||||||
dim=0,
|
|
||||||
dtype=torch.int64)
|
|
||||||
group_list_type = 0
|
|
||||||
else:
|
else:
|
||||||
active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens
|
first_expert_idx = 0
|
||||||
sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
last_expert_idx = self.num_experts_local
|
||||||
hidden_states,
|
global_num_experts = self.num_experts_local
|
||||||
row_idx=row_idx,
|
|
||||||
expert_idx=topk_ids,
|
|
||||||
active_num=active_num)
|
|
||||||
|
|
||||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
sorted_hidden_states, self.expanded_row_idx, expert_tokens, _ = (
|
||||||
expanded_expert_idx, self.num_experts_local)
|
torch_npu.npu_moe_init_routing_v2(
|
||||||
expert_tokens = expert_tokens.to(torch.int64)
|
hidden_states,
|
||||||
group_list_type = 0
|
topk_ids,
|
||||||
|
active_num=num_tokens * self.top_k,
|
||||||
|
expert_num=global_num_experts,
|
||||||
|
expert_tokens_num_type=1,
|
||||||
|
expert_tokens_num_flag=True,
|
||||||
|
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||||
|
quant_mode=-1,
|
||||||
|
))
|
||||||
|
expert_tokens = expert_tokens.to(torch.int64)
|
||||||
|
group_list_type = 1 # `count` mode
|
||||||
return {
|
return {
|
||||||
"group_list_type": group_list_type,
|
"group_list_type": group_list_type,
|
||||||
"hidden_states": sorted_hidden_states,
|
"hidden_states": sorted_hidden_states,
|
||||||
@@ -424,61 +386,12 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
bias: torch.Tensor = None):
|
bias: torch.Tensor = None):
|
||||||
assert self.original_shape is not None
|
assert self.original_shape is not None
|
||||||
dtype = hidden_states.dtype
|
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||||
device = hidden_states.device
|
permuted_tokens=hidden_states,
|
||||||
if self.expert_map is not None:
|
sorted_indices=self.expanded_row_idx,
|
||||||
assert self.mask is not None
|
probs=self.topk_weights)
|
||||||
assert self.sorted_token_indices is not None
|
if len(self.original_shape) == 3:
|
||||||
assert self.sorted_weights is not None
|
final_hidden_states = final_hidden_states.view(self.original_shape)
|
||||||
|
|
||||||
weighted_down_out = hidden_states * \
|
|
||||||
self.sorted_weights.unsqueeze(1)
|
|
||||||
|
|
||||||
final_hidden_states = torch.zeros(*self.original_shape,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype)
|
|
||||||
|
|
||||||
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
|
||||||
# This created multiple NaN and index_add_ will mix them up which harms accuracy
|
|
||||||
# remove this mask and filter after it being fixed
|
|
||||||
num_valid_tokens = self.mask.sum()
|
|
||||||
valid_token_mask = torch.arange(
|
|
||||||
0, self.sorted_token_indices.shape[0],
|
|
||||||
device=device).unsqueeze(1) < num_valid_tokens
|
|
||||||
valid_output = torch.where(
|
|
||||||
valid_token_mask, weighted_down_out,
|
|
||||||
torch.zeros_like(weighted_down_out)).to(dtype)
|
|
||||||
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
|
||||||
valid_output)
|
|
||||||
else:
|
|
||||||
if self.with_quant:
|
|
||||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
|
||||||
hidden_states,
|
|
||||||
skip1=None,
|
|
||||||
skip2=None,
|
|
||||||
bias=None,
|
|
||||||
scales=self.topk_weights,
|
|
||||||
expanded_src_to_dst_row=self.expanded_row_idx,
|
|
||||||
export_for_source_row=self.topk_ids,
|
|
||||||
)
|
|
||||||
if len(self.original_shape) == 3:
|
|
||||||
final_hidden_states = final_hidden_states.view(
|
|
||||||
self.original_shape)
|
|
||||||
else:
|
|
||||||
scales = torch.ones_like(
|
|
||||||
self.topk_weights
|
|
||||||
) if self.apply_router_weight_on_input else self.topk_weights
|
|
||||||
# TODO: Reorder device memory 2 times here, replace the current
|
|
||||||
# implementation here when suitable operators become available.
|
|
||||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
|
||||||
hidden_states,
|
|
||||||
skip1=None,
|
|
||||||
skip2=None,
|
|
||||||
bias=None,
|
|
||||||
scales=scales,
|
|
||||||
expanded_src_to_dst_row=self.expanded_row_idx,
|
|
||||||
export_for_source_row=self.topk_ids,
|
|
||||||
)
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user