From bc67696a02fffff3b5d246efe17d0a01e52ab2a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AC=A7=E6=B4=BE=E6=9E=9C=E5=A5=B6=E6=88=91=E8=BF=98?= =?UTF-8?q?=E8=A6=81?= <47294568+845473182@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:52:05 +0800 Subject: [PATCH] [EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: fusion Baseline: baseline - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 --- tests/ut/ops/test_moe_comm_method.py | 4 +- vllm_ascend/eplb/adaptor/vllm_adaptor.py | 47 +++++++++--- vllm_ascend/ops/fused_moe/moe_comm_method.py | 8 +- vllm_ascend/ops/fused_moe/moe_mlp.py | 81 +++++++++++++------- vllm_ascend/quantization/w4a8_dynamic.py | 8 +- vllm_ascend/quantization/w8a8_dynamic.py | 41 +++++++++- 6 files changed, 139 insertions(+), 50 deletions(-) diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index f258f8e7..8adde876 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -226,8 +226,8 @@ class TestMoECommMethod(TestBase): w2 = w2.contiguous() result = comm_impl.fused_experts(hidden_states=hidden_states, - w1=w1, - w2=w2, + w1=[w1], + w2=[w2], topk_weights=topk_weights, topk_ids=topk_ids, activation="silu") diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index d200fa6c..47a99d1b 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -44,11 +44,22 @@ class VllmEplbAdaptor(EplbAdaptor): self.init_redundancy_expert = get_ascend_config( ).init_redundancy_expert + for i in range(self.num_dense_layers, + self.model.config.num_hidden_layers): + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = \ + self.model.model.layers[i].mlp.experts.w13_weight_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = \ + self.model.model.layers[i].mlp.experts.w2_weight_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = \ + self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \ + self.model.model.layers[i].mlp.experts.w2_weight_scale_list # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here if self.model.quant_config is not None: self.expert_weight_names = [ - "w13_weight", "w2_weight", "w13_weight_scale", - "w13_weight_offset", "w2_weight_scale", "w2_weight_offset" + "w13_weight_list", "w2_weight_list", + "w13_weight_scale_fp32_list", "w13_weight_offset", + "w2_weight_scale_list", "w2_weight_offset" ] else: self.expert_weight_names = ["w13_weight", "w2_weight"] @@ -84,9 +95,14 @@ class VllmEplbAdaptor(EplbAdaptor): for name in self.expert_weight_names: complete_name = "model.layers." + str( self.num_dense_layers) + ".mlp.experts." + name - expert_tensor = self.param_dict[complete_name].data[0] - if name in ["w13_weight", "w2_weight"]: + if name in [ + "w13_weight_list", "w2_weight_list", + "w13_weight_scale_fp32_list", "w2_weight_scale_list" + ]: + expert_tensor = self.param_dict[complete_name][0] expert_tensor = expert_tensor.clone() + else: + expert_tensor = self.param_dict[complete_name][0].data[0] buffer_tensor = torch.empty_like(expert_tensor) self.buffer_tensor_list[buffer_id].append(buffer_tensor) @@ -97,12 +113,23 @@ class VllmEplbAdaptor(EplbAdaptor): layer_idx = self.num_dense_layers + moe_layer_id self.expert_param_per_layer[layer_idx] = list() for local_expert_id in range(num_local_expert): - self.expert_param_per_layer[layer_idx].append([ - self.param_dict["model.layers." + str(layer_idx) + - ".mlp.experts." + - name].data[local_expert_id] - for name in self.expert_weight_names - ]) + per_expert_param = list() + for name in self.expert_weight_names: + if name in [ + "w13_weight_list", "w2_weight_list", + "w13_weight_scale_fp32_list", + "w2_weight_scale_list" + ]: + per_expert_param.append( + self.param_dict["model.layers." + str(layer_idx) + + ".mlp.experts." + + name][local_expert_id]) + else: + per_expert_param.append( + self.param_dict["model.layers." + str(layer_idx) + + ".mlp.experts." + + name][0].data[local_expert_id]) + self.expert_param_per_layer[layer_idx].append(per_expert_param) def get_rank_expert_workload(self) -> torch.Tensor: self.moe_load = self.model.get_all_moe_loads() diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index c48ce1a4..802dbe5d 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -83,8 +83,8 @@ class MoECommMethod(ABC): def fused_experts( self, hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, + w1: torch.Tensor | list[torch.Tensor], + w2: torch.Tensor | list[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", @@ -93,8 +93,8 @@ class MoECommMethod(ABC): use_int4_w4a8: bool = False, global_num_experts: Optional[int] = None, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, + w1_scale: Optional[list[torch.Tensor]] = None, + w2_scale: Optional[list[torch.Tensor]] = None, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, # For TorchAir graph diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 07ba732f..13e1efc0 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -23,7 +23,11 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.utils import (AscendDeviceType, dispose_tensor, - get_ascend_device_type) + enable_custom_op, get_ascend_device_type) + + +def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): + return fusion and dynamic_eplb and enable_custom_op() def cumsum_group_list(group_list: torch.Tensor, @@ -55,10 +59,10 @@ def cumsum_group_list(group_list: torch.Tensor, def quant_apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, + w1: list[torch.Tensor], + w1_scale: list[torch.Tensor], + w2: list[torch.Tensor], + w2_scale: list[torch.Tensor], group_list: torch.Tensor, group_list_type: int = 1, dynamic_scale: torch.Tensor = None, @@ -79,7 +83,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, quantized_hidden_states = hidden_states bias1, bias2 = None, None - _output_dtype = w2_scale.dtype + _output_dtype = w2_scale[0].dtype weight_prefetch_method = get_forward_context().weight_prefetch_method if weight_prefetch_method: @@ -87,23 +91,34 @@ def quant_apply_mlp(hidden_states: torch.Tensor, hidden_states) is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 if w1_scale_bias is None and is_mc2: - if fusion and not dynamic_eplb: + if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): + # gmm1: gate_up_proj & act_fn: swiglu + hidden_states, swiglu_out_scale, _ = ( + torch.ops._C_ascend. + grouped_matmul_swiglu_quant_weight_nz_tensor_list( + x=hidden_states, + weight=w1, + weight_scale=w1_scale, + x_scale=pertoken_scale, + group_list=cumsum_group_list(group_list, group_list_type), + )) + elif fusion and not dynamic_eplb: # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( x=hidden_states, - weight=w1, + weight=w1[0], group_list=cumsum_group_list(group_list, group_list_type), - weight_scale=w1_scale, + weight_scale=w1_scale[0], x_scale=pertoken_scale) if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) else: - if w1_scale.dtype != torch.float32: - w1_scale = w1_scale.to(torch.float32) + if w1_scale[0].dtype != torch.float32: + w1_scale[0] = w1_scale[0].to(torch.float32) # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w1], + weight=w1, split_item=3, group_list_type=group_list_type, group_type=0, @@ -126,14 +141,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor, # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w2], - scale=[w2_scale], + weight=w2, + scale=w2_scale, per_token_scale=[swiglu_out_scale], split_item=2, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=w2_scale.dtype)[0] + output_dtype=w2_scale[0].dtype)[0] else: if w1_scale_bias is not None: if group_list_type == 0: @@ -146,23 +161,36 @@ def quant_apply_mlp(hidden_states: torch.Tensor, # TODO w4a8 scene: dynamic acquisition of dtype in the future _output_dtype = torch.bfloat16 - if fusion and not dynamic_eplb: + if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): + # gmm1: gate_up_proj & act_fn: swiglu + hidden_states, swiglu_out_scale, _ = ( + torch.ops._C_ascend. + grouped_matmul_swiglu_quant_weight_nz_tensor_list( + x=hidden_states, + weight=w1, + weight_scale=w1_scale, + x_scale=pertoken_scale, + group_list=cumsum_group_list(group_list, group_list_type), + bias=bias1, + )) + elif fusion and not dynamic_eplb: # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( x=hidden_states, - weight=w1, + weight=w1[0], bias=bias1, group_list=cumsum_group_list(group_list, group_list_type), - weight_scale=w1_scale, + weight_scale=w1_scale[0], x_scale=pertoken_scale) if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) else: + w1_scale[0] = w1_scale[0].to(w2_scale[0].dtype) # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w1], - scale=[w1_scale.to(w2_scale.dtype)], + weight=w1, + scale=w1_scale, bias=bias1, per_token_scale=[pertoken_scale], split_item=2, @@ -179,8 +207,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w2], - scale=[w2_scale], + weight=w2, + scale=w2_scale, bias=bias2, per_token_scale=[swiglu_out_scale], split_item=2, @@ -232,11 +260,11 @@ def unquant_apply_mlp(hidden_states: torch.Tensor, def unified_apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, + w1: torch.Tensor | list[torch.Tensor], + w2: torch.Tensor | list[torch.Tensor], group_list: torch.Tensor, + w1_scale: Optional[list[torch.Tensor]] = None, + w2_scale: Optional[list[torch.Tensor]] = None, dynamic_scale: torch.Tensor = None, group_list_type: int = 1, w1_scale_bias: torch.Tensor = None, @@ -247,6 +275,7 @@ def unified_apply_mlp(hidden_states: torch.Tensor, need_trans: bool = True, dynamic_eplb: bool = False) -> torch.Tensor: if with_quant: + assert w1_scale is not None and w2_scale is not None return quant_apply_mlp(hidden_states=hidden_states, w1=w1, w1_scale=w1_scale, diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index c7f1dfab..a73050c3 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -379,10 +379,10 @@ class AscendW4A8DynamicFusedMoEMethod: moe_comm_method = get_forward_context().moe_comm_method return moe_comm_method.fused_experts( hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, + w1=[layer.w13_weight], + w2=[layer.w2_weight], + w1_scale=[layer.w13_weight_scale], + w2_scale=[layer.w2_weight_scale], w1_scale_bias=layer.w13_scale_bias, w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index e64814be..cfeee223 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -236,13 +236,24 @@ class AscendW8A8DynamicFusedMoEMethod: topk_weights = topk_weights.to(self.in_dtype) moe_comm_method = get_forward_context().moe_comm_method + if self.dynamic_eplb: + w1 = layer.w13_weight_list + w1_scale = layer.w13_weight_scale_fp32_list + w2 = layer.w2_weight_list + w2_scale = layer.w2_weight_scale_list + else: + w1 = [layer.w13_weight] + w1_scale = [layer.w13_weight_scale_fp32] + w2 = [layer.w2_weight] + w2_scale = [layer.w2_weight_scale] + return moe_comm_method.fused_experts( hidden_states=x, pertoken_scale=pertoken_scale, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale_fp32, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, topk_weights=topk_weights, topk_ids=topk_ids, use_int8_w8a8=True, @@ -274,3 +285,25 @@ class AscendW8A8DynamicFusedMoEMethod: layer.w2_weight_scale.data.shape[0], -1) layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( layer.w2_weight_offset.data.shape[0], -1) + if self.dynamic_eplb: + layer.w13_weight_list = [ + weight.clone() + for weight in layer.w13_weight.data.unbind(dim=0) + ] + layer.w2_weight_list = [ + weight.clone() for weight in layer.w2_weight.data.unbind(dim=0) + ] + layer.w13_weight_scale_fp32_list = [ + weight.clone() + for weight in layer.w13_weight_scale.data.unbind(dim=0) + ] + layer.w2_weight_scale_list = [ + weight.clone() + for weight in layer.w2_weight_scale.data.unbind(dim=0) + ] + del layer.w13_weight + del layer.w2_weight + del layer.w13_weight_scale + del layer.w13_weight_scale_fp32 + del layer.w2_weight_scale + torch.npu.empty_cache()