diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py new file mode 100644 index 0000000..59ab604 --- /dev/null +++ b/tests/ut/quantization/test_w8a8_dynamic.py @@ -0,0 +1,75 @@ +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.quantization.w8a8_dynamic import fused_experts_with_all2all + + +class TestAscendW8A8FusedMoEMethod(TestBase): + + def setUp(self): + self.hidden_size = 128 + self.num_tokens = 128 + self.placeholder = torch.randn(self.num_tokens, + self.hidden_size, + dtype=torch.bfloat16) + + @patch("torch.distributed.all_to_all_single") + @patch("torch_npu.npu_moe_re_routing") + @patch("torch_npu.npu_grouped_matmul") + @patch("torch_npu.npu_swiglu") + @patch("torch_npu.npu_dynamic_quant") + @patch("torch_npu.npu_moe_finalize_routing") + @patch("torch_npu.npu_moe_init_routing") + def test_fused_experts_with_all2all(self, mock_moe_init_routing, + mock_moe_finalize_routing, + mock_dynamic_quant, mock_swiglu, + mock_grouped_matmul, + mock_moe_re_routing, + mock_all_to_all_single): + expert_map = MagicMock() + ep_group = MagicMock() + placeholder_int8 = torch.randint(0, + 100, + (self.num_tokens, self.hidden_size), + dtype=torch.int8) + placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32) + mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( + input) + mock_moe_init_routing.return_value = ( + placeholder_int8, + placeholder_ones, + placeholder_ones, + ) + mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder, + torch.randint(0, + 100, + (self.num_tokens, ), + dtype=torch.int32), + self.placeholder) + mock_grouped_matmul.return_value = self.placeholder + mock_swiglu.return_value = self.placeholder + mock_dynamic_quant.return_value = ( + placeholder_int8, + torch.randn(self.num_tokens), + ) + mock_moe_finalize_routing.return_value = self.placeholder + + result = fused_experts_with_all2all( + hidden_states=self.placeholder, + w1=self.placeholder, + w1_scale=self.placeholder, + w2=self.placeholder, + w2_scale=self.placeholder, + topk_weights=self.placeholder, + topk_ids=self.placeholder, + top_k=8, + expert_map=expert_map, + ep_group=ep_group, + log2phy=None, + global_redundant_expert_num=256, + ) + self.assertIsNotNone(result) + self.assertEqual(result.dtype, torch.bfloat16) + self.assertEqual(result.shape, (128, 128)) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index b20ffa3..e4afbb5 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -334,6 +334,29 @@ def fused_experts_with_mc2( return hidden_states, shared_output +def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts): + num_tokens, _ = hidden_states.shape + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=hidden_states.device).view( + top_k, -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute( + 1, 0).contiguous().view(-1)) + global_expert_tokens = torch.bincount(expanded_expert_idx, + minlength=global_num_experts) + global_expert_tokens = global_expert_tokens.to(torch.int32) + quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states) + return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales + + # currently expert parallelism implemented with all2all # is under-optimized. def fused_experts_with_all2all( @@ -358,50 +381,54 @@ def fused_experts_with_all2all( num_tokens, _ = hidden_states.shape num_experts = w1.shape[0] - device = hidden_states.device if expert_map is not None: global_num_experts = len(expert_map) + global_redundant_expert_num - local_num_experts = global_num_experts // ep_group.world_size - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(top_k, -1).permute( - 1, 0).contiguous()) - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) + if hasattr(torch_npu, "npu_moe_init_routing_quant"): + quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant( + hidden_states, + expert_idx=topk_ids.to(torch.int32), + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_num_mode=2, + expert_tokens_before_capacity_flag=False, + quant_mode=1, + ) + else: + quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant( + hidden_states, top_k, topk_ids, global_num_experts) - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - scatter_sizes = global_expert_tokens.view(ep_group.world_size, - -1).sum(-1) + gather_sizes = global_expert_tokens.new_empty( + global_expert_tokens.shape[0]) + dist.all_to_all_single(gather_sizes, global_expert_tokens) - gather_sizes = torch.empty_like(scatter_sizes) - dist.all_to_all_single(gather_sizes, - scatter_sizes, - group=ep_group.device_group) - scatter_size_list = scatter_sizes.cpu().tolist() - gather_size_list = gather_sizes.cpu().tolist() + token_counts_combined = torch.stack( + [gather_sizes, global_expert_tokens], dim=0) + token_counts_combined = token_counts_combined.view( + 2, ep_group.world_size, -1).sum(dim=2) + token_counts_combined_cpu = token_counts_combined.to( + torch.device("cpu"), non_blocking=True).numpy() + all_tokens = gather_sizes.sum() - expanded_expert_idx = expanded_expert_idx % local_num_experts - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - scatter_size_list, - gather_size_list) - local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, - scatter_size_list, - gather_size_list) + gathered_tokens = quantized_tokens.new_empty(all_tokens.item(), + quantized_tokens.shape[1]) + dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0]) + gather_size_list = token_counts_combined_cpu[1] + scatter_size_list = token_counts_combined_cpu[0] - sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) + dist.all_to_all_single(gathered_tokens, quantized_tokens, + scatter_size_list, gather_size_list) + dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list, + gather_size_list) - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - sorted_local_expert_idx, local_num_experts).to(torch.int64) - - hidden_states = hidden_states[sorted_idx] - group_list_type = 0 + hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing( + gathered_tokens, + gather_sizes.view(ep_group.world_size, -1), + per_token_scales=dynamic_scale) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 1 else: row_idx_len = num_tokens * top_k row_idx = torch.arange(0, @@ -419,6 +446,7 @@ def fused_experts_with_all2all( expanded_expert_idx, num_experts) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 0 + dynamic_scale = None # `hidden_states` will be disposed in the `apply_mlp` function hidden_states = apply_mlp( @@ -428,14 +456,19 @@ def fused_experts_with_all2all( w2, w2_scale, expert_tokens, #16 + dynamic_scale=dynamic_scale, group_list_type=group_list_type) if expert_map is not None: - resorted_idx = torch.argsort(sorted_idx) - hidden_states = hidden_states[resorted_idx] - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - gather_size_list, - scatter_size_list) + reordered_outputs = torch.index_select( + hidden_states, + dim=0, + # Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU + index=inverse_indices.to(torch.float32).argsort().to(torch.int32)) + + hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape) + dist.all_to_all_single(hidden_states, reordered_outputs, + gather_size_list, scatter_size_list) final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, @@ -444,8 +477,8 @@ def fused_experts_with_all2all( bias=None, scales=topk_weights, expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) + export_for_source_row=None, + drop_pad_mode=2) else: # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available.