eplb redundant expert bugfix (#4291)
### What this PR does / why we need it?
Redundant experts bugfix
### Does this PR introduce _any_ user-facing change?
After configuring the path for experts_map, users do not need to
configure iinit_redundancy_expert.
### How was this patch tested?
The accuracy of EPLB was tested with and without the use of redundant
experts.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
@@ -12,6 +12,13 @@ Expert balancing for MoE models in LLM serving is essential for optimal performa
|
|||||||
- Adaptive Scaling: Automatically adjusts to workload fluctuations while maintaining stable performance.
|
- Adaptive Scaling: Automatically adjusts to workload fluctuations while maintaining stable performance.
|
||||||
- Fault Tolerance: Redundant expert placement ensures system resilience during hardware failures.
|
- Fault Tolerance: Redundant expert placement ensures system resilience during hardware failures.
|
||||||
|
|
||||||
|
## Support Scenarios
|
||||||
|
|
||||||
|
### Models:
|
||||||
|
DeepseekV3/V3.1/R1、Qwen3-MOE
|
||||||
|
### MOE QuantType:
|
||||||
|
W8A8-dynamic
|
||||||
|
|
||||||
## How to Use EPLB
|
## How to Use EPLB
|
||||||
|
|
||||||
### Dynamic EPLB
|
### Dynamic EPLB
|
||||||
|
|||||||
@@ -9,39 +9,6 @@ from vllm_ascend.eplb.core import eplb_utils
|
|||||||
from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
|
from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
|
||||||
|
|
||||||
|
|
||||||
def test_determine_default_expert_map_single_world():
|
|
||||||
count, expert_map = eplb_utils.determine_default_expert_map(
|
|
||||||
global_expert_num=4,
|
|
||||||
world_size=1,
|
|
||||||
rank_id=0,
|
|
||||||
global_redundant_expert_num=0)
|
|
||||||
assert count == 4
|
|
||||||
assert torch.equal(expert_map, torch.arange(4, dtype=torch.int32))
|
|
||||||
|
|
||||||
|
|
||||||
def test_determine_default_expert_map_multiple_worlds_no_redundant():
|
|
||||||
count, expert_map = eplb_utils.determine_default_expert_map(
|
|
||||||
global_expert_num=8,
|
|
||||||
world_size=2,
|
|
||||||
rank_id=0,
|
|
||||||
global_redundant_expert_num=0)
|
|
||||||
|
|
||||||
assert count == 4
|
|
||||||
assert torch.all(expert_map[:4] >= 0)
|
|
||||||
assert torch.all(expert_map[4:] == -1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_determine_default_expert_map_multiple_worlds_with_redundant():
|
|
||||||
count, expert_map = eplb_utils.determine_default_expert_map(
|
|
||||||
global_expert_num=5,
|
|
||||||
world_size=2,
|
|
||||||
rank_id=0,
|
|
||||||
global_redundant_expert_num=1)
|
|
||||||
|
|
||||||
assert count == 2
|
|
||||||
assert torch.all(expert_map[0:2] >= 0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_log2phy_map_single_rank_holding():
|
def test_generate_log2phy_map_single_rank_holding():
|
||||||
|
|
||||||
expert_map = torch.tensor([[0, -1], [-1, 0]], dtype=torch.int32)
|
expert_map = torch.tensor([[0, -1], [-1, 0]], dtype=torch.int32)
|
||||||
@@ -64,21 +31,17 @@ def test_generate_log2phy_map_multiple_rank_holding(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
def test_determine_default_log2phy_map_world_size_1():
|
def test_determine_default_log2phy_map_world_size_1():
|
||||||
log2phy = eplb_utils.determine_default_log2phy_map(
|
log2phy = eplb_utils.determine_default_log2phy_map(global_expert_num=3,
|
||||||
global_expert_num=3,
|
|
||||||
world_size=1,
|
world_size=1,
|
||||||
rank_id=0,
|
rank_id=0)
|
||||||
global_redundant_expert_num=0)
|
|
||||||
assert log2phy.shape == (3, )
|
assert log2phy.shape == (3, )
|
||||||
assert (log2phy >= 0).all()
|
assert (log2phy >= 0).all()
|
||||||
|
|
||||||
|
|
||||||
def test_determine_default_log2phy_map_world_size_multiple():
|
def test_determine_default_log2phy_map_world_size_multiple():
|
||||||
log2phy = eplb_utils.determine_default_log2phy_map(
|
log2phy = eplb_utils.determine_default_log2phy_map(global_expert_num=6,
|
||||||
global_expert_num=6,
|
|
||||||
world_size=2,
|
world_size=2,
|
||||||
rank_id=1,
|
rank_id=1)
|
||||||
global_redundant_expert_num=1)
|
|
||||||
assert log2phy.shape == (6, )
|
assert log2phy.shape == (6, )
|
||||||
assert (log2phy >= 0).all()
|
assert (log2phy >= 0).all()
|
||||||
|
|
||||||
|
|||||||
@@ -48,8 +48,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
with open(json_file, 'r') as f:
|
with open(json_file, 'r') as f:
|
||||||
self.expert_map: MockData = json.load(f)
|
self.expert_map: MockData = json.load(f)
|
||||||
|
|
||||||
self.expert_load_balancer = ExpertLoadBalancer(json_file,
|
self.expert_load_balancer = ExpertLoadBalancer(json_file, 8)
|
||||||
global_expert_num=8)
|
|
||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
|
|
||||||
@@ -83,7 +82,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(expert_placement_map.shape,
|
self.assertEqual(expert_placement_map.shape,
|
||||||
(self.expert_load_balancer.layers_num,
|
(self.expert_load_balancer.layers_num,
|
||||||
self.expert_load_balancer.ranks_num, 8))
|
self.expert_load_balancer.ranks_num, 10))
|
||||||
self.assertTrue(torch.all(expert_placement_map >= -1))
|
self.assertTrue(torch.all(expert_placement_map >= -1))
|
||||||
|
|
||||||
def test_generate_log2phy_expert_map(self):
|
def test_generate_log2phy_expert_map(self):
|
||||||
@@ -91,7 +90,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
|
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
|
||||||
layer_id)
|
layer_id)
|
||||||
self.assertEqual(log2phy_map.shape,
|
self.assertEqual(log2phy_map.shape,
|
||||||
(self.expert_load_balancer.ranks_num, 8))
|
(self.expert_load_balancer.ranks_num, 10))
|
||||||
self.assertTrue(torch.all(log2phy_map >= -1))
|
self.assertTrue(torch.all(log2phy_map >= -1))
|
||||||
|
|
||||||
@mock.patch("torch_npu.npu._lazy_init")
|
@mock.patch("torch_npu.npu._lazy_init")
|
||||||
@@ -102,7 +101,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
|
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
|
||||||
layer_id, rank_id)
|
layer_id, rank_id)
|
||||||
self.assertEqual(rank_local_expert_num, 5)
|
self.assertEqual(rank_local_expert_num, 5)
|
||||||
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0],
|
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0, -1, -1],
|
||||||
dtype=torch.int32).to(
|
dtype=torch.int32).to(
|
||||||
rank_expert_map.device)
|
rank_expert_map.device)
|
||||||
self.assertTrue(rank_expert_map.equal(expected_tensor))
|
self.assertTrue(rank_expert_map.equal(expected_tensor))
|
||||||
@@ -110,7 +109,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
rank_id = 1
|
rank_id = 1
|
||||||
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
|
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
|
||||||
layer_id, rank_id)
|
layer_id, rank_id)
|
||||||
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3],
|
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3, -1, -1],
|
||||||
dtype=torch.int32).to(
|
dtype=torch.int32).to(
|
||||||
rank_expert_map.device)
|
rank_expert_map.device)
|
||||||
self.assertTrue(rank_expert_map.equal(expected_tensor))
|
self.assertTrue(rank_expert_map.equal(expected_tensor))
|
||||||
@@ -120,7 +119,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
rank_id = 0
|
rank_id = 0
|
||||||
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
|
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
|
||||||
layer_id, rank_id)
|
layer_id, rank_id)
|
||||||
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0],
|
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0, -1, -1],
|
||||||
dtype=torch.int32).to(
|
dtype=torch.int32).to(
|
||||||
log2phy_map.device)
|
log2phy_map.device)
|
||||||
self.assertTrue(log2phy_map.equal(expected_tensor))
|
self.assertTrue(log2phy_map.equal(expected_tensor))
|
||||||
@@ -128,7 +127,7 @@ class TestExpertLoadBalancer(TestBase):
|
|||||||
rank_id = 1
|
rank_id = 1
|
||||||
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
|
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
|
||||||
layer_id, rank_id)
|
layer_id, rank_id)
|
||||||
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8],
|
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8, -1, -1],
|
||||||
dtype=torch.int32).to(
|
dtype=torch.int32).to(
|
||||||
log2phy_map.device)
|
log2phy_map.device)
|
||||||
self.assertTrue(log2phy_map.equal(expected_tensor))
|
self.assertTrue(log2phy_map.equal(expected_tensor))
|
||||||
|
|||||||
@@ -25,32 +25,6 @@ from vllm.logger import logger
|
|||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
|
|
||||||
|
|
||||||
def determine_default_expert_map(global_expert_num, world_size, rank_id,
|
|
||||||
global_redundant_expert_num):
|
|
||||||
if world_size == 1:
|
|
||||||
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
|
|
||||||
return (global_expert_num, local_ids)
|
|
||||||
|
|
||||||
local_num_experts = global_expert_num // world_size
|
|
||||||
|
|
||||||
expert_map = torch.full((global_expert_num, ), -1, dtype=torch.int32)
|
|
||||||
|
|
||||||
if rank_id < world_size - 1:
|
|
||||||
start = rank_id * local_num_experts
|
|
||||||
end = (rank_id + 1) * local_num_experts
|
|
||||||
local_count = local_num_experts
|
|
||||||
else:
|
|
||||||
start = rank_id * local_num_experts
|
|
||||||
end = global_expert_num
|
|
||||||
local_count = global_expert_num - rank_id * local_num_experts
|
|
||||||
|
|
||||||
if isinstance(local_count, int):
|
|
||||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
|
||||||
expert_map[start:end] = local_ids
|
|
||||||
|
|
||||||
return (local_count, expert_map)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_log2phy_map(expert_map):
|
def generate_log2phy_map(expert_map):
|
||||||
num_local_experts = expert_map.max() + 1
|
num_local_experts = expert_map.max() + 1
|
||||||
log2phy_map = expert_map.clone()
|
log2phy_map = expert_map.clone()
|
||||||
@@ -90,8 +64,7 @@ def generate_log2phy_map(expert_map):
|
|||||||
return log2phy_map
|
return log2phy_map
|
||||||
|
|
||||||
|
|
||||||
def determine_default_log2phy_map(global_expert_num, world_size, rank_id,
|
def determine_default_log2phy_map(global_expert_num, world_size, rank_id):
|
||||||
global_redundant_expert_num):
|
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
|
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
|
||||||
expert_map_all = local_ids.unsqueeze(0).expand(world_size, -1)
|
expert_map_all = local_ids.unsqueeze(0).expand(world_size, -1)
|
||||||
|
|||||||
@@ -8,12 +8,14 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
class ExpertLoadBalancer(object):
|
class ExpertLoadBalancer(object):
|
||||||
|
|
||||||
def __init__(self, expert_map_path, global_expert_num):
|
def __init__(self, expert_map_path, num_experts):
|
||||||
self.expert_map_path = expert_map_path
|
self.expert_map_path = expert_map_path
|
||||||
self.global_expert_num = global_expert_num
|
self.num_experts = num_experts
|
||||||
self.tensor_data = []
|
self.tensor_data = []
|
||||||
self.expert_map_tensor, self.layers_num, self.ranks_num = (
|
self.expert_map_tensor, self.layers_num, self.ranks_num = (
|
||||||
self._expert_file_to_tensor())
|
self._expert_file_to_tensor())
|
||||||
|
self.global_expert_num = num_experts + self.get_global_redundant_expert_num(
|
||||||
|
)
|
||||||
self.expert_placement_map = self.generate_expert_placement_map()
|
self.expert_placement_map = self.generate_expert_placement_map()
|
||||||
|
|
||||||
def _expert_file_to_tensor(self):
|
def _expert_file_to_tensor(self):
|
||||||
@@ -95,7 +97,7 @@ class ExpertLoadBalancer(object):
|
|||||||
def get_global_redundant_expert_num(self):
|
def get_global_redundant_expert_num(self):
|
||||||
global_redundant_expert_num = (
|
global_redundant_expert_num = (
|
||||||
len(self.expert_map_tensor[0][0]) * self.ranks_num -
|
len(self.expert_map_tensor[0][0]) * self.ranks_num -
|
||||||
self.global_expert_num)
|
self.num_experts)
|
||||||
return global_redundant_expert_num
|
return global_redundant_expert_num
|
||||||
|
|
||||||
def check_expert_map_tensor(self):
|
def check_expert_map_tensor(self):
|
||||||
|
|||||||
@@ -32,8 +32,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
|
||||||
determine_default_log2phy_map)
|
|
||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
|
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
|
||||||
@@ -183,10 +182,8 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
AscendFusedMoE.moe_counter += 1
|
AscendFusedMoE.moe_counter += 1
|
||||||
self.moe_instance_id = AscendFusedMoE.moe_counter
|
self.moe_instance_id = AscendFusedMoE.moe_counter
|
||||||
|
|
||||||
self.global_num_experts = num_experts
|
|
||||||
self.expert_map = None
|
self.expert_map = None
|
||||||
self.log2phy = None
|
self.log2phy = None
|
||||||
self.global_redundant_expert_num = 0
|
|
||||||
|
|
||||||
if self.quant_config is None:
|
if self.quant_config is None:
|
||||||
self.quant_method = AscendUnquantizedFusedMoEMethod(
|
self.quant_method = AscendUnquantizedFusedMoEMethod(
|
||||||
@@ -210,15 +207,24 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
|
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
|
||||||
dtype=vllm_config.model_config.dtype)
|
dtype=vllm_config.model_config.dtype)
|
||||||
|
|
||||||
|
# init moe.
|
||||||
|
if vllm_version_is("0.11.0"):
|
||||||
|
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||||
|
self.ep_size, self.ep_rank, self.global_num_experts)
|
||||||
|
else:
|
||||||
|
self.local_num_experts, self.expert_map, _ = determine_expert_map(
|
||||||
|
self.ep_size, self.ep_rank, self.global_num_experts)
|
||||||
# static eplb initializing with expert_map_path
|
# static eplb initializing with expert_map_path
|
||||||
if self.expert_map_path and os.path.exists(
|
if self.expert_map_path and os.path.exists(
|
||||||
self.expert_map_path) and os.access(self.expert_map_path,
|
self.expert_map_path) and os.access(self.expert_map_path,
|
||||||
os.R_OK):
|
os.R_OK):
|
||||||
self.expert_load_balancer = ExpertLoadBalancer(
|
self.expert_load_balancer = ExpertLoadBalancer(
|
||||||
self.expert_map_path, self.global_num_experts)
|
self.expert_map_path, num_experts)
|
||||||
self.expert_load_balancer.check_expert_map_tensor()
|
self.expert_load_balancer.check_expert_map_tensor()
|
||||||
self.global_redundant_expert_num = (
|
self.global_redundant_expert_num = (
|
||||||
self.expert_load_balancer.get_global_redundant_expert_num())
|
self.expert_load_balancer.get_global_redundant_expert_num())
|
||||||
|
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
||||||
try:
|
try:
|
||||||
self.local_num_experts, self.expert_map = (
|
self.local_num_experts, self.expert_map = (
|
||||||
self.expert_load_balancer.get_rank_placement_map(
|
self.expert_load_balancer.get_rank_placement_map(
|
||||||
@@ -228,39 +234,15 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Init expert map of mtp/eagle when using sample.{e}")
|
f"Init expert map of mtp/eagle when using sample.{e}")
|
||||||
self.local_num_experts, self.expert_map = determine_default_expert_map(
|
|
||||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
|
||||||
self.global_redundant_expert_num)
|
|
||||||
self.log2phy = determine_default_log2phy_map(
|
self.log2phy = determine_default_log2phy_map(
|
||||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
self.global_num_experts, self.ep_size, self.ep_rank).npu()
|
||||||
self.global_redundant_expert_num).npu()
|
|
||||||
if self.expert_map is not None and isinstance(
|
|
||||||
self.expert_map, torch.Tensor):
|
|
||||||
logger.info_once(
|
|
||||||
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
|
|
||||||
" number of experts: %s/%s. Experts local to global index map:"
|
|
||||||
" %s.", self.ep_rank, self.ep_size, self.local_num_experts,
|
|
||||||
self.global_num_experts,
|
|
||||||
get_compressed_expert_map(self.expert_map))
|
|
||||||
else:
|
else:
|
||||||
# init moe.
|
|
||||||
if vllm_version_is("0.11.0"):
|
|
||||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
|
||||||
self.ep_size, self.ep_rank, self.global_num_experts)
|
|
||||||
else:
|
|
||||||
self.local_num_experts, self.expert_map, _ = determine_expert_map(
|
|
||||||
self.ep_size, self.ep_rank, self.global_num_experts)
|
|
||||||
# dynamic eplb initializing with not expert_map_path
|
# dynamic eplb initializing with not expert_map_path
|
||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
|
||||||
self.local_num_experts, self.expert_map = determine_default_expert_map(
|
|
||||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
|
||||||
self.global_redundant_expert_num)
|
|
||||||
self.log2phy = determine_default_log2phy_map(
|
self.log2phy = determine_default_log2phy_map(
|
||||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
self.global_num_experts, self.ep_size, self.ep_rank).npu()
|
||||||
self.global_redundant_expert_num).npu()
|
if self.expert_map is not None and isinstance(self.expert_map,
|
||||||
if self.expert_map is not None and isinstance(
|
torch.Tensor):
|
||||||
self.expert_map, torch.Tensor):
|
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
|
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
|
||||||
" number of experts: %s/%s. Experts local to global index map:"
|
" number of experts: %s/%s. Experts local to global index map:"
|
||||||
|
|||||||
@@ -342,7 +342,7 @@ class AscendW4A8DynamicFusedMoEMethod:
|
|||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
is_prefill: bool = True,
|
is_prefill: bool = True,
|
||||||
enable_force_load_balance: bool = True,
|
enable_force_load_balance: bool = False,
|
||||||
log2phy: torch.Tensor = None,
|
log2phy: torch.Tensor = None,
|
||||||
global_redundant_expert_num: int = 0,
|
global_redundant_expert_num: int = 0,
|
||||||
shared_experts: Optional[Any] = None,
|
shared_experts: Optional[Any] = None,
|
||||||
@@ -371,7 +371,8 @@ class AscendW4A8DynamicFusedMoEMethod:
|
|||||||
# to avoid accumulating too much tokens on a single rank.
|
# to avoid accumulating too much tokens on a single rank.
|
||||||
# currently it is only activated when doing profile runs.
|
# currently it is only activated when doing profile runs.
|
||||||
if enable_force_load_balance:
|
if enable_force_load_balance:
|
||||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
topk_ids = torch.randint_like(
|
||||||
|
topk_ids, 0, global_num_experts - global_redundant_expert_num)
|
||||||
|
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
is_prefill: bool = True,
|
is_prefill: bool = True,
|
||||||
enable_force_load_balance: bool = True,
|
enable_force_load_balance: bool = False,
|
||||||
log2phy: torch.Tensor = None,
|
log2phy: torch.Tensor = None,
|
||||||
global_redundant_expert_num: int = 0,
|
global_redundant_expert_num: int = 0,
|
||||||
shared_experts: Optional[Any] = None,
|
shared_experts: Optional[Any] = None,
|
||||||
@@ -242,7 +242,8 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
# to avoid accumulating too much tokens on a single rank.
|
# to avoid accumulating too much tokens on a single rank.
|
||||||
# currently it is only activated when doing profile runs.
|
# currently it is only activated when doing profile runs.
|
||||||
if enable_force_load_balance:
|
if enable_force_load_balance:
|
||||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
topk_ids = torch.randint_like(
|
||||||
|
topk_ids, 0, global_num_experts - global_redundant_expert_num)
|
||||||
|
|
||||||
topk_weights = topk_weights.to(self.in_dtype)
|
topk_weights = topk_weights.to(self.in_dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -43,8 +43,7 @@ from vllm.model_executor.layers.quantization.base_config import \
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
|
||||||
determine_default_log2phy_map)
|
|
||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
||||||
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
|
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
|
||||||
@@ -1042,7 +1041,7 @@ class TorchairAscendFusedMoE(FusedMoE):
|
|||||||
self.expert_map_path) and os.access(self.expert_map_path,
|
self.expert_map_path) and os.access(self.expert_map_path,
|
||||||
os.R_OK):
|
os.R_OK):
|
||||||
self.expert_load_balancer = ExpertLoadBalancer(
|
self.expert_load_balancer = ExpertLoadBalancer(
|
||||||
self.expert_map_path, self.global_num_experts)
|
self.expert_map_path, num_experts)
|
||||||
self.expert_load_balancer.check_expert_map_tensor()
|
self.expert_load_balancer.check_expert_map_tensor()
|
||||||
self.global_redundant_expert_num = (
|
self.global_redundant_expert_num = (
|
||||||
self.expert_load_balancer.get_global_redundant_expert_num())
|
self.expert_load_balancer.get_global_redundant_expert_num())
|
||||||
@@ -1052,15 +1051,14 @@ class TorchairAscendFusedMoE(FusedMoE):
|
|||||||
self.moe_instance_id, self.ep_rank))
|
self.moe_instance_id, self.ep_rank))
|
||||||
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
|
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
|
||||||
self.moe_instance_id, self.ep_rank).npu()
|
self.moe_instance_id, self.ep_rank).npu()
|
||||||
|
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Init expert map of mtp/eagle when using sample.{e}")
|
f"Init expert map of mtp/eagle when using sample.{e}")
|
||||||
self.local_num_experts, self.expert_map = determine_default_expert_map(
|
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
self.ep_size, self.ep_rank, self.global_num_experts)
|
||||||
self.global_redundant_expert_num)
|
|
||||||
self.log2phy = determine_default_log2phy_map(
|
self.log2phy = determine_default_log2phy_map(
|
||||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
self.global_num_experts, self.ep_size, self.ep_rank).npu()
|
||||||
self.global_redundant_expert_num).npu()
|
|
||||||
if self.expert_map is not None and isinstance(
|
if self.expert_map is not None and isinstance(
|
||||||
self.expert_map, torch.Tensor):
|
self.expert_map, torch.Tensor):
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
@@ -1079,13 +1077,8 @@ class TorchairAscendFusedMoE(FusedMoE):
|
|||||||
self.ep_size, self.ep_rank, self.global_num_experts)
|
self.ep_size, self.ep_rank, self.global_num_experts)
|
||||||
# dynamic eplb initializing with not expert_map_path
|
# dynamic eplb initializing with not expert_map_path
|
||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
|
||||||
self.local_num_experts, self.expert_map = determine_default_expert_map(
|
|
||||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
|
||||||
self.global_redundant_expert_num)
|
|
||||||
self.log2phy = determine_default_log2phy_map(
|
self.log2phy = determine_default_log2phy_map(
|
||||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
self.global_num_experts, self.ep_size, self.ep_rank).npu()
|
||||||
self.global_redundant_expert_num).npu()
|
|
||||||
if self.expert_map is not None and isinstance(
|
if self.expert_map is not None and isinstance(
|
||||||
self.expert_map, torch.Tensor):
|
self.expert_map, torch.Tensor):
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
|
|||||||
@@ -990,7 +990,9 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
|||||||
# to avoid accumulating too much tokens on a single rank.
|
# to avoid accumulating too much tokens on a single rank.
|
||||||
# currently it is only activated when doing profile runs.
|
# currently it is only activated when doing profile runs.
|
||||||
if enable_force_load_balance:
|
if enable_force_load_balance:
|
||||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
topk_ids = torch.randint_like(
|
||||||
|
topk_ids, 0,
|
||||||
|
global_num_experts - global_redundant_expert_num)
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
if fused_moe_state == FusedMoEState.AllGatherEP:
|
if fused_moe_state == FusedMoEState.AllGatherEP:
|
||||||
|
|||||||
Reference in New Issue
Block a user