diff --git a/tests/ut/ops/expert_map.json b/tests/ut/eplb/core/expert_map.json similarity index 100% rename from tests/ut/ops/expert_map.json rename to tests/ut/eplb/core/expert_map.json diff --git a/tests/ut/eplb/core/test_eplb_utils.py b/tests/ut/eplb/core/test_eplb_utils.py index 28d5d425..530dbf4f 100644 --- a/tests/ut/eplb/core/test_eplb_utils.py +++ b/tests/ut/eplb/core/test_eplb_utils.py @@ -1,49 +1,67 @@ -import random +import os import sys +import unittest from unittest.mock import patch +# isort: off import pytest import torch +from vllm.config import VllmConfig +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEParallelConfig + ) -from vllm_ascend.eplb.core import eplb_utils -from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils +from vllm_ascend.ascend_config import init_ascend_config +from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils, init_eplb_config +# isort: on -def test_generate_log2phy_map_single_rank_holding(): +class TestAscendConfig(unittest.TestCase): - expert_map = torch.tensor([[0, -1], [-1, 0]], dtype=torch.int32) - log2phy_map = eplb_utils.generate_log2phy_map(expert_map) + def setUp(self): + vllm_config = VllmConfig() + ascend_config = init_ascend_config(vllm_config) + ascend_config.dynamic_eplb = True + ascend_config.init_redundancy_expert = 2 + moe_parallel_config = FusedMoEParallelConfig(2, 0, 1, 2, 1, 1, 1, 1, + True, "hccl") + moe_config = FusedMoEConfig(8, 8, 8192, 5, moe_parallel_config, + torch.float16) + moe_config.supports_eplb = True + self.ascend_config = ascend_config + self.moe_config = moe_config + self.mock_npu = patch("torch.Tensor.npu", + new=lambda self: self).start() - assert torch.all(log2phy_map[:, 0] == log2phy_map[0, 0]) - assert torch.all(log2phy_map[:, 1] == log2phy_map[1, 1]) + def test_init_eplb_config_with_eplb(self): + expert_map, log2phy, redundant_experts = init_eplb_config( + self.ascend_config, 0, self.moe_config) + gt_expert_map = torch.tensor([4, -1, -1, -1, 0, 1, 2, 3]) + gt_log2phy = torch.tensor([9, 1, 2, 3, 5, 6, 7, 8]) + self.assertTrue(torch.equal(expert_map, gt_expert_map)) + self.assertTrue(torch.equal(log2phy, gt_log2phy)) + self.assertEqual(redundant_experts, 2) + def test_init_eplb_config_with_eplb_withmap(self): + _TEST_DIR = os.path.dirname(__file__) + self.ascend_config.expert_map_path = _TEST_DIR + "/expert_map.json" + expert_map, log2phy, redundant_experts = init_eplb_config( + self.ascend_config, 0, self.moe_config) + gt_expert_map = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3]) + gt_log2phy = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8]) + self.assertTrue(torch.equal(expert_map, gt_expert_map)) + self.assertTrue(torch.equal(log2phy, gt_log2phy)) + self.assertEqual(redundant_experts, 2) -def test_generate_log2phy_map_multiple_rank_holding(monkeypatch): - - expert_map = torch.tensor([[0], [0]], dtype=torch.int32) - - monkeypatch.setattr(random, "choice", lambda x: x[0]) - - log2phy_map = eplb_utils.generate_log2phy_map(expert_map) - - assert log2phy_map.shape == (2, 1) - assert (log2phy_map >= 0).all() - - -def test_determine_default_log2phy_map_world_size_1(): - log2phy = eplb_utils.determine_default_log2phy_map(global_expert_num=3, - world_size=1, - rank_id=0) - assert log2phy.shape == (3, ) - assert (log2phy >= 0).all() - - -def test_determine_default_log2phy_map_world_size_multiple(): - log2phy = eplb_utils.determine_default_log2phy_map(global_expert_num=6, - world_size=2, - rank_id=1) - assert log2phy.shape == (6, ) - assert (log2phy >= 0).all() + def test_init_eplb_config_without_eplb(self): + self.ascend_config.dynamic_eplb = False + self.ascend_config.expert_map_path = None + expert_map, log2phy, redundant_experts = init_eplb_config( + self.ascend_config, 0, self.moe_config) + gt_expert_map = torch.tensor([-1, -1, -1, -1, 0, 1, 2, 3]) + print(expert_map, log2phy, redundant_experts) + self.assertTrue(torch.equal(expert_map, gt_expert_map)) + self.assertEqual(redundant_experts, 0) class TestEPLBParamUtils: diff --git a/tests/ut/ops/test_expert_load_balancer.py b/tests/ut/ops/test_expert_load_balancer.py deleted file mode 100644 index f7f68472..00000000 --- a/tests/ut/ops/test_expert_load_balancer.py +++ /dev/null @@ -1,140 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -import json -import os -from typing import List, TypedDict -from unittest import mock - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer - - -class Device(TypedDict): - device_id: int - device_expert: List[int] - - -class Layer(TypedDict): - layer_id: int - device_count: int - device_list: List[Device] - - -class MockData(TypedDict): - moe_layer_count: int - layer_list: List[Layer] - - -class TestExpertLoadBalancer(TestBase): - - def setUp(self): - _TEST_DIR = os.path.dirname(__file__) - json_file = _TEST_DIR + "/expert_map.json" - with open(json_file, 'r') as f: - self.expert_map: MockData = json.load(f) - - self.expert_load_balancer = ExpertLoadBalancer(json_file, 8) - - def test_init(self): - - self.assertIsInstance(self.expert_load_balancer.expert_map_tensor, - torch.Tensor) - self.assertEqual(self.expert_load_balancer.layers_num, - self.expert_map["moe_layer_count"]) - self.assertEqual(self.expert_load_balancer.ranks_num, - self.expert_map["layer_list"][0]["device_count"]) - - def test_generate_index_dicts(self): - tensor_2d = torch.tensor([[7, 2, 0, 3, 5], [6, 1, 4, 7, 2]]) - result = self.expert_load_balancer.generate_index_dicts(tensor_2d) - expected_result = [{ - 7: 0, - 2: 1, - 0: 2, - 3: 3, - 5: 4 - }, { - 6: 5, - 1: 6, - 4: 7, - 7: 8, - 2: 9 - }] - self.assertEqual(result, expected_result) - - def test_generate_expert_placement_map(self): - expert_placement_map = self.expert_load_balancer.generate_expert_placement_map( - ) - self.assertEqual(expert_placement_map.shape, - (self.expert_load_balancer.layers_num, - self.expert_load_balancer.ranks_num, 10)) - self.assertTrue(torch.all(expert_placement_map >= -1)) - - def test_generate_log2phy_expert_map(self): - layer_id = 0 - log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map( - layer_id) - self.assertEqual(log2phy_map.shape, - (self.expert_load_balancer.ranks_num, 10)) - self.assertTrue(torch.all(log2phy_map >= -1)) - - @mock.patch("torch_npu.npu._lazy_init") - @mock.patch("torch.npu.current_device", return_value="cpu") - def test_get_rank_placement_map(self, mock_current_device, mock_lazy_init): - layer_id = 0 - rank_id = 0 - rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map( - layer_id, rank_id) - self.assertEqual(rank_local_expert_num, 5) - expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0, -1, -1], - dtype=torch.int32).to( - rank_expert_map.device) - self.assertTrue(rank_expert_map.equal(expected_tensor)) - - rank_id = 1 - rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map( - layer_id, rank_id) - expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3, -1, -1], - dtype=torch.int32).to( - rank_expert_map.device) - self.assertTrue(rank_expert_map.equal(expected_tensor)) - - def test_get_rank_log2phy_map(self): - layer_id = 0 - rank_id = 0 - log2phy_map = self.expert_load_balancer.get_rank_log2phy_map( - layer_id, rank_id) - expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0, -1, -1], - dtype=torch.int32).to( - log2phy_map.device) - self.assertTrue(log2phy_map.equal(expected_tensor)) - - rank_id = 1 - log2phy_map = self.expert_load_balancer.get_rank_log2phy_map( - layer_id, rank_id) - expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8, -1, -1], - dtype=torch.int32).to( - log2phy_map.device) - self.assertTrue(log2phy_map.equal(expected_tensor)) - - def test_get_global_redundant_expert_num(self): - redundant_expert_num = self.expert_load_balancer.get_global_redundant_expert_num( - ) - expected_redundant_expert_num = len(self.expert_map["layer_list"][0]["device_list"][0]["device_expert"]) * \ - self.expert_map["layer_list"][0]["device_count"] - 8 - self.assertEqual(redundant_expert_num, expected_redundant_expert_num) diff --git a/tests/ut/ops/test_fused_moe.py b/tests/ut/ops/test_fused_moe.py index d82e46d8..2d519b92 100644 --- a/tests/ut/ops/test_fused_moe.py +++ b/tests/ut/ops/test_fused_moe.py @@ -117,8 +117,8 @@ def mock_dist_env(mocker: MockerFixture): enable_multistream_moe=False, expert_map_path=None )), \ - patch('vllm_ascend.ops.fused_moe.fused_moe.determine_expert_map', - return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ + patch('vllm_ascend.ops.fused_moe.fused_moe.init_eplb_config', + return_value=(torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]), None, 0)), \ patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context', return_value=mock_forward_context_obj), \ patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context', diff --git a/vllm_ascend/eplb/core/eplb_utils.py b/vllm_ascend/eplb/core/eplb_utils.py index b43b85b6..4920de30 100644 --- a/vllm_ascend/eplb/core/eplb_utils.py +++ b/vllm_ascend/eplb/core/eplb_utils.py @@ -15,87 +15,111 @@ # This file is a part of the vllm-ascend project. # # Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove eplb utils. +import json import os.path -import random import sys +from collections import defaultdict +import numpy as np import torch from vllm.logger import logger import vllm_ascend.envs as envs_ascend -def generate_log2phy_map(expert_map): - num_local_experts = expert_map.max() + 1 - log2phy_map = expert_map.clone() - num_ranks, num_global_expert = log2phy_map.shape +def expert_file_to_tensor(expert_map_path, layer_id): + with open(expert_map_path, "r") as f: + data = json.load(f) + physical_count = 0 + device_data = [] + if layer_id > data["moe_layer_count"]: + raise ValueError("Invalid EPLB Table") + if layer_id == data["moe_layer_count"]: + logger.warning("Init expert map of mtp/eagle when using sample.") + return None, None + for device in data["layer_list"][layer_id]["device_list"]: + physical_count += len(device["device_expert"]) + device_data.append(device["device_expert"]) + global_placement = torch.tensor(device_data, dtype=torch.int32) + return global_placement, physical_count - row_indices = torch.arange(num_ranks).view(-1, 1).expand(num_ranks, \ - num_global_expert) * num_local_experts - log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1] - for idx in range(num_global_expert): - positive_rank_idx = torch.where(log2phy_map[:, idx] != -1)[0] - negative_rank_idx = torch.where(log2phy_map[:, idx] == -1)[0] - num_rank_holding_expert = positive_rank_idx.size(0) - - if num_rank_holding_expert == 0: - log2phy_map[:, idx] = torch.full((num_ranks, ), - 0, - dtype=log2phy_map.dtype) - - if num_rank_holding_expert == 1: - log2phy_map[negative_rank_idx, idx] = torch.full( - (num_ranks - 1, ), - log2phy_map[positive_rank_idx, idx].item(), - dtype=log2phy_map.dtype) +def generate_global_placement(n_expert, ep_size, n_redundant): + all_experts = np.arange(n_expert) + groups = np.array_split(all_experts, ep_size) + for i in range(n_redundant): + j = i % ep_size + 1 + if len(groups[-j]) == 0: + groups[-j] = np.append(groups[-j], j) else: - try: - random_list = [ - random.choice(log2phy_map[positive_rank_idx, idx]) - for _ in range(num_ranks - num_rank_holding_expert) - ] - log2phy_map[negative_rank_idx, - idx] = torch.tensor(random_list, - dtype=log2phy_map.dtype) - except Exception as e: - logger.error(f"Fail to get log2phy_map: {str(e)}") + groups[-j] = np.append(groups[-j], (groups[-j][-1] + 1) % n_expert) + return torch.tensor(groups, dtype=torch.int32) + + +def init_eplb_config(ascend_config, layer_id, moe_config): + expert_map_path = ascend_config.expert_map_path + n_experts = moe_config.num_experts + ep_size = moe_config.ep_size + global_placement = None + eplb_enable = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path + n_redundant = ascend_config.init_redundancy_expert if eplb_enable else 0 + if expert_map_path: + if not (os.path.exists(expert_map_path) + and os.access(expert_map_path, os.R_OK)): + raise ValueError("Invalid EPLB path") + eplb_enable = True + global_placement, physical_count = expert_file_to_tensor( + expert_map_path, layer_id) + if physical_count is not None: + n_redundant = physical_count - n_experts + if not moe_config.supports_eplb: + raise ValueError( + "Eplb supports only w8a8_dynamic quantization.") + else: + eplb_enable = False + + if global_placement is None: + global_placement = generate_global_placement(n_experts, ep_size, + n_redundant) + + if ep_size == 1: + return None, None, n_redundant + global_expert_map = [] + for rankid in range(ep_size): + expert_map = torch.full((n_experts, ), -1, dtype=torch.int32) + local_placement = global_placement[rankid] + expert_map[local_placement] = torch.arange(local_placement.shape[0], + dtype=torch.int32) + global_expert_map.append(expert_map) + local_expert_map = global_expert_map[moe_config.ep_rank].npu() + log2phy = generate_log2phy_map( + global_expert_map, moe_config.ep_rank).npu() if eplb_enable else None + + return local_expert_map, log2phy, n_redundant + + +def generate_log2phy_map(global_expert_map, ep_rank): + log2phy_map = defaultdict(list) + valid_count = torch.sum(global_expert_map[0] != -1) + for rankid, map_per_rank in enumerate(global_expert_map): + for idx, val in enumerate(map_per_rank): + val = val.item() + # 计算value:当前值 + i * 有效元素个数 + if val != -1: + log2phy_map[idx].append(val + rankid * valid_count) + + for key in log2phy_map.keys(): + num_of_duplications = len(log2phy_map[key]) + log2phy_map[key] = log2phy_map[key][ep_rank % num_of_duplications] + + log2phy_map = torch.scatter( + torch.zeros(len(log2phy_map.keys()), dtype=torch.int32), 0, + torch.tensor(list(log2phy_map.keys()), dtype=torch.int64), + torch.tensor(list(log2phy_map.values()), dtype=torch.int32)) return log2phy_map -def determine_default_log2phy_map(global_expert_num, world_size, rank_id): - if world_size == 1: - local_ids = torch.arange(global_expert_num, dtype=torch.int32) - expert_map_all = local_ids.unsqueeze(0).expand(world_size, -1) - log2phy_map_all = generate_log2phy_map(expert_map_all) - return log2phy_map_all[rank_id] - - local_num_experts = global_expert_num // world_size - - expert_map_all = torch.full((world_size, global_expert_num), - -1, - dtype=torch.int32) - - for r in range(world_size): - if r < world_size - 1: - start = r * local_num_experts - end = (r + 1) * local_num_experts - local_count = local_num_experts - else: - start = r * local_num_experts - end = global_expert_num - local_count = global_expert_num - r * local_num_experts - - if isinstance(local_count, int): - local_ids = torch.arange(local_count, dtype=torch.int32) - expert_map_all[r, start:end] = local_ids - - log2phy_map_all = generate_log2phy_map(expert_map_all) - - return log2phy_map_all[rank_id] - - class EPLBParamUtils: @staticmethod diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py index 15dbbd02..03bf126f 100644 --- a/vllm_ascend/eplb/core/eplb_worker.py +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -377,8 +377,8 @@ class EplbWorker: maps.append(new_expert_map[self.rank_id].numpy().tolist()) - log2phy_map = generate_log2phy_map(new_expert_map) - log2phy_all.append(log2phy_map[self.rank_id].numpy().tolist()) + log2phy_map = generate_log2phy_map(new_expert_map, self.rank_id) + log2phy_all.append(log2phy_map.numpy().tolist()) layer_ids.append(layer_id) diff --git a/vllm_ascend/ops/expert_load_balancer.py b/vllm_ascend/ops/expert_load_balancer.py deleted file mode 100644 index 7e8a9aef..00000000 --- a/vllm_ascend/ops/expert_load_balancer.py +++ /dev/null @@ -1,118 +0,0 @@ -import json -import random -from typing import Dict, List - -import torch -import torch.distributed as dist - - -class ExpertLoadBalancer(object): - - def __init__(self, expert_map_path, num_experts): - self.expert_map_path = expert_map_path - self.num_experts = num_experts - self.tensor_data = [] - self.expert_map_tensor, self.layers_num, self.ranks_num = ( - self._expert_file_to_tensor()) - self.global_expert_num = num_experts + self.get_global_redundant_expert_num( - ) - self.expert_placement_map = self.generate_expert_placement_map() - - def _expert_file_to_tensor(self): - with open(self.expert_map_path, "r") as f: - data = json.load(f) - layers_num = data["moe_layer_count"] - gpus_num = data["layer_list"][0]["device_count"] - for layer in data["layer_list"]: - device_data = [] - for device in layer["device_list"]: - device_data.append(device["device_expert"]) - self.tensor_data.append(device_data) - expert_map_tensor = torch.tensor(self.tensor_data, dtype=torch.int32) - return expert_map_tensor, layers_num, gpus_num - - def generate_index_dicts(self, tensor_2d): - dict_list = [] - current_idx = 0 - - for row in tensor_2d: - value_to_index = {} - for i in range(row.size(0)): - value = row[i].item() - value_to_index[value] = current_idx + i - dict_list.append(value_to_index) - current_idx += row.size(0) - - return dict_list - - def generate_expert_placement_map(self): - expert_placement_map = torch.full( - (self.layers_num, self.ranks_num, self.global_expert_num), - -1, - dtype=torch.int32, - ) - for layer_id in range(self.layers_num): - for gpu_id in range(self.ranks_num): - e_ids = self.expert_map_tensor[layer_id, gpu_id] - expert_placement_map[layer_id, gpu_id, - e_ids] = torch.arange(len(e_ids), - dtype=torch.int32) - return expert_placement_map - - def generate_log2phy_expert_map(self, layer_id): - concatenated = torch.flatten(self.expert_map_tensor[layer_id]) - rank_expert_to_global = self.generate_index_dicts( - self.expert_map_tensor[layer_id]) - result_dict: Dict[int, List[int]] = {} - for idx, value in enumerate(concatenated): - key = value.item() - if key not in result_dict: - result_dict[key] = [] - result_dict[key].append(idx) - - log2phy_map = torch.full((self.ranks_num, self.global_expert_num), - -1, - dtype=torch.int32) - for rank in range(self.ranks_num): - for key in result_dict: - indices_in_concat = result_dict[key] - if key in rank_expert_to_global[rank]: - log2phy_map[rank][key] = rank_expert_to_global[rank][key] - else: - chosen_index = random.choice(indices_in_concat) - log2phy_map[rank][key] = chosen_index - return log2phy_map - - def get_rank_placement_map(self, layer_id, rank_id): - layer_expert_map = self.expert_placement_map[layer_id] - rank_expert_map = layer_expert_map[rank_id].to( - torch.npu.current_device()) - rank_local_expert_num = torch.sum(torch.ne(rank_expert_map, -1)).item() - return rank_local_expert_num, rank_expert_map - - def get_rank_log2phy_map(self, layer_id, rank_id): - layer_log2phy_map = self.generate_log2phy_expert_map(layer_id) - return layer_log2phy_map[rank_id] - - def get_global_redundant_expert_num(self): - global_redundant_expert_num = ( - len(self.expert_map_tensor[0][0]) * self.ranks_num - - self.num_experts) - return global_redundant_expert_num - - def check_expert_map_tensor(self): - if dist.is_initialized(): - try: - rank = dist.get_rank() - world_size = dist.get_world_size() - all_expert_maps = [None for _ in range(world_size)] - dist.all_gather_object(all_expert_maps, self.tensor_data) - for rank_id, expert_map_tensor in enumerate(all_expert_maps): - if self.tensor_data != expert_map_tensor: - raise ValueError( - f"The expert map of rank{rank} is not equal to rank{rank_id}" - ) - return True - except Exception as e: - raise ValueError( - f"The expert maps of all ranks are inconsistency: {e}") diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index a9547a5a..daaca8b9 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import os.path from typing import Any, Callable, Optional import torch @@ -25,19 +24,17 @@ from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map, - get_compressed_expert_map) + FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map) from vllm.model_executor.layers.fused_moe.shared_fused_moe import \ SharedFusedMoE from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map +from vllm_ascend.eplb.core.eplb_utils import init_eplb_config from vllm_ascend.eplb.utils import moe_load_async_stream from vllm_ascend.flash_common3_context import (get_flash_common3_context, set_flash_common3_context) -from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl, setup_moe_comm_method) @@ -164,11 +161,8 @@ class AscendFusedMoE(FusedMoE): self.moe_config.dp_group = get_dp_group() self.moe_config.ep_group = get_ep_group() self.moe_config.mc2_group = get_mc2_group() + self.moe_config.supports_eplb = self.quant_method.supports_eplb ascend_config = get_ascend_config() - self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path - self.expert_map_path = ascend_config.expert_map_path - self.global_redundant_expert_num = ascend_config.init_redundancy_expert - self.global_num_experts = num_experts + self.global_redundant_expert_num # flashcommon3 gate stream self.multistream_overlap_gate = ascend_config.multistream_overlap_gate if self.multistream_overlap_gate and AscendFusedMoE.gate_stream is None: @@ -178,66 +172,33 @@ class AscendFusedMoE(FusedMoE): self.e_score_correction_bias.data = self.e_score_correction_bias.data.to( dtype=vllm_config.model_config.dtype) - # init moe. - self.local_num_experts, self._expert_map, _ = determine_expert_map( - self.ep_size, self.ep_rank, self.global_num_experts) - # TODO: Temporary flag to indicate if static EPLB is enabled. This is a - # workaround to bypass a quantization check that fails with float weights. - init_eplb_enable = False - # static eplb initializing with expert_map_path - if self.expert_map_path and os.path.exists( - self.expert_map_path) and os.access(self.expert_map_path, - os.R_OK): - self.expert_load_balancer = ExpertLoadBalancer( - self.expert_map_path, num_experts) - self.expert_load_balancer.check_expert_map_tensor() - self.global_redundant_expert_num = ( - self.expert_load_balancer.get_global_redundant_expert_num()) - self.global_num_experts = num_experts + self.global_redundant_expert_num - try: - self.local_num_experts, self._expert_map = ( - self.expert_load_balancer.get_rank_placement_map( - self.moe_instance_id, self.ep_rank)) - self.log2phy = self.expert_load_balancer.get_rank_log2phy_map( - self.moe_instance_id, self.ep_rank).npu() - init_eplb_enable = True - except Exception as e: - logger.warning( - f"Init expert map of mtp/eagle when using sample.{e}") - self.log2phy = determine_default_log2phy_map( - self.global_num_experts, self.ep_size, self.ep_rank).npu() - else: - # dynamic eplb initializing with not expert_map_path - if self.dynamic_eplb: - self.log2phy = determine_default_log2phy_map( - self.global_num_experts, self.ep_size, self.ep_rank).npu() - if self._expert_map is not None and isinstance(self._expert_map, - torch.Tensor): + # init moe + self._expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config( + ascend_config, self.moe_instance_id, self.moe_config) + self.global_num_experts = num_experts + self.global_redundant_expert_num + self.dynamic_eplb = (ascend_config.dynamic_eplb + or ascend_config.expert_map_record_path) and ( + self.log2phy is not None) + self.local_num_experts = (torch.sum( + self._expert_map != -1) if self._expert_map is not None else + self.global_num_experts) + if self._expert_map is not None: logger.info_once( "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" " number of experts: %s/%s. Experts local to global index map:" " %s.", self.ep_rank, self.ep_size, self.local_num_experts, self.global_num_experts, get_compressed_expert_map(self._expert_map)) - local_num_experts = (torch.sum( - self._expert_map != -1) if self._expert_map is not None else - self.global_num_experts) if self.dynamic_eplb: - self.moe_load = torch.zeros(local_num_experts, + self.moe_load = torch.zeros(self.local_num_experts, dtype=torch.int64).npu() - if init_eplb_enable and ( - not hasattr(self.quant_method, "quant_method") - or not isinstance(self.quant_method.quant_method, - AscendW8A8DynamicFusedMoEMethod)): - raise ValueError("Eplb supports only w8a8_dynamic quantization.") - self.moe_config.num_experts = self.global_num_experts self.moe_config.num_local_experts = self.local_num_experts self.moe_config.original_num_experts = num_experts moe_quant_params = { - "num_experts": local_num_experts, + "num_experts": self.local_num_experts, "hidden_size": self.hidden_size, "intermediate_size_per_partition": self.intermediate_size_per_partition, @@ -373,7 +334,7 @@ class AscendFusedMoE(FusedMoE): renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, - expert_map=self.expert_map, + expert_map=self._expert_map, topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 8477fdb8..8c8b7518 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -536,6 +536,11 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): # TODO: implement this function pass + @property + def supports_eplb(self): + supports_eplb = getattr(self.quant_method, "supports_eplb", False) + return supports_eplb + class AscendEmbeddingMethod(AscendLinearMethod): """Embedding method for Ascend quantization. diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 1c158d09..986f6fd2 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -115,6 +115,7 @@ class AscendW8A8DynamicFusedMoEMethod: self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path self.in_dtype = vllm_config.model_config.dtype + self.supports_eplb = True try: device_group = get_mc2_group().device_group