[BugFix] Fix input parameter bug of dispatch_gmm_combine_decode[RFC: issue 5476] (#5932)
### What this PR does / why we need it?
In [PR 5040](https://github.com/vllm-project/vllm-ascend/pull/5040), the
`dispatch_gmm_combine_decode` operator was configured with an incorrect
global_bs parameter. This PR is to fix the bug.
The global_bs provided as input should have the same meaning as in the
`moe_distributed_dispatch` operator, specifically: (the maximum batch
size across all cards) * (expert parallel world size).
However, the implementation incorrectly used the variable
max_num_tokens, which does not account for tensor parallelism. This
error likely resulted in an unnecessarily large (overestimated) value.
More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Acc
test qwen3-235b eplb on a single A3 node(ep16),
with dispatch_gmm_combine_decode
| dataset | version | metric | mode | vllm-api-stream-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 80.00 |
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -343,7 +343,7 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
ep_rank_size=self.token_dispatcher.ep_world_size,
|
||||
ep_rank_id=self.token_dispatcher.ep_rank_id,
|
||||
moe_expert_num=self.moe_config.num_experts,
|
||||
global_bs=self.token_dispatcher.fused_global_bs)
|
||||
global_bs=self.token_dispatcher.global_bs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||
|
||||
@@ -137,7 +137,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
|
||||
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
|
||||
self.fused_global_bs = max_num_tokens * self.ep_world_size
|
||||
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user