diff --git a/tests/ut/eplb/adaptor/test_abstract_adaptor.py b/tests/ut/eplb/adaptor/test_abstract_adaptor.py deleted file mode 100644 index 9929cb77..00000000 --- a/tests/ut/eplb/adaptor/test_abstract_adaptor.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest - -from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor - - -class DummyAdaptor(EplbAdaptor): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.args = kwargs - - def get_rank_expert_workload(self): - return "workload" - - def do_update_expert_map(self, layer_id, updated_expert_map): - return {"layer_id": layer_id, "map": updated_expert_map} - - def do_update_expert_weight(self, layer_id, local_expert_to_replace, - buffer_tensor_id): - return { - "layer_id": layer_id, - "replace": local_expert_to_replace, - "buffer": buffer_tensor_id, - } - - -def test_base_class_methods_raise(): - adaptor = EplbAdaptor() - with pytest.raises(NotImplementedError): - adaptor.get_rank_expert_workload() - with pytest.raises(NotImplementedError): - adaptor.do_update_expert_map(1, {}) - with pytest.raises(NotImplementedError): - adaptor.do_update_expert_weight(1, "x", "y") - - -def test_dummy_adaptor_init_and_args(): - adaptor = DummyAdaptor(test_arg=123) - assert adaptor.args["test_arg"] == 123 - - -def test_get_rank_expert_workload(): - adaptor = DummyAdaptor() - result = adaptor.get_rank_expert_workload() - assert result == "workload" - - -def test_do_update_expert_map(): - adaptor = DummyAdaptor() - updated = {"expert": 1} - result = adaptor.do_update_expert_map(2, updated) - assert result["layer_id"] == 2 - assert result["map"] == updated - - -def test_do_update_expert_weight(): - adaptor = DummyAdaptor() - result = adaptor.do_update_expert_weight(1, "expertA", "bufferX") - assert result["layer_id"] == 1 - assert result["replace"] == "expertA" - assert result["buffer"] == "bufferX" diff --git a/tests/ut/eplb/adaptor/test_vllm_adaptor.py b/tests/ut/eplb/adaptor/test_vllm_adaptor.py new file mode 100644 index 00000000..a5dc0559 --- /dev/null +++ b/tests/ut/eplb/adaptor/test_vllm_adaptor.py @@ -0,0 +1,39 @@ +import unittest +from unittest.mock import MagicMock, patch + +import torch + +from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor +from transformers import DeepseekV2Config + + +class TestVllmAdaptor(unittest.TestCase): + def setUp(self): + n_routed_experts = 256 + mock_model = MagicMock() + mock_model.model.named_parameters.return_value = dict() + config = DeepseekV2Config(n_routed_experts=n_routed_experts) + mock_model.config = config + mock_model.get_expert_map.return_value = [i for i in range(n_routed_experts)] + mock_model.get_log2phy_map.return_value = [i for i in range(n_routed_experts)] + self.model = mock_model + + self.mock_rank = patch("vllm_ascend.eplb.adaptor.vllm_adaptor.dist.get_rank", return_value=0).start() + self.mock_size = patch("vllm_ascend.eplb.adaptor.vllm_adaptor.dist.get_world_size", return_value=4).start() + + @patch("torch.empty_like", return_value=torch.zeros(16, 32)) + def test_init_fp16(self, mock_func): + self.model.quant_config = None + VllmEplbAdaptor(self.model) + + @patch("torch.empty_like", return_value=torch.zeros(16, 32)) + def test_init_w8a8(self, mock_func): + VllmEplbAdaptor(self.model) + + def tearDown(self): + self.mock_rank.stop() + self.mock_size.stop() + +if __name__ == "__main__": + unittest.main() + \ No newline at end of file diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 8d718213..1e3783c5 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -31,46 +31,19 @@ class VllmEplbAdaptor(EplbAdaptor): self.model = model self.rank_id = dist.get_rank() self.world_size = dist.get_world_size() - self.param_dict = dict(self.model.named_parameters()) self.num_dense_layers = getattr(self.model.config, "first_k_dense_replace", 0) self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers - for i in range(self.num_dense_layers, self.model.config.num_hidden_layers): - self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = self.model.model.layers[ - i - ].mlp.experts.w13_weight_list - self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = self.model.model.layers[ - i - ].mlp.experts.w2_weight_list - self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = ( - self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list - ) - self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = ( - self.model.model.layers[i].mlp.experts.w2_weight_scale_list - ) - # TODO: init self.expert_weight_names depending on different model types. - # Only deepseek v3 w8a8 and qwen3-moe is supported here - if self.model.quant_config is not None: - self.expert_weight_names = [ - "w13_weight_list", - "w2_weight_list", - "w13_weight_scale_fp32_list", - "w13_weight_offset", - "w2_weight_scale_list", - "w2_weight_offset", - ] - else: - self.expert_weight_names = ["w13_weight", "w2_weight"] - self.expert_map_per_layer_cpu = dict() # copy of expert map on CPU to avoid device synchronize frequently - num_buffer_tensor = self.model.model.layers[-1].mlp.experts.local_num_experts - self.buffer_tensor_list: list[list[Any]] = [[] for _ in range(num_buffer_tensor)] - self.init_buffer_tensor(num_buffer_tensor) - + self.num_local_experts = self.model.model.layers[-1].mlp.experts.local_num_experts self.expert_param_per_layer = dict() self.init_expert_param_per_layer() + num_buffer_tensor = self.num_local_experts + self.buffer_tensor_list: list[list[Any]] = [[] for _ in range(num_buffer_tensor)] + self.init_buffer_tensor(num_buffer_tensor) + self.log2phy_map_per_layer = dict() for layer_idx in range(self.num_moe_layers): self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = self.model.get_log2phy_map( @@ -81,38 +54,34 @@ class VllmEplbAdaptor(EplbAdaptor): for buffer_id in range(num_buffer_tensor): for name in self.expert_weight_names: complete_name = "model.layers." + str(self.num_dense_layers) + ".mlp.experts." + name - if name in ["w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", "w2_weight_scale_list"]: - expert_tensor = self.param_dict[complete_name][0] - expert_tensor = expert_tensor.clone() - else: - expert_tensor = self.param_dict[complete_name][0].data[0] + expert_tensor = self.param_dict[complete_name][0] buffer_tensor = torch.empty_like(expert_tensor) self.buffer_tensor_list[buffer_id].append(buffer_tensor) def init_expert_param_per_layer(self): - key = f"model.layers.{self.num_dense_layers}.mlp.experts.{self.expert_weight_names[0]}" - num_local_expert = len(self.param_dict[key]) - for moe_layer_id in range(self.num_moe_layers): - layer_idx = self.num_dense_layers + moe_layer_id + self.param_dict = dict() + if self.model.quant_config is not None: + self.expert_weight_names = [ + "w13_weight_list", + "w2_weight_list", + "w13_weight_scale_fp32_list", + "w2_weight_scale_list", + ] + else: + self.expert_weight_names = ["w13_weight", "w2_weight"] + + for layer_idx in range(self.num_dense_layers, self.model.config.num_hidden_layers): self.expert_param_per_layer[layer_idx] = list() - for local_expert_id in range(num_local_expert): + for name in self.expert_weight_names: + param_key = f"model.layers.{layer_idx}.mlp.experts.{name}" + param_value = getattr(self.model.model.layers[layer_idx].mlp.experts, name) + self.param_dict[param_key] = param_value + for local_expert_id in range(self.num_local_experts): per_expert_param = list() for name in self.expert_weight_names: - if name in [ - "w13_weight_list", - "w2_weight_list", - "w13_weight_scale_fp32_list", - "w2_weight_scale_list", - ]: - per_expert_param.append( - self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name][local_expert_id] - ) - else: - per_expert_param.append( - self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name][0].data[ - local_expert_id - ] - ) + per_expert_param.append( + self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name][local_expert_id] + ) self.expert_param_per_layer[layer_idx].append(per_expert_param) def get_rank_expert_workload(self) -> torch.Tensor: diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 3d289124..f94871fc 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -102,6 +102,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, enable_force_load_balance: bool = False, + log2phy: torch.Tensor = None, **kwargs) -> torch.Tensor: zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_type = getattr(layer, "zero_expert_type", None) @@ -149,6 +150,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, dynamic_eplb=self.dynamic_eplb, + log2phy=log2phy, mc2_mask=kwargs.get("mc2_mask", None)) if zero_expert_num > 0 and zero_expert_type is not None: final_hidden_states += zero_expert_result