diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp index 811b2ce9..9180b06d 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp @@ -114,7 +114,6 @@ __aicore__ inline void moe_init_routing_quant_v2( srcToDstAndGatherOp.Init(x, scale, expandedRowIdx, expandedX, dynamicQuantScale, workspace, tilingData, &srcToDstGatherPipe); srcToDstAndGatherOp.Process(); srcToDstGatherPipe.Destroy(); - return; } } diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 14db8976..28882924 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -734,6 +734,9 @@ class MtpProposer(Proposer): num_input_tokens, self.runner.with_prefill) moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens) + # TODO: remove this after moe_comm_type selection logic is finalized + moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type + == MoECommType.FUSED_ALLTOALL else moe_comm_type) # Enable shared_expert_dp and MTP FULL graph may cause accuracy issues. if scheduler_output and not self.enable_shared_expert_dp: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index ef74c354..1dad4a28 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -920,16 +920,17 @@ def calculate_ep_buffer_size() -> int: try: from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() + tp_size = vllm_config.parallel_config.tensor_parallel_size hf_config = vllm_config.model_config.hf_config hidden_size = hf_config.hidden_size - topk = getattr(hf_config, "num_experts_per_token", 1) - batch_size = vllm_config.scheduler_config.max_num_batched_tokens + topk = getattr(hf_config, "num_experts_per_tok", 1) + batch_size = vllm_config.scheduler_config.max_num_batched_tokens // tp_size int8_size = torch.iinfo(torch.int8).bits // 8 bf16_size = torch.finfo(torch.bfloat16).bits // 8 ep_buffer_size = math.ceil( (batch_size * hidden_size * topk * - (int8_size * 2 + bf16_size)) / (1024 * 1024)) + (int8_size + bf16_size) * 3) / (1024 * 1024)) except Exception: pass return max(ep_buffer_size, _DEFAULT_BUFFER_SIZE)