[EPLB]Eplb Config Renaming (#5533)

### What this PR does / why we need it?
1. Rename num_iterations_eplb_update to expert_heat_collection_interval.
2. Rename num_wait_worker_iterations to algorithm_execution_interval.
3. Rename init_redundancy_expert to num_redundant_experts because the
variable with the same meaning in vLLM is named this way.
4. Delete gate_eplb because we don't need this feature.
5. Move eplb config into a dict in additional config.
6. Depend on pr5817

### Does this PR introduce _any_ user-facing change?

before this pr:
`--additional-config '{"dynamic_eplb":true,
"num_iterations_eplb_update": 4000, "num_wait_worker_iterations": 150,
"init_redundancy_expert": 16, "expert_map_path": "xxx.json"}'`

after this pr: 
`--additional-config
'{"eplb_config":{"dynamic_eplb":true,"expert_heat_collection_interval":4000,
"algorithm_execution_interval":150,"num_redundant_experts": 16,
"expert_map_path": "xxx.json"}}'`

### How was this patch tested?

#### test qwen3-235b eplb num_redundant_experts=16

without pr5817
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 83.33 |

with pr5817
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 86.67 |

- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
LI SHENGYONG
2026-01-15 10:26:44 +08:00
committed by GitHub
parent ea01aeaab7
commit da958ee386
21 changed files with 174 additions and 349 deletions

View File

@@ -17,15 +17,12 @@
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove eplb utils.
import json
import os.path
import sys
from collections import defaultdict
import numpy as np
import torch
from vllm.logger import logger
import vllm_ascend.envs as envs_ascend
def expert_file_to_tensor(expert_map_path, layer_id):
with open(expert_map_path, "r") as f:
@@ -56,13 +53,13 @@ def generate_global_placement(n_expert, ep_size, n_redundant):
return torch.tensor(groups, dtype=torch.int32)
def init_eplb_config(ascend_config, layer_id, moe_config):
expert_map_path = ascend_config.expert_map_path
def init_eplb_config(eplb_config, layer_id, moe_config):
expert_map_path = eplb_config.expert_map_path
n_experts = moe_config.num_experts
ep_size = moe_config.ep_size
global_placement = None
eplb_enable = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
n_redundant = ascend_config.init_redundancy_expert if eplb_enable else 0
eplb_enable = eplb_config.dynamic_eplb
n_redundant = eplb_config.num_redundant_experts if eplb_enable else 0
if expert_map_path:
if not (os.path.exists(expert_map_path)
and os.access(expert_map_path, os.R_OK)):
@@ -83,6 +80,7 @@ def init_eplb_config(ascend_config, layer_id, moe_config):
n_redundant)
if ep_size == 1:
assert not eplb_enable, "EPLB must used in expert parallelism."
return None, None, n_redundant
global_expert_map = []
for rankid in range(ep_size):
@@ -116,73 +114,3 @@ def generate_log2phy_map(global_expert_map, ep_rank):
torch.tensor(list(log2phy_map.values()), dtype=torch.int32))
return log2phy_map
class EPLBParamUtils:
@staticmethod
def check_iterations(iterations):
if not isinstance(iterations, int):
raise TypeError(f"The {iterations} is not int.")
if iterations <= 0:
raise ValueError(
f"The {iterations} can not less than or equal to 0.")
if iterations > sys.maxsize:
raise ValueError(
f"The {iterations} can not large than {sys.maxsize}")
@staticmethod
def check_dynamic_eplb(dynamic_eplb):
if dynamic_eplb is None:
return
if not isinstance(dynamic_eplb, bool):
raise TypeError("The dynamic_eplb is not bool.")
if dynamic_eplb and envs_ascend.DYNAMIC_EPLB not in ("true", "1"):
raise ValueError(
'Can not enable dynamic_eplb when DYNAMIC_EPLB is not set to "true" or "1".'
)
@staticmethod
def check_expert_map_path(expert_map):
if expert_map is None:
return
if not isinstance(expert_map, str):
raise TypeError("The expert_map is not str.")
if not expert_map.strip():
raise ValueError("The expert_map is not empty.")
_, ext = os.path.splitext(expert_map)
if ext.lower() != ".json":
raise TypeError("The expert_map is not json.")
if not os.path.exists(expert_map):
raise ValueError("The expert_map is not exist.")
try:
with open(expert_map, "w", encoding='utf-8') as f:
f.read()
except Exception as e:
raise IOError(
f"Fail read expert info from {expert_map}, please check the reading permission of {expert_map} : {e}"
)
@staticmethod
def check_expert_map_record_path(expert_map_record_path):
if expert_map_record_path is None:
return
if not isinstance(expert_map_record_path, str):
raise TypeError("The expert_map_record_path is not str.")
if not expert_map_record_path.strip():
raise ValueError("The expert_map_record_path is empty.")
_, ext = os.path.splitext(expert_map_record_path)
if ext.lower() != ".json":
raise TypeError("The expert_map_record_path is not json.")
if os.getenv("EXPERT_MAP_RECORD", "false") != "true":
raise ValueError(
'Can not enable expert_map_record_path when not export EXPERT_MAP_RECORD="true".'
)
try:
with open(expert_map_record_path, "w", encoding='utf-8') as f:
f.write("")
except Exception as e:
raise IOError(
f"Fail write expert info to {expert_map_record_path}, please check the writing permission of {expert_map_record_path} : {e}"
)