Files
xc-llm-ascend/vllm_ascend/eplb/core/eplb_utils.py

115 lines
4.4 KiB
Python
Raw Normal View History

Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove eplb utils.
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
import json
import os.path
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
from collections import defaultdict
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
import numpy as np
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
import torch
from vllm.logger import logger
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
def expert_file_to_tensor(expert_map_path, layer_id):
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
with open(expert_map_path) as f:
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
data = json.load(f)
physical_count = 0
device_data = []
if layer_id > data["moe_layer_count"]:
raise ValueError("Invalid EPLB Table")
if layer_id == data["moe_layer_count"]:
logger.warning("Init expert map of mtp/eagle when using sample.")
return None, None
for device in data["layer_list"][layer_id]["device_list"]:
physical_count += len(device["device_expert"])
device_data.append(device["device_expert"])
global_placement = torch.tensor(device_data, dtype=torch.int32)
return global_placement, physical_count
def generate_global_placement(n_expert, ep_size, n_redundant):
all_experts = np.arange(n_expert)
groups = np.array_split(all_experts, ep_size)
for i in range(n_redundant):
j = i % ep_size + 1
if len(groups[-j]) == 0:
groups[-j] = np.append(groups[-j], j)
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
else:
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
groups[-j] = np.append(groups[-j], (groups[-j][-1] + 1) % n_expert)
return torch.tensor(groups, dtype=torch.int32)
[EPLB]Eplb Config Renaming (#5533) ### What this PR does / why we need it? 1. Rename num_iterations_eplb_update to expert_heat_collection_interval. 2. Rename num_wait_worker_iterations to algorithm_execution_interval. 3. Rename init_redundancy_expert to num_redundant_experts because the variable with the same meaning in vLLM is named this way. 4. Delete gate_eplb because we don't need this feature. 5. Move eplb config into a dict in additional config. 6. Depend on pr5817 ### Does this PR introduce _any_ user-facing change? before this pr: `--additional-config '{"dynamic_eplb":true, "num_iterations_eplb_update": 4000, "num_wait_worker_iterations": 150, "init_redundancy_expert": 16, "expert_map_path": "xxx.json"}'` after this pr: `--additional-config '{"eplb_config":{"dynamic_eplb":true,"expert_heat_collection_interval":4000, "algorithm_execution_interval":150,"num_redundant_experts": 16, "expert_map_path": "xxx.json"}}'` ### How was this patch tested? #### test qwen3-235b eplb num_redundant_experts=16 without pr5817 | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 83.33 | with pr5817 | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 86.67 | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/45c1ca1ca1ee8fa06df263c8715e8a412ff408d4 Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
2026-01-15 10:26:44 +08:00
def init_eplb_config(eplb_config, layer_id, moe_config):
expert_map_path = eplb_config.expert_map_path
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
n_experts = moe_config.num_experts
ep_size = moe_config.ep_size
global_placement = None
[EPLB]Eplb Config Renaming (#5533) ### What this PR does / why we need it? 1. Rename num_iterations_eplb_update to expert_heat_collection_interval. 2. Rename num_wait_worker_iterations to algorithm_execution_interval. 3. Rename init_redundancy_expert to num_redundant_experts because the variable with the same meaning in vLLM is named this way. 4. Delete gate_eplb because we don't need this feature. 5. Move eplb config into a dict in additional config. 6. Depend on pr5817 ### Does this PR introduce _any_ user-facing change? before this pr: `--additional-config '{"dynamic_eplb":true, "num_iterations_eplb_update": 4000, "num_wait_worker_iterations": 150, "init_redundancy_expert": 16, "expert_map_path": "xxx.json"}'` after this pr: `--additional-config '{"eplb_config":{"dynamic_eplb":true,"expert_heat_collection_interval":4000, "algorithm_execution_interval":150,"num_redundant_experts": 16, "expert_map_path": "xxx.json"}}'` ### How was this patch tested? #### test qwen3-235b eplb num_redundant_experts=16 without pr5817 | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 83.33 | with pr5817 | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 86.67 | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/45c1ca1ca1ee8fa06df263c8715e8a412ff408d4 Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
2026-01-15 10:26:44 +08:00
eplb_enable = eplb_config.dynamic_eplb
n_redundant = eplb_config.num_redundant_experts if eplb_enable else 0
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
if expert_map_path:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
if not (os.path.exists(expert_map_path) and os.access(expert_map_path, os.R_OK)):
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
raise ValueError("Invalid EPLB path")
eplb_enable = True
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
global_placement, physical_count = expert_file_to_tensor(expert_map_path, layer_id)
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
if physical_count is not None:
n_redundant = physical_count - n_experts
if not moe_config.supports_eplb:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
raise ValueError("Eplb supports only w8a8_dynamic quantization.")
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
else:
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
eplb_enable = False
if global_placement is None:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
global_placement = generate_global_placement(n_experts, ep_size, n_redundant)
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
if ep_size == 1:
[EPLB]Eplb Config Renaming (#5533) ### What this PR does / why we need it? 1. Rename num_iterations_eplb_update to expert_heat_collection_interval. 2. Rename num_wait_worker_iterations to algorithm_execution_interval. 3. Rename init_redundancy_expert to num_redundant_experts because the variable with the same meaning in vLLM is named this way. 4. Delete gate_eplb because we don't need this feature. 5. Move eplb config into a dict in additional config. 6. Depend on pr5817 ### Does this PR introduce _any_ user-facing change? before this pr: `--additional-config '{"dynamic_eplb":true, "num_iterations_eplb_update": 4000, "num_wait_worker_iterations": 150, "init_redundancy_expert": 16, "expert_map_path": "xxx.json"}'` after this pr: `--additional-config '{"eplb_config":{"dynamic_eplb":true,"expert_heat_collection_interval":4000, "algorithm_execution_interval":150,"num_redundant_experts": 16, "expert_map_path": "xxx.json"}}'` ### How was this patch tested? #### test qwen3-235b eplb num_redundant_experts=16 without pr5817 | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 83.33 | with pr5817 | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 86.67 | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/45c1ca1ca1ee8fa06df263c8715e8a412ff408d4 Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
2026-01-15 10:26:44 +08:00
assert not eplb_enable, "EPLB must used in expert parallelism."
return None, None, None, n_redundant
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
global_expert_map = []
for rankid in range(ep_size):
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
expert_map = torch.full((n_experts,), -1, dtype=torch.int32)
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
local_placement = global_placement[rankid]
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
expert_map[local_placement] = torch.arange(local_placement.shape[0], dtype=torch.int32)
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
global_expert_map.append(expert_map)
if rankid == moe_config.ep_rank:
local_expert_map = expert_map.npu()
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
log2phy = generate_log2phy_map(global_expert_map, moe_config.ep_rank).npu() if eplb_enable else None
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
return torch.stack(global_expert_map), local_expert_map, log2phy, n_redundant
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
def generate_log2phy_map(global_expert_map, ep_rank):
log2phy_map = defaultdict(list)
valid_count = torch.sum(global_expert_map[0] != -1)
for rankid, map_per_rank in enumerate(global_expert_map):
for idx, val in enumerate(map_per_rank):
val = val.item()
if val != -1:
log2phy_map[idx].append(val + rankid * valid_count)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
for key in log2phy_map:
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
num_of_duplications = len(log2phy_map[key])
log2phy_map[key] = log2phy_map[key][ep_rank % num_of_duplications]
log2phy_map = torch.scatter(
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-24 22:08:33 +08:00
torch.zeros(len(log2phy_map), dtype=torch.int32),
0,
torch.tensor(list(log2phy_map), dtype=torch.int64),
torch.tensor(list(log2phy_map.values()), dtype=torch.int32),
)
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
return log2phy_map