From bc67696a02fffff3b5d246efe17d0a01e52ab2a4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E6=AC=A7=E6=B4=BE=E6=9E=9C=E5=A5=B6=E6=88=91=E8=BF=98?=
=?UTF-8?q?=E8=A6=81?= <47294568+845473182@users.noreply.github.com>
Date: Sun, 30 Nov 2025 22:52:05 +0800
Subject: [PATCH] [EPLB][Ops] Integerate
grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB
(#4216)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
### What this PR does / why we need it?
Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into
dynamic EPLB to support list-type parameters
This PR also modify the logic of loading model in dynamic-eplb scenario.
The operator is based on this pr:
https://github.com/vllm-project/vllm-ascend/pull/3804
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
```
vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \
--max_num_seqs 8 \
--max-model-len 8192 \
--max-num-batched-tokens 16384 \
--tensor-parallel-size 8 \
--data-parallel-size 2 \
--enable-expert-parallel \
--served-model-name ds_r1 \
--enable-auto-tool-choice \
--tool-call-parser hermes \
--no-enable-prefix-caching \
--port 8999 \
--quantization "ascend" \
--gpu-memory-utilization 0.85 \
--trust-remote-code \
--compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \
--additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}'
```
input&output: 2k 2k
This PR:
Baseline:
- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2
---------
Signed-off-by: 白永斌
Signed-off-by: 欧派果奶我还要 <845473182@qq.com>
Co-authored-by: 白永斌
---
tests/ut/ops/test_moe_comm_method.py | 4 +-
vllm_ascend/eplb/adaptor/vllm_adaptor.py | 47 +++++++++---
vllm_ascend/ops/fused_moe/moe_comm_method.py | 8 +-
vllm_ascend/ops/fused_moe/moe_mlp.py | 81 +++++++++++++-------
vllm_ascend/quantization/w4a8_dynamic.py | 8 +-
vllm_ascend/quantization/w8a8_dynamic.py | 41 +++++++++-
6 files changed, 139 insertions(+), 50 deletions(-)
diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py
index f258f8e7..8adde876 100644
--- a/tests/ut/ops/test_moe_comm_method.py
+++ b/tests/ut/ops/test_moe_comm_method.py
@@ -226,8 +226,8 @@ class TestMoECommMethod(TestBase):
w2 = w2.contiguous()
result = comm_impl.fused_experts(hidden_states=hidden_states,
- w1=w1,
- w2=w2,
+ w1=[w1],
+ w2=[w2],
topk_weights=topk_weights,
topk_ids=topk_ids,
activation="silu")
diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py
index d200fa6c..47a99d1b 100644
--- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py
+++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py
@@ -44,11 +44,22 @@ class VllmEplbAdaptor(EplbAdaptor):
self.init_redundancy_expert = get_ascend_config(
).init_redundancy_expert
+ for i in range(self.num_dense_layers,
+ self.model.config.num_hidden_layers):
+ self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = \
+ self.model.model.layers[i].mlp.experts.w13_weight_list
+ self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = \
+ self.model.model.layers[i].mlp.experts.w2_weight_list
+ self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = \
+ self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
+ self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \
+ self.model.model.layers[i].mlp.experts.w2_weight_scale_list
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
if self.model.quant_config is not None:
self.expert_weight_names = [
- "w13_weight", "w2_weight", "w13_weight_scale",
- "w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
+ "w13_weight_list", "w2_weight_list",
+ "w13_weight_scale_fp32_list", "w13_weight_offset",
+ "w2_weight_scale_list", "w2_weight_offset"
]
else:
self.expert_weight_names = ["w13_weight", "w2_weight"]
@@ -84,9 +95,14 @@ class VllmEplbAdaptor(EplbAdaptor):
for name in self.expert_weight_names:
complete_name = "model.layers." + str(
self.num_dense_layers) + ".mlp.experts." + name
- expert_tensor = self.param_dict[complete_name].data[0]
- if name in ["w13_weight", "w2_weight"]:
+ if name in [
+ "w13_weight_list", "w2_weight_list",
+ "w13_weight_scale_fp32_list", "w2_weight_scale_list"
+ ]:
+ expert_tensor = self.param_dict[complete_name][0]
expert_tensor = expert_tensor.clone()
+ else:
+ expert_tensor = self.param_dict[complete_name][0].data[0]
buffer_tensor = torch.empty_like(expert_tensor)
self.buffer_tensor_list[buffer_id].append(buffer_tensor)
@@ -97,12 +113,23 @@ class VllmEplbAdaptor(EplbAdaptor):
layer_idx = self.num_dense_layers + moe_layer_id
self.expert_param_per_layer[layer_idx] = list()
for local_expert_id in range(num_local_expert):
- self.expert_param_per_layer[layer_idx].append([
- self.param_dict["model.layers." + str(layer_idx) +
- ".mlp.experts." +
- name].data[local_expert_id]
- for name in self.expert_weight_names
- ])
+ per_expert_param = list()
+ for name in self.expert_weight_names:
+ if name in [
+ "w13_weight_list", "w2_weight_list",
+ "w13_weight_scale_fp32_list",
+ "w2_weight_scale_list"
+ ]:
+ per_expert_param.append(
+ self.param_dict["model.layers." + str(layer_idx) +
+ ".mlp.experts." +
+ name][local_expert_id])
+ else:
+ per_expert_param.append(
+ self.param_dict["model.layers." + str(layer_idx) +
+ ".mlp.experts." +
+ name][0].data[local_expert_id])
+ self.expert_param_per_layer[layer_idx].append(per_expert_param)
def get_rank_expert_workload(self) -> torch.Tensor:
self.moe_load = self.model.get_all_moe_loads()
diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py
index c48ce1a4..802dbe5d 100644
--- a/vllm_ascend/ops/fused_moe/moe_comm_method.py
+++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py
@@ -83,8 +83,8 @@ class MoECommMethod(ABC):
def fused_experts(
self,
hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
+ w1: torch.Tensor | list[torch.Tensor],
+ w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
@@ -93,8 +93,8 @@ class MoECommMethod(ABC):
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
- w1_scale: Optional[torch.Tensor] = None,
- w2_scale: Optional[torch.Tensor] = None,
+ w1_scale: Optional[list[torch.Tensor]] = None,
+ w2_scale: Optional[list[torch.Tensor]] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
# For TorchAir graph
diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py
index 07ba732f..13e1efc0 100644
--- a/vllm_ascend/ops/fused_moe/moe_mlp.py
+++ b/vllm_ascend/ops/fused_moe/moe_mlp.py
@@ -23,7 +23,11 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
- get_ascend_device_type)
+ enable_custom_op, get_ascend_device_type)
+
+
+def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
+ return fusion and dynamic_eplb and enable_custom_op()
def cumsum_group_list(group_list: torch.Tensor,
@@ -55,10 +59,10 @@ def cumsum_group_list(group_list: torch.Tensor,
def quant_apply_mlp(hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w1_scale: torch.Tensor,
- w2: torch.Tensor,
- w2_scale: torch.Tensor,
+ w1: list[torch.Tensor],
+ w1_scale: list[torch.Tensor],
+ w2: list[torch.Tensor],
+ w2_scale: list[torch.Tensor],
group_list: torch.Tensor,
group_list_type: int = 1,
dynamic_scale: torch.Tensor = None,
@@ -79,7 +83,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
quantized_hidden_states = hidden_states
bias1, bias2 = None, None
- _output_dtype = w2_scale.dtype
+ _output_dtype = w2_scale[0].dtype
weight_prefetch_method = get_forward_context().weight_prefetch_method
if weight_prefetch_method:
@@ -87,23 +91,34 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
hidden_states)
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and is_mc2:
- if fusion and not dynamic_eplb:
+ if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
+ # gmm1: gate_up_proj & act_fn: swiglu
+ hidden_states, swiglu_out_scale, _ = (
+ torch.ops._C_ascend.
+ grouped_matmul_swiglu_quant_weight_nz_tensor_list(
+ x=hidden_states,
+ weight=w1,
+ weight_scale=w1_scale,
+ x_scale=pertoken_scale,
+ group_list=cumsum_group_list(group_list, group_list_type),
+ ))
+ elif fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
- weight=w1,
+ weight=w1[0],
group_list=cumsum_group_list(group_list, group_list_type),
- weight_scale=w1_scale,
+ weight_scale=w1_scale[0],
x_scale=pertoken_scale)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
else:
- if w1_scale.dtype != torch.float32:
- w1_scale = w1_scale.to(torch.float32)
+ if w1_scale[0].dtype != torch.float32:
+ w1_scale[0] = w1_scale[0].to(torch.float32)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
- weight=[w1],
+ weight=w1,
split_item=3,
group_list_type=group_list_type,
group_type=0,
@@ -126,14 +141,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
- weight=[w2],
- scale=[w2_scale],
+ weight=w2,
+ scale=w2_scale,
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
- output_dtype=w2_scale.dtype)[0]
+ output_dtype=w2_scale[0].dtype)[0]
else:
if w1_scale_bias is not None:
if group_list_type == 0:
@@ -146,23 +161,36 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
# TODO w4a8 scene: dynamic acquisition of dtype in the future
_output_dtype = torch.bfloat16
- if fusion and not dynamic_eplb:
+ if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
+ # gmm1: gate_up_proj & act_fn: swiglu
+ hidden_states, swiglu_out_scale, _ = (
+ torch.ops._C_ascend.
+ grouped_matmul_swiglu_quant_weight_nz_tensor_list(
+ x=hidden_states,
+ weight=w1,
+ weight_scale=w1_scale,
+ x_scale=pertoken_scale,
+ group_list=cumsum_group_list(group_list, group_list_type),
+ bias=bias1,
+ ))
+ elif fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
- weight=w1,
+ weight=w1[0],
bias=bias1,
group_list=cumsum_group_list(group_list, group_list_type),
- weight_scale=w1_scale,
+ weight_scale=w1_scale[0],
x_scale=pertoken_scale)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
else:
+ w1_scale[0] = w1_scale[0].to(w2_scale[0].dtype)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
- weight=[w1],
- scale=[w1_scale.to(w2_scale.dtype)],
+ weight=w1,
+ scale=w1_scale,
bias=bias1,
per_token_scale=[pertoken_scale],
split_item=2,
@@ -179,8 +207,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
- weight=[w2],
- scale=[w2_scale],
+ weight=w2,
+ scale=w2_scale,
bias=bias2,
per_token_scale=[swiglu_out_scale],
split_item=2,
@@ -232,11 +260,11 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
def unified_apply_mlp(hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w1_scale: torch.Tensor,
- w2: torch.Tensor,
- w2_scale: torch.Tensor,
+ w1: torch.Tensor | list[torch.Tensor],
+ w2: torch.Tensor | list[torch.Tensor],
group_list: torch.Tensor,
+ w1_scale: Optional[list[torch.Tensor]] = None,
+ w2_scale: Optional[list[torch.Tensor]] = None,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
@@ -247,6 +275,7 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
need_trans: bool = True,
dynamic_eplb: bool = False) -> torch.Tensor:
if with_quant:
+ assert w1_scale is not None and w2_scale is not None
return quant_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py
index c7f1dfab..a73050c3 100644
--- a/vllm_ascend/quantization/w4a8_dynamic.py
+++ b/vllm_ascend/quantization/w4a8_dynamic.py
@@ -379,10 +379,10 @@ class AscendW4A8DynamicFusedMoEMethod:
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
- w1=layer.w13_weight,
- w2=layer.w2_weight,
- w1_scale=layer.w13_weight_scale,
- w2_scale=layer.w2_weight_scale,
+ w1=[layer.w13_weight],
+ w2=[layer.w2_weight],
+ w1_scale=[layer.w13_weight_scale],
+ w2_scale=[layer.w2_weight_scale],
w1_scale_bias=layer.w13_scale_bias,
w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights,
diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py
index e64814be..cfeee223 100644
--- a/vllm_ascend/quantization/w8a8_dynamic.py
+++ b/vllm_ascend/quantization/w8a8_dynamic.py
@@ -236,13 +236,24 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_weights = topk_weights.to(self.in_dtype)
moe_comm_method = get_forward_context().moe_comm_method
+ if self.dynamic_eplb:
+ w1 = layer.w13_weight_list
+ w1_scale = layer.w13_weight_scale_fp32_list
+ w2 = layer.w2_weight_list
+ w2_scale = layer.w2_weight_scale_list
+ else:
+ w1 = [layer.w13_weight]
+ w1_scale = [layer.w13_weight_scale_fp32]
+ w2 = [layer.w2_weight]
+ w2_scale = [layer.w2_weight_scale]
+
return moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
- w1=layer.w13_weight,
- w1_scale=layer.w13_weight_scale_fp32,
- w2=layer.w2_weight,
- w2_scale=layer.w2_weight_scale,
+ w1=w1,
+ w1_scale=w1_scale,
+ w2=w2,
+ w2_scale=w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=True,
@@ -274,3 +285,25 @@ class AscendW8A8DynamicFusedMoEMethod:
layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
layer.w2_weight_offset.data.shape[0], -1)
+ if self.dynamic_eplb:
+ layer.w13_weight_list = [
+ weight.clone()
+ for weight in layer.w13_weight.data.unbind(dim=0)
+ ]
+ layer.w2_weight_list = [
+ weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)
+ ]
+ layer.w13_weight_scale_fp32_list = [
+ weight.clone()
+ for weight in layer.w13_weight_scale.data.unbind(dim=0)
+ ]
+ layer.w2_weight_scale_list = [
+ weight.clone()
+ for weight in layer.w2_weight_scale.data.unbind(dim=0)
+ ]
+ del layer.w13_weight
+ del layer.w2_weight
+ del layer.w13_weight_scale
+ del layer.w13_weight_scale_fp32
+ del layer.w2_weight_scale
+ torch.npu.empty_cache()