diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 6c44a6a..aeadc7b 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -130,6 +130,7 @@ def fused_experts_with_mc2( "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": global_bs, + "expert_scales": topk_weights.to(torch.float32), } rank = torch.distributed.get_rank() @@ -158,8 +159,8 @@ def fused_experts_with_mc2( output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ - 0:5] + expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[ + 0:7] if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): @@ -187,6 +188,7 @@ def fused_experts_with_mc2( "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": 0, + "expand_scales": expand_scales, } tp_recv_counts = torch.empty(1, dtype=torch.int32,