[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216)

### What this PR does / why we need it?
Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into
dynamic EPLB to support list-type parameters
This PR also modify the logic of loading model in dynamic-eplb scenario.
The operator is based on this pr:
https://github.com/vllm-project/vllm-ascend/pull/3804

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

```
vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \
    --max_num_seqs 8 \
    --max-model-len 8192 \
    --max-num-batched-tokens 16384 \
    --tensor-parallel-size 8 \
    --data-parallel-size 2 \
    --enable-expert-parallel \
    --served-model-name ds_r1 \
    --enable-auto-tool-choice \
    --tool-call-parser hermes \
    --no-enable-prefix-caching \
    --port 8999 \
    --quantization "ascend" \
    --gpu-memory-utilization 0.85 \
    --trust-remote-code \
    --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \
    --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}'
 
```
input&output: 2k 2k
This PR:
<img width="1318" height="695" alt="fusion"
src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd"
/>

Baseline:
<img width="1323" height="690" alt="baseline"
src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e"
/>


- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

---------

Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: 欧派果奶我还要 <845473182@qq.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
This commit is contained in:
欧派果奶我还要
2025-11-30 22:52:05 +08:00
committed by GitHub
parent 18eefc23c3
commit bc67696a02
6 changed files with 139 additions and 50 deletions

View File

@@ -83,8 +83,8 @@ class MoECommMethod(ABC):
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
@@ -93,8 +93,8 @@ class MoECommMethod(ABC):
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
# For TorchAir graph

View File

@@ -23,7 +23,11 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
get_ascend_device_type)
enable_custom_op, get_ascend_device_type)
def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
return fusion and dynamic_eplb and enable_custom_op()
def cumsum_group_list(group_list: torch.Tensor,
@@ -55,10 +59,10 @@ def cumsum_group_list(group_list: torch.Tensor,
def quant_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w1: list[torch.Tensor],
w1_scale: list[torch.Tensor],
w2: list[torch.Tensor],
w2_scale: list[torch.Tensor],
group_list: torch.Tensor,
group_list_type: int = 1,
dynamic_scale: torch.Tensor = None,
@@ -79,7 +83,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
quantized_hidden_states = hidden_states
bias1, bias2 = None, None
_output_dtype = w2_scale.dtype
_output_dtype = w2_scale[0].dtype
weight_prefetch_method = get_forward_context().weight_prefetch_method
if weight_prefetch_method:
@@ -87,23 +91,34 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
hidden_states)
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and is_mc2:
if fusion and not dynamic_eplb:
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = (
torch.ops._C_ascend.
grouped_matmul_swiglu_quant_weight_nz_tensor_list(
x=hidden_states,
weight=w1,
weight_scale=w1_scale,
x_scale=pertoken_scale,
group_list=cumsum_group_list(group_list, group_list_type),
))
elif fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1,
weight=w1[0],
group_list=cumsum_group_list(group_list, group_list_type),
weight_scale=w1_scale,
weight_scale=w1_scale[0],
x_scale=pertoken_scale)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
else:
if w1_scale.dtype != torch.float32:
w1_scale = w1_scale.to(torch.float32)
if w1_scale[0].dtype != torch.float32:
w1_scale[0] = w1_scale[0].to(torch.float32)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
weight=w1,
split_item=3,
group_list_type=group_list_type,
group_type=0,
@@ -126,14 +141,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
weight=w2,
scale=w2_scale,
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]
output_dtype=w2_scale[0].dtype)[0]
else:
if w1_scale_bias is not None:
if group_list_type == 0:
@@ -146,23 +161,36 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
# TODO w4a8 scene: dynamic acquisition of dtype in the future
_output_dtype = torch.bfloat16
if fusion and not dynamic_eplb:
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = (
torch.ops._C_ascend.
grouped_matmul_swiglu_quant_weight_nz_tensor_list(
x=hidden_states,
weight=w1,
weight_scale=w1_scale,
x_scale=pertoken_scale,
group_list=cumsum_group_list(group_list, group_list_type),
bias=bias1,
))
elif fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1,
weight=w1[0],
bias=bias1,
group_list=cumsum_group_list(group_list, group_list_type),
weight_scale=w1_scale,
weight_scale=w1_scale[0],
x_scale=pertoken_scale)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
else:
w1_scale[0] = w1_scale[0].to(w2_scale[0].dtype)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
scale=[w1_scale.to(w2_scale.dtype)],
weight=w1,
scale=w1_scale,
bias=bias1,
per_token_scale=[pertoken_scale],
split_item=2,
@@ -179,8 +207,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
weight=w2,
scale=w2_scale,
bias=bias2,
per_token_scale=[swiglu_out_scale],
split_item=2,
@@ -232,11 +260,11 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
def unified_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
group_list: torch.Tensor,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
@@ -247,6 +275,7 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
need_trans: bool = True,
dynamic_eplb: bool = False) -> torch.Tensor:
if with_quant:
assert w1_scale is not None and w2_scale is not None
return quant_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,