diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index 832cbc5..be212e2 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -416,6 +416,7 @@ def torchair_fused_experts_with_all2all( num_experts = w1.shape[0] if expert_map is not None: + assert ep_group is not None, "ep_group must be provided when expert_map is given" global_num_experts = len(expert_map) + global_redundant_expert_num if hasattr(torch_npu, "npu_moe_init_routing_quant"): quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant( @@ -435,8 +436,9 @@ def torchair_fused_experts_with_all2all( gather_sizes = global_expert_tokens.new_empty( global_expert_tokens.shape[0]) - dist.all_to_all_single(gather_sizes, global_expert_tokens) - + dist.all_to_all_single(gather_sizes, + global_expert_tokens, + group=ep_group.device_group) token_counts_combined = torch.stack( [gather_sizes, global_expert_tokens], dim=0) token_counts_combined = token_counts_combined.view( @@ -451,10 +453,16 @@ def torchair_fused_experts_with_all2all( gather_size_list = token_counts_combined_cpu[1] scatter_size_list = token_counts_combined_cpu[0] - dist.all_to_all_single(gathered_tokens, quantized_tokens, - scatter_size_list, gather_size_list) - dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list, - gather_size_list) + dist.all_to_all_single(gathered_tokens, + quantized_tokens, + scatter_size_list, + gather_size_list, + group=ep_group.device_group) + dist.all_to_all_single(dynamic_scale, + token_scales, + scatter_size_list, + gather_size_list, + group=ep_group.device_group) hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing( gathered_tokens, @@ -502,9 +510,11 @@ def torchair_fused_experts_with_all2all( index=inverse_indices.to(torch.float32).argsort().to(torch.int32)) hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape) - dist.all_to_all_single(hidden_states, reordered_outputs, - gather_size_list, scatter_size_list) - + dist.all_to_all_single(hidden_states, + reordered_outputs, + gather_size_list, + scatter_size_list, + group=ep_group.device_group) final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, skip1=None,